* feature(intelligent-search): Added API to connect to Llama.cpp in EC2 and filter the response into OR filters * updated sql to filter script and added init.sql for tables * feature(intelligent-search): Changed llama.cpp for llama in GPU now contained in API * Updated Dockerfile to use GPU and download LLM from S3 * Added link to facebook/research/llama * Updated Dockerfile * Updated requirements and Dockerfile base images * fixed minor issues: Not used variables, updated COPY and replace values * fix(intelligent-search): Fixed WHERE statement filter * feature(smart-charts): Added method to create charts using llama. style(intelligent-search): Changed names for attributes to match frontend format. fix(intelligent-search): Fixed vulnerability in requiments and small issues fix * Added some test before deploying the service * Added semaphore to handle concurrency --------- Co-authored-by: EC2 Default User <ec2-user@ip-10-0-2-226.eu-central-1.compute.internal>
78 lines
2.9 KiB
Python
78 lines
2.9 KiB
Python
from typing import List, Optional
|
|
from decouple import config
|
|
from time import time
|
|
|
|
from fastapi import FastAPI, Depends
|
|
from contextlib import asynccontextmanager
|
|
|
|
from utils.contexts import search_context_v2, search_context_v3
|
|
from utils.contexts_charts import chart_context_v2, formatable_end
|
|
from utils.sql_to_filters import filter_sql_where_statement
|
|
from utils import parameters, declarations
|
|
from core.llm_api import LLM_Model
|
|
from auth.auth_key import api_key_auth
|
|
|
|
|
|
class FastAPI_with_LLM(FastAPI):
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.llm_model = None
|
|
|
|
def build_llm(self, ckpt_dir: str, tokenizer_path: str, max_seq_len: int, max_batch_size: int):
|
|
self.llm_model = LLM_Model(ckpt_dir=ckpt_dir,
|
|
tokenizer_path=tokenizer_path,
|
|
max_seq_len=max_seq_len,
|
|
max_batch_size=max_batch_size)
|
|
|
|
def clear(self):
|
|
del self.llm_model
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI_with_LLM):
|
|
app.build_llm(ckpt_dir=parameters.ckpt_dir,
|
|
tokenizer_path=parameters.tokenizer_path,
|
|
max_seq_len=parameters.max_seq_len,
|
|
max_batch_size=parameters.max_batch_size)
|
|
yield
|
|
app.clear()
|
|
|
|
|
|
app = FastAPI_with_LLM(lifespan=lifespan)
|
|
|
|
|
|
@app.post("/llm/completion", dependencies=[Depends(api_key_auth)])
|
|
async def predict(msg: declarations.LLMQuestion):
|
|
question = msg.question
|
|
t1 = time()
|
|
result = app.llm_model.execute_prompts([search_context_v3.format(user_question=question)],
|
|
temperature=parameters.temperature,
|
|
top_p=parameters.top_p,
|
|
max_gen_len=parameters.max_gen_len)
|
|
t2 = time()
|
|
processed = filter_sql_where_statement(result[0]['generation'])
|
|
if processed is None:
|
|
return {"content": None, "raw_response": result, "inference_time": t2-t1}
|
|
return {"content": processed, "raw_response": result, "inference_time": t2-t1}
|
|
|
|
|
|
@app.post("/llm/completion/charts", dependencies=[Depends(api_key_auth)])
|
|
async def chart_predict(msg: declarations.LLMQuestion):
|
|
question = msg.question
|
|
t1 = time()
|
|
result = app.llm_model.execute_prompts([chart_context_v2+formatable_end.format(user_question=question)],
|
|
temperature=parameters.temperature,
|
|
top_p=parameters.top_p,
|
|
max_gen_len=parameters.max_gen_len)
|
|
t2 = time()
|
|
processed = result[0]['generation']
|
|
if processed is None:
|
|
return {"content": None, "raw_response": result, "inference_time": t2-t1}
|
|
return {"content": processed, "raw_response": result, "inference_time": t2-t1}
|
|
|
|
|
|
@app.get('/')
|
|
async def health():
|
|
return {'status': 200}
|
|
|