feat(intelligent-search): intelligent search service (#1545)

* feature(intelligent-search): Added API to connect to Llama.cpp in EC2 and filter the response into OR filters

* updated sql to filter script and added init.sql for tables

* feature(intelligent-search): Changed llama.cpp for llama in GPU now contained in API

* Updated Dockerfile to use GPU and download LLM from S3

* Added link to facebook/research/llama

* Updated Dockerfile

* Updated requirements and Dockerfile base images

* fixed minor issues: Not used variables, updated COPY and replace values

* fix(intelligent-search): Fixed WHERE statement filter

* feature(smart-charts): Added method to create charts using llama. style(intelligent-search): Changed names for attributes to match frontend format. fix(intelligent-search): Fixed vulnerability in requiments and small issues fix

* Added some test before deploying the service

* Added semaphore to handle concurrency

---------

Co-authored-by: EC2 Default User <ec2-user@ip-10-0-2-226.eu-central-1.compute.internal>
This commit is contained in:
MauricioGarciaS 2023-10-25 10:13:58 +02:00 committed by GitHub
parent c836768d72
commit 16efb1316c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 863 additions and 0 deletions

3
.gitmodules vendored Normal file
View file

@ -0,0 +1,3 @@
[submodule "ee/intelligent_search/llama"]
path = ee/intelligent_search/llama
url = https://github.com/facebookresearch/llama.git

View file

@ -0,0 +1,28 @@
FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
COPY requirements.txt .
RUN pip install -r requirements.txt
WORKDIR api
COPY llama/llama/*.py llama/
COPY auth/*.py auth/
COPY crons/*.py crons/
COPY utils/*.py utils/
COPY core/*.py core/
COPY *.sh ./
COPY *.py ./
ENV \
RANK=0 \
WORLD_SIZE=1 \
LOCAL_RANK=0 \
MASTER_PORT=29500 \
MASTER_ADDR=localhost \
CHECKPOINT_DIR=/api/llama-2-7b-chat/ \
TOKENIZER_PATH=/api/tokenizer.model \
S3_LLM_DIR= \
S3_TOKENIZER_PATH= \
AWS_ACCESS_KEY_ID= \
AWS_SECRET_ACCESS_KEY= \
LLAMA_API_AUTH_KEY=
EXPOSE 8082
ENTRYPOINT ./entrypoint.sh

View file

@ -0,0 +1,33 @@
from fastapi.security import OAuth2PasswordBearer
from fastapi import HTTPException, Depends, status
from decouple import config
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
class AuthHandler:
def __init__(self):
"""
Authorization method using an API key.
"""
self.__api_keys = [config("LLAMA_API_AUTH_KEY")]
def __contains__(self, api_key):
return api_key in self.__api_keys
def add_key(self, key):
"""Adds new key for authentication."""
self.__api_keys.append(key)
auth_method = AuthHandler()
def api_key_auth(api_key: str = Depends(oauth2_scheme)):
"""Method to verify auth."""
global auth_method
if api_key not in auth_method:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Forbidden"
)

View file

@ -0,0 +1,80 @@
from utils.ch_client import ClickHouseClient
from core.llm_api import LLM_Model
from threading import Semaphore
from decouple import config
import logging
FEEDBACK_LLAMA_TABLE_NAME = config('FEEDBACK_LLAMA_TABLE_NAME')
class QnA:
user_question: str
llama_response: str
user_identifier: int
project_identifier: int
def __init__(self, question: str, answer: str, user_id: int, project_id: int):
self.user_question = question
self.llama_response = answer
self.user_identifier = user_id
self.project_identifier = project_id
def __preprocess_value(**args):
processed = {}
for k,v in args.values():
if __annotations__[k] == str:
v = v.replace("'", "''")
processed[k] = f"'{v}'"
else:
processed[k] = str(v)
return processed
def to_sql(self):
processed = __preproces_value({'user_question': self.user_question,
'llama_response': self.llama_response,
'user_identifier': self.user_identifier,
'project_identifier': self.project_identifier
})
return "({project_id}, {user_id}, {user_question}, {llama_response})".format(processed)
class RequestsQueue:
__q_n_a_queue: list[QnA] = list()
queue_current_size = 0
def __init__(self, size: int = 100, max_wait_time: int = 1):
self.queue_size = size
self.max_wait_time = max_wait_time
self.semaphore = Semaphore(1)
def add_to_queue(self, question: str, answer: str, user_id: int, project_id: int):
self.__q_n_a_queue.append(
QnA(question=question,
answer=answer,
user_id=user_id,
project_id=project_id)
)
self.queue_current_size += 1
def flush_queue(self):
replace_sql = ', '.join([question_and_answer.to_sql() for question_and_answer in self.__q_n_a_queue])
query = "INSERT INTO {table_name} (projectId, userId, userQuestion, llamaResponse) VALUES {replace_sql}".format(
table_name=FEEDBACK_LLAMA_TABLE_NAME,
replace_sql=replace_sql)
try:
with ClickHouseClient() as conn:
conn.execute(query)
except Exception as e:
logging.error(f'[Flush Queue Error] {repr(e)}')
def start(self, llm_model: LLM_Model):
...
def recurrent_flush(self):
if self.semaphore.aquire(timeout=10):
# TODO: Process requests
self.semaphore.release()
else:
raise TimeoutError('LLM model overloaded with requests')

View file

@ -0,0 +1,55 @@
from llama import Llama, Dialog
from decouple import config
from utils.contexts import search_context_v2
from threading import Semaphore
class LLM_Model:
def __init__(self, **params):
"""
Initialization of pre-trained model.
Args:
ckpt_dirckpt_dir (str): The directory containing checkpoint files for the pretrained model.
tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding.
max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 128.
max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 4.
"""
self.generator = Llama.build(**params)
self.max_queue_size = config('LLM_MAX_QUEUE_SIZE', cast=int, default=1)
self.semaphore = Semaphore(config('LLM_MAX_BATCH_SIZE', cast=int, default=1))
self.queue = list()
self.responses = list()
def __execute_prompts(self, prompts, **params):
"""
Entry point of the program for generating text using a pretrained model.
Args:
prompts (list str): batch of prompts to be asked to LLM.
temperature (float, optional): The temperature value for controlling randomness in generation. Defaults to 0.6.
top_p (float, optional): The top-p sampling parameter for controlling diversity in generation. Defaults to 0.9.
max_gen_len (int, optional): The maximum length of generated sequences. Defaults to 64.
"""
return self.generator.text_completion(
prompts, **params)
def execute_prompts(self, prompts, **params):
if self.semaphore.acquire(timeout=10):
results = self.__execute_prompts(prompts, **params)
self.semaphore.release()
return results
else:
raise TimeoutError("[Error] LLM is over-requested")
async def queue_prompt(self, prompt, force=False, **params):
if self.semaphore.acquire(timeout=10):
if force:
self.responses = execute_prompts(self.queue + [prompt])
else:
self.queue.append(prompt)
# Wait until response exists
self.semaphore.release()
else:
raise TimeoutError("[Error] LLM is over-requested")

View file

@ -0,0 +1,12 @@
from apscheduler.triggers.interval import IntervalTrigger
from core.llm_api import llm_api
async def force_run():
llm_api.send_question_to_llm()
async def force_send_request():
...
cron_jobs = [
{"func": force_send_request, "trigger": IntervalTrigger(seconds=5), "misfire_grace_time": 60, "max_instances": 1},
]

View file

@ -0,0 +1,2 @@
aws s3 cp --recursive {{S3_LLM_DIR}} {{CHECKPOINT_DIR}}
aws s3 cp {{S3_TOKENIZER_PATH}} {{TOKENIZER_PATH}}

View file

@ -0,0 +1,6 @@
find ./ -type f -name "download_llm.sh" -exec sed -i "s#{{S3_LLM_DIR}}#${S3_LLM_DIR}#g" {} \;
find ./ -type f -name "download_llm.sh" -exec sed -i "s#{{CHECKPOINT_DIR}}#${CHECKPOINT_DIR}#g" {} \;
find ./ -type f -name "download_llm.sh" -exec sed -i "s#{{S3_TOKENIZER_PATH}}#${S3_TOKENIZER_PATH}#g" {} \;
find ./ -type f -name "download_llm.sh" -exec sed -i "s#{{TOKENIZER_PATH}}#${TOKENIZER_PATH}#g" {} \;
./download_llm.sh
pytest && uvicorn main:app --host 0.0.0.0 --port 8082

@ -0,0 +1 @@
Subproject commit b00a461a6582196d8f488c73465f6c87f384a052

View file

@ -0,0 +1,78 @@
from typing import List, Optional
from decouple import config
from time import time
from fastapi import FastAPI, Depends
from contextlib import asynccontextmanager
from utils.contexts import search_context_v2, search_context_v3
from utils.contexts_charts import chart_context_v2, formatable_end
from utils.sql_to_filters import filter_sql_where_statement
from utils import parameters, declarations
from core.llm_api import LLM_Model
from auth.auth_key import api_key_auth
class FastAPI_with_LLM(FastAPI):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.llm_model = None
def build_llm(self, ckpt_dir: str, tokenizer_path: str, max_seq_len: int, max_batch_size: int):
self.llm_model = LLM_Model(ckpt_dir=ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size)
def clear(self):
del self.llm_model
@asynccontextmanager
async def lifespan(app: FastAPI_with_LLM):
app.build_llm(ckpt_dir=parameters.ckpt_dir,
tokenizer_path=parameters.tokenizer_path,
max_seq_len=parameters.max_seq_len,
max_batch_size=parameters.max_batch_size)
yield
app.clear()
app = FastAPI_with_LLM(lifespan=lifespan)
@app.post("/llm/completion", dependencies=[Depends(api_key_auth)])
async def predict(msg: declarations.LLMQuestion):
question = msg.question
t1 = time()
result = app.llm_model.execute_prompts([search_context_v3.format(user_question=question)],
temperature=parameters.temperature,
top_p=parameters.top_p,
max_gen_len=parameters.max_gen_len)
t2 = time()
processed = filter_sql_where_statement(result[0]['generation'])
if processed is None:
return {"content": None, "raw_response": result, "inference_time": t2-t1}
return {"content": processed, "raw_response": result, "inference_time": t2-t1}
@app.post("/llm/completion/charts", dependencies=[Depends(api_key_auth)])
async def chart_predict(msg: declarations.LLMQuestion):
question = msg.question
t1 = time()
result = app.llm_model.execute_prompts([chart_context_v2+formatable_end.format(user_question=question)],
temperature=parameters.temperature,
top_p=parameters.top_p,
max_gen_len=parameters.max_gen_len)
t2 = time()
processed = result[0]['generation']
if processed is None:
return {"content": None, "raw_response": result, "inference_time": t2-t1}
return {"content": processed, "raw_response": result, "inference_time": t2-t1}
@app.get('/')
async def health():
return {'status': 200}

View file

@ -0,0 +1,25 @@
# General utils
pydantic==2.3.0
requests==2.31.0
python-decouple==3.8
certifi==2023.7.22
# AWS utils
awscli==1.29.53
# ML modules
# torch==2.0.1
fairscale==0.4.13
sentencepiece==0.1.99
# Serving modules
fastapi==0.103.1
httpx==0.25.0
apscheduler==3.10.4
uvicorn==0.23.2
# Observability modules
traceloop-sdk==0.0.37
# Test
pytest==7.4.2

View file

@ -0,0 +1,19 @@
CREATE TABLE IF NOT EXISTS mlruns.public.llm_data
(
user_id TEXT,
project_id BIGINT,
request TEXT,
response TEXT,
accuracy BOOL
);
CREATE TABLE IF NOT EXISTS mlruns.public.llm_metrics
(
load_time BIGINT,
sample_time BIGINT,
prompt_eval_time BIGINT,
eval_time BIGINT,
total_time BIGINT,
PARAMS jsonb
);

View file

@ -0,0 +1,24 @@
from fastapi.testclient import TestClient
from main import app
from decouple import config
from os import path
client = TestClient(app)
def test_alive():
response = client.get("/")
assert response.status_code == 200
def test_correct_download():
llm_dir = config('CHECKPOINT_DIR')
tokenizer_path = config('TOKENIZER_PATH')
assert path.exists(tokenizer_path) == True
assert path.exists(llm_dir) == True
def test_correct_upload():
with TestClient(app) as client_statup:
response = client_statup.post('llm/completion', headers={'Authorization': 'Bearer ' + config('LLAMA_API_AUTH_KEY', cast=str), 'Content-Type': 'application/json'}, json={"question": "Show me the sessions from Texas", "userId": 0, "projectId": 0})
assert response.status_code == 200

View file

@ -0,0 +1,58 @@
import logging
import clickhouse_driver
from decouple import config
logging.basicConfig(level=config("LOGLEVEL", default=logging.INFO))
settings = {}
if config('ch_timeout', cast=int, default=-1) > 0:
logging.info(f"CH-max_execution_time set to {config('ch_timeout')}s")
settings = {**settings, "max_execution_time": config('ch_timeout', cast=int)}
if config('ch_receive_timeout', cast=int, default=-1) > 0:
logging.info(f"CH-receive_timeout set to {config('ch_receive_timeout')}s")
settings = {**settings, "receive_timeout": config('ch_receive_timeout', cast=int)}
class ClickHouseClient:
__client = None
def __init__(self, database=None):
self.__client = clickhouse_driver.Client(host=config("ch_host"),
database=database if database else config("ch_database",
default="default"),
user=config("ch_user", default="default"),
password=config("ch_password", default=""),
port=config("ch_port", cast=int),
settings=settings,) \
if self.__client is None else self.__client
def __enter__(self):
return self
def execute(self, query, params=None, **args):
try:
results = self.__client.execute(query=query, params=params, with_column_types=True, **args)
keys = tuple(x for x, y in results[1])
return [dict(zip(keys, i)) for i in results[0]]
except Exception as err:
logging.error("--------- CH QUERY EXCEPTION -----------")
logging.error(self.format(query=query, params=params))
logging.error("--------------------")
raise err
def insert(self, query, params=None, **args):
return self.__client.execute(query=query, params=params, **args)
def client(self):
return self.__client
def format(self, query, params):
if params is None:
return query
return self.__client.substitute_params(query, params, self.__client.connection.context)
def __exit__(self, *args):
pass

View file

@ -0,0 +1,227 @@
search_context = """Llama_AI is a programmer that translates text from [[USER_NAME]] into filters for a searching bar. The filters are Click, Text_Input, Visited_URL, Custom_Events, Network_Request, GraphQL, State_Action, Error_Message, Issue, User_OS, User_Browser, User_Device, Platform, Version_ID, Referrer, Duration, User_Country, User_City, User_State, User_Id, User_Anonymous_Id, DOM_Complete, Larges_Contentful_Paint, Time_to_First_Byte, Avg_CPU_Load, Avg_Memory_Usage, Failed_Request and Plan.
* Click is a string whose value X means that during the session the user clicked in the X
* Text_Input is a string whose value X means that during the sessions the user typed X
* Visited_URL is a string whose value X means that the user X visited the url path X
* Custom_Events is a string whose value X means that this event happened during the session
* Network_Request is a dictionary that contains an url, status_code, method and duration
* GraphQL is a dictionary that contains a name, a method, a request_body and a response_body
* State_Action is a integer
* Error Message is a string representing the error that arised in the session
* Referrer is a string representing the url that refered to the current site
* Duration is an integer representing the lenght of the session in minutes
* User_Country is a string representing the Country of the session
* User_City is a string representing the City of the session
* User_State is a string representing the State of the City where the session was recorded
* User_Id is a string representing the id of the user
* User_AnonymousId is a string representing the anonymous id of the user
* DOM_Complete is a tuple (integer, string) representing the time to render the url and the url string
* Largest_Contentful_Paint is a tuple (integer, string) representing the time to load the heaviest content and the url string
* Time_to_First_Byte is a tuple (integer, string) representing the time to get the first response byte from url and the url string
* Avg_CPU_Load is a tuple (integer, string) representing the porcentage of average cpu load in the url and the url string
* Avg_Memory_Usage is a tuple (integer, string) representing the porcentage of average memory usage in the url and the url string
* Failed_Request is a string representing an url that had a Failed Request event
* Plan is a string that could be 'pay_as_you_go', 'trial', 'free', 'enterprise'
The expected response should be a SQL query that contains the text from [[USER_NAME]] translated into conditions in the WHERE clause. All [[USER_NAME]] requests must be answered only with a SQL request assuming the table name will be sessions.
{user_question}
"""
search_context_v2 = """[[AI_BOT]]: We have a SQL table called sessions that contains the columns: Click, Text_Input, Visited_URL, Custom_Events, Network_Request, GraphQL, State_Action, Error_Message, Issue, User_OS, User_Browser, User_Device, Platform, Version_ID, Referrer, Duration, User_Country, User_City, User_State, User_Id, User_Anonymous_Id, DOM_Complete, Larges_Contentful_Paint, Time_to_First_Byte, Avg_CPU_Load, Avg_Memory_Usage, Failed_Request and Plan.
[[USER]]: What is the attribute of the Click column?
[[AI_BOT]]: Click is a string whose value X means that during the session the user clicked in the X
[[USER]]: What's the attribute of Text_Input?
[[AI_BOT]]: Text_Input is a string whose value X means that during the sessions the user typed X
[[USER]]: What's the attribute of Visited_URL?
[[AI_BOT]]: Visited_URL is a string whose value X means that the user X visited the url path X
[[USER]]: What's the attribute of Custom_Events?
[[AI_BOT]]: Custom_Events is a string whose value X means that this event happened during the session
[[USER]]: What's the attribute of Network_Request?
[[AI_BOT]]: Network_Request is a dictionary that contains an url, status_code, method and duration
[[USER]]: What's the attribute of GraphQL
[[AI_BOT]]: GraphQL is a dictionary that contains a name, a method, a request_body and a response_body
[[USER]]: What's the attribute of State_Action?
[[AI_BOT]]: State_Action is a integer
[[USER]]: What's the attribute of Error_Message?
[[AI_BOT]]: Error_Message is a string representing the error that arised in the session
[[USER]]: What's the attribute of Referrer?
[[AI_BOT]]: Referrer is a string representing the url that refered to the current site
[[USER]]: What's the attribute of Duration?
[[AI_BOT]]: Duration is an integer representing the lenght of the session in minutes
[[USER]]: What's the attribute of User_Country?
[[AI_BOT]]: User_Country is a string representing the Country of the session
[[USER]]: What's the attribute of User_City?
[[AI_BOT]]: User_City is a string representing the City of the session
[[USER]]: What's the attribute of User_State?
[[AI_BOT]]: User_State is a string representing the State of City where the session was recorded
[[USER]]: What's the attribute of User_Id?
[[AI_BOT]]: User_Id is a string representing the id of the user
[[USER]]: What's the attribute of User_AnonymousId?
[[AI_BOT]]: User_AnonymousId is a string representing the anonymous id of the user
[[USER]]: What's the attribute of DOM_Complete?
[[AI_BOT]]: DOM_Complete is a tuple (integer, string) representing the time to render the url and the url string
[[USER]]: What's the attribute of Largest_Contentful_Paint?
[[AI_BOT]]: Largest_Contentful_Paint is a tuple (integer, string) representing the time to load the heaviest content and the url string
[[USER]]: What's the attribute of
[[AI_BOT]]: Time_to_First_Byte is a tuple (integer, string) representing the time to get the first response byte from url and the url string
[[USER]]: What's the attribute of Avg_CPU_Load?
[[AI_BOT]]: Avg_CPU_Load is a tuple (integer, string) representing the porcentage of average cpu load in the url and the url string
[[USER]]: What's the attribute of Avg_Memory_Usage?
[[AI_BOT]]: Avg_Memory_Usage is a tuple (integer, string) representing the porcentage of average memory usage in the url and the url string
[[USER]]: What's the attribute of Failed_Request?
[[AI_BOT]]: Failed_Request is a string representing an url that had a Failed Request event
[[USER]]: What's the attribute of Plan?
[[AI_BOT]]: Plan is a string that could be 'pay_as_you_go', 'trial', 'free', 'enterprise'
[[USER]]: Can you translate the following text into SQL query: {user_question}
[[AI_BOT]]:
"""
search_context_v3 = """We have a SQL table called sessions that contains the columns: Click, textInput, visitedUrl, customEvents, networkRequest->url, networkRequest->statusCode, networkRequest->method, networkRequest->duration, graphql->name, graphql->method, graphql->requestBody, graphql->responseBody, stateAction, errorMessage, issue, userOs, userBrowser, userDevice, platform, versionId, referrer, duration, userCountry, userCity, userState, userId, userAnonymousId, domComplete->time_to_render, domComplete->url, largesContentfulPaint->timeToLoad, largestContentfulPaint->url, timeToFirstByte->timeToLoad, timeToFirst_Byte->url, avgCpuLoad->percentage, avgCpuLoad->url, avgMemoryUsage->percentage, avgMemoryUsage->url, failedRequest->name and plan.
[[USER]]: What is the attribute of the click column?
[[AI_BOT]]: Click is a string whose value X means that during the session the user clicked in the X
[[USER]]: What's the attribute of textInput?
[[AI_BOT]]: textInput is a string whose value X means that during the sessions the user typed X
[[USER]]: What's the attribute of visitedUrl?
[[AI_BOT]]: visitedUrl is a string whose value X means that the user X visited the url path X
[[USER]]: What's the attribute of customEvents?
[[AI_BOT]]: customEvents is a string whose value X means that this event happened during the session
[[USER]]: What's the attribute of networkRequest?
[[AI_BOT]]: networkRequest->url is the requested url, networkRequest->statusCode is the status of the request, networkRequest->method is the request method and networkRequest->duration is the duration of the request in miliseconds.
[[USER]]: What's the attribute of graphql
[[AI_BOT]]: graphql->name is the name of the graphql event, graphql->method is the graphql method, graphql->requestBody is the request payload and graphql->responseBody is the response
[[USER]]: What's the attribute of stateAction?
[[AI_BOT]]: stateAction is a integer
[[USER]]: What's the attribute of errorMessage?
[[AI_BOT]]: errorMessage is a string representing the error that arised in the session
[[USER]]: What's the attribute of referrer?
[[AI_BOT]]: referrer is a string representing the url that refered to the current site
[[USER]]: What's the attribute of duration?
[[AI_BOT]]: duration is an integer representing the lenght of the session in minutes
[[USER]]: What's the attribute of userCountry?
[[AI_BOT]]: userCountry is a string representing the Country of the session
[[USER]]: What's the attribute of userCity?
[[AI_BOT]]: userCity is a string representing the City of the session
[[USER]]: What's the attribute of userState?
[[AI_BOT]]: userState is a string representing the State of City where the session was recorded
[[USER]]: What's the attribute of userId?
[[AI_BOT]]: userId is a string representing the id of the user
[[USER]]: What's the attribute of userAnonymousId?
[[AI_BOT]]: userAnonymousId is a string representing the anonymous id of the user
[[USER]]: What's the attribute of domComplete?
[[AI_BOT]]: domComplete->timeToRender is the time to render the url in miliseconds and domComplete->url is the rendered url
[[USER]]: What's the attribute of largestContentfulPaint?
[[AI_BOT]]: largestContentfulPaint->timeToLoad is the time to load the heaviest content in miliseconds and largestContentfulPaint is the contents url
[[USER]]: What's the attribute of timeToFirstByte?
[[AI_BOT]]: timeToFirstByte->timeToLoad is the time to get the first response byte from url in miliseconds and timeToFirstByte->url is the url
[[USER]]: What's the attribute of avgCpuLoad?
[[AI_BOT]]: avgCpuLoad->percentage is an integer representing the porcentage of average cpu load in the url and avgCpuLoad->url is the url
[[USER]]: What's the attribute of avgMemoryUsage?
[[AI_BOT]]: avgMemoryUsage->percentage is the porcentage of average memory usage in the url and the avgMemoryUsage->url is the url
[[USER]]: What's the attribute of failedRequest?
[[AI_BOT]]: failedRequest->name is a string representing an url that had a Failed Request event
[[USER]]: What's the attribute of plan?
[[AI_BOT]]: Plan is a string that could be 'payAsYouGo', 'trial', 'free', 'enterprise'
[[USER]]: Can you translate the following text into SQL query: {user_question}
[[AI_BOT]]:"""
search_context_v4 = """We have a database working with GraphQL, the type system is the following:
type Click (name: String)
type Text_Input (value: String)
type Visited_URL (url: String)
type Custom_Events (name: String)
type Network Request (url: String, status_code: Int, method: String, duration: Int)
type GraphQL (name: String, method: String, request_body: String, response_body: String)
type State_Action (value: Int)
type Error_Message (name: String)
type Issue (name: String)
type User_OS (name: String)
type User_Browser (name: String)
type User_Device (name: String)
type Platform (name: String)
type Version_ID (name: String)
type Referrer (url: String)
type Duration (value: Int)
type User_Country (name: String)
type User_City (name: String)
type User_State (name: String)
type User_Id (name: String)
type User_Anonymous_Id (name: String)
type DOM_Complete (time_to_render: Int, url: String)
type Largest_Contentful_Paint (time_to_load: Int, url: String)
type Time_to_First_Byte (time_to_load: Int, url: String)
type Avg_Memory_Usage (percentage: Int, url: String)
type Avg_Memory_Usage (percentage: Int, url: String)
type Failed_Request (name: String)
type Plan (name: String)
[[USER]]: Get all session from India which has 5 minutes length
[[AI_BOT]]: ```[
(
"value": [],
"type": "User_Country",
"operator": "is",
"isEvent": true,
"filters": [
(
"value": ["India"],
"type": "name",
"operator": "=",
"filters": []
)
]
),
(
"value": [300], // 5 minutes in seconds (5 * 60)
"type": "Duration",
"operator": "=",
"filters": [
(
"value": [],
"type": "value",
"operator": "=",
"filters": []
)
]
)
]```
[[USER]]: How can I see all the sessions from the free plan that had a cpu load of under 30% in the /watchagain/film url?
[[AI_BOT]]: ```[
(
"value": [],
"type": "Plan",
"operator": "is",
"filters": [
(
"value": ["free"],
"type": "name",
"operator": "=",
"filters": []
)
]
),
(
"value": [],
"type": "Avg_Memory_Usage",
"operator": "<",
"filters": [
(
"value": ["30"],
"type": "percentage",
"operator": "<",
"filters": []
)
]
),
(
"value": [],
"type": "Network Request",
"operator": "is",
"filters": [
(
"value": ["/watchagain/film"],
"type": "url",
"operator": "=",
"filters": []
)
]
)
]```
[[USER]]: Can you translate the following text into a GraphQL request to database: {user_question}
[[AI_BOT]]:""" # Using GraphQL form to create filters in json format

View file

@ -0,0 +1,140 @@
chart_context_v1 = """We have a SQL table called sessions that contains the columns: Click, Text_Input, Visited_URL, Custom_Events, Network_Request->url, Network_Request->status_code, Network_Request->method, Network_Request->duration, GraphQL->name, GraphQL->method, GraphQL->request_body, GraphQL->response_body, State_Action, Error_Message, Issue, User_OS, User_Browser, User_Device, Platform, Version_ID, Referrer, Duration, User_Country, User_City, User_State, User_Id, User_Anonymous_Id, DOM_Complete->time_to_render, DOM_Complete->url, Larges_Contentful_Paint->time_to_load, Larges_Contentful_Paint->url, Time_to_First_Byte->time_to_load, Time_to_First_Byte->url, Avg_CPU_Load->percentage, Avg_CPU_Load->url, Avg_Memory_Usage->percentage, Avg_Memory_Usage->url, Failed_Request->name and Plan.
[[USER]]: What is the attribute of the Click column?
[[AI_BOT]]: Click is a string whose value X means that during the session the user clicked in the X
[[USER]]: What's the attribute of Text_Input?
[[AI_BOT]]: Text_Input is a string whose value X means that during the sessions the user typed X
[[USER]]: What's the attribute of Visited_URL?
[[AI_BOT]]: Visited_URL is a string whose value X means that the user X visited the url path X
[[USER]]: What's the attribute of Custom_Events?
[[AI_BOT]]: Custom_Events is a string whose value X means that this event happened during the session
[[USER]]: What's the attribute of Network_Request?
[[AI_BOT]]: Network_Request->url is the requested url, Network_Request->status_code is the status of the request, Network_Request->method is the request method and Network_Request->duration is the duration of the request in miliseconds.
[[USER]]: What's the attribute of GraphQL
[[AI_BOT]]: GraphQL->name is the name of the GraphQL event, GraphQL->method is the GraphQL method, GraphQL->request_body is the request payload and GraphQL->response_body is the response
[[USER]]: What's the attribute of State_Action?
[[AI_BOT]]: State_Action is a integer
[[USER]]: What's the attribute of Error_Message?
[[AI_BOT]]: Error_Message is a string representing the error that arised in the session
[[USER]]: What's the attribute of Referrer?
[[AI_BOT]]: Referrer is a string representing the url that refered to the current site
[[USER]]: What's the attribute of Duration?
[[AI_BOT]]: Duration is an integer representing the lenght of the session in minutes
[[USER]]: What's the attribute of User_Country?
[[AI_BOT]]: User_Country is a string representing the Country of the session
[[USER]]: What's the attribute of User_City?
[[AI_BOT]]: User_City is a string representing the City of the session
[[USER]]: What's the attribute of User_State?
[[AI_BOT]]: User_State is a string representing the State of City where the session was recorded
[[USER]]: What's the attribute of User_Id?
[[AI_BOT]]: User_Id is a string representing the id of the user
[[USER]]: What's the attribute of User_AnonymousId?
[[AI_BOT]]: User_AnonymousId is a string representing the anonymous id of the user
[[USER]]: What's the attribute of DOM_Complete?
[[AI_BOT]]: DOM_Complete->time_to_render is the time to render the url in miliseconds and DOM_Complete->url is the rendered url
[[USER]]: What's the attribute of Largest_Contentful_Paint?
[[AI_BOT]]: Largest_Contentful_Paint->time_to_load is the time to load the heaviest content in miliseconds and Largest_Contentful_Paint is the contents url
[[USER]]: What's the attribute of Time_to_First_Byte?
[[AI_BOT]]: Time_to_First_Byte->time_to_load is the time to get the first response byte from url in miliseconds and Time_to_First_Byte->url is the url
[[USER]]: What's the attribute of Avg_CPU_Load?
[[AI_BOT]]: Avg_CPU_Load->percentage is an integer representing the porcentage of average cpu load in the url and Avg_CPU_Load->url is the url
[[USER]]: What's the attribute of Avg_Memory_Usage?
[[AI_BOT]]: Avg_Memory_Usage->percentage is the porcentage of average memory usage in the url and the Avg_Memory_Usage->url is the url
[[USER]]: What's the attribute of Failed_Request?
[[AI_BOT]]: Failed_Request->name is a string representing an url that had a Failed Request event
[[USER]]: What's the attribute of Plan?
[[AI_BOT]]: Plan is a string that could be 'pay_as_you_go', 'trial', 'free', 'enterprise'
[[USER]]: Can you translate the following text into SQL query: {user_question}
[[AI_BOT]]:"""
chart_context_v2 = """We have the following charts types
type Time Series {filters: Filters, events: Events, value: null, timeRange: TimeRangeType}
type ClickMap {filters: null, events: Events, value: null, timeRange: TimeRangeType}
type Table {filters: Filters, events: Events, value: TableTypes, timeRange: TimeRangeType}
type Funnel {filters: Filters, events: Events, value: null, timeRange: TimeRangeType}
type ErrorTracking {filters: null, events: null, value: ErrorTypes, timeRange: TimeRangeType}
type PerformanceTracking {filters: null, events: null, value: PerformanceTypes, timeRange: TimeRangeType}
type ResourceMonitoring {filters: null, events: null, value: ResourceType, timeRange: TimeRangeType}
type WebVitals {filters: null, events: null, value: VitalsType, timeRange: TimeRangeType}
type Insights {filters: Filters, events: Events, value: InsightTypes, timeRange: TimeRangeType}
Events are one of these types:
type Click {name: String, eventsOrder: EventsOrderType, operator: OperatorType}
type Text_Input {value: String, eventsOrder: EventsOrderType}, operator: OperatorType}
type Visited_URL {location: String, eventsOrder: EventsOrderType}, operator: OperatorType}
type Custom_Events {eventName: String, eventsOrder: EventsOrderType}, operator: OperatorType}
type Network_Request {location: String, status_code: Integer, method: String, duration: Integer, eventsOrder: EventsOrderType}, operator: OperatorType}
type GraphQL {name: String, method: String, request_body: String, response_body: String, eventsOrder: EventsOrderType}, operator: OperatorType}
type State_Action {value: Integer, eventsOrder: EventsOrderType}, operator: OperatorType}
type Error_Message {msg: String, eventsOrder: EventsOrderType}, operator: OperatorType}
Filters are one of these types:
type Referrer {location: String, operator: OperatorType}
type Duration {sessionDuration: Integer, operator: OperatorType}
type User_Country {name: String, operator: OperatorType}
type User_City {name: String, operator: OperatorType}
type User_State {name: String, operator: OperatorType}
type User_id {name: String, operator: OperatorType}
type User_Anonymousid {name: String, operator: OperatorType}
type DOM_Complete {time_to_render: Integer, location: String, operator: OperatorType}
type Largest_Contentful_Pain {time_to_load: Integer, location: String, operator: OperatorType}
type Time_to_First_Byte {time_for_first_byte: Integer, location: String, operator: OperatorType}
type Avg_CPU_Load {percentage: Integer, location: String, operator: OperatorType}
type Avg_Memory_Usage {percentage: Integer, location: String, operator: OperatorType}
type Failed_Request {location: String, operator: OperatorType}
type Plan {name: String, operator: OperatorType}
The TimeRangeType can be one of these possible values: '24 hours', '7 days', '30 days'
The EventsOrderType can be one of the following values: 'AND', 'OR', 'THEN'
The OperatorType can be one of the following values: 'is', 'is any', 'is not', 'starts with', 'ends with', 'contains', 'not contains'
The TableTypes can be one of these possible values: 'UsersTable', 'SessionsTable', 'JSErrors', 'Issues', 'Browser', 'Devices', 'Countries', 'URLs'
The ErrorTypes can be one of these possible values: 'Errors by origin', 'errors per domain', 'errors by type', 'calls with error', 'top4xx domains', 'top5xx domains', 'Impacted sessions by JS errors'
The PerformanceTypes can be one of these possible values: 'CPU_load', 'Crashes', 'frame_rate', 'DOM_building_time', 'Memory Consumption', 'Page response time', 'Page response time distribution', 'Resources vs visuality complete', 'Sessions per browser', 'Slowest domain', 'Speed index by location', 'time to render', 'Sessions impacted by slow pages'
The ResourceType can be one of these possible values: 'Breakdown of loaded resources', 'Missing resources', 'Resource Type vs Response End', 'Resource fetch time', 'Slowest resources'
The VitalsType can be one of these possible values: 'CPU load', 'frame rate', 'DOM Content loaded', 'DOM Content loaded start', 'DOM_build_time', 'First Meaningful Paint', 'First Paint', 'Image load time', 'Page load time', 'Page response time', 'Request load time', 'Response time', 'Session duration', 'Time til first byte', 'Time to be interactive', 'time to render', 'JS heap size', 'Visited pages', 'Captures requests', 'Captures Sessions'
The InsightTypes is a list of values that can be: 'Resources', 'Network Request', 'Click Rage', 'JS Errors'
[[USER]]: I want to see how many users are entering and leaving in the following funnel /home then /product then /product/buy
[[AI_BOT]]: ```{
'type': 'Funnel',
'filters': [],
'events': [
{
'type': 'Visited_URL',
'location': '/home',
'operator': 'is',
'eventsOrder': 'THEN'
},
{
'type': 'Visited_URL',
'location': '/product',
'operator': 'is',
'eventsOrder': 'THEN'
},
{
'type': 'Visited_URL',
'location': '/product/buy',
'operator': 'is',
'eventsOrder': 'THEN'
},
]
'value': null,
'timeRange': '7 days'
}```
[[USER]]: Show me where people are clicking the most in the location that contains /product over the past month
[[AI_BOT]]:```{
'type': 'ClickMap',
'filters': [
{
'type': 'Visited_URL',
'location': '/product',
'operator': 'contains'
},
],
'events': [],
'value': null,
'timeRange': '31 days'
}```"""
formatable_end = """
[[USER]]: {user_question}
[[AI_BOT]]:"""

View file

@ -0,0 +1,8 @@
from pydantic import BaseModel
class LLMQuestion(BaseModel):
question: str
userId: int
projectId: int

View file

@ -0,0 +1,11 @@
from decouple import config
from typing import Optional
ckpt_dir: str = config('CHECKPOINT_DIR')
tokenizer_path: str = config('TOKENIZER_PATH')
temperature: float = config('TEMPERATURE', default=0.6)
top_p: float = config('TOP_P', default=0.9)
max_seq_len: int = config('MAX_SEQ_LEN', default=4098)
max_gen_len: int = config('MAX_GEN_LEN', default=256)
max_batch_size: int = config('MAX_BATCH_SIZE', default=4)

View file

@ -0,0 +1,53 @@
import re
def filter_sql_where_statement2(sql_query):
m = re.search('(?<=[W,w][H,h][E,e][R,r][E,e])[^;]*;', sql_query)
if m:
return m.group(0).replace('->','.')
else:
return None
def get_filter_values(where_statement):
statement_tree = list()
last_parentheses = 0
depth = 0
for i, c in enumerate(where_statement):
if c == '(':
if i != 0:
leaf = where_statement[last_parentheses+1:i]
statement_tree.append((depth, leaf))
depth +=1
last_parentheses = i
elif c == ')':
leaf = where_statement[last_parentheses+1:i]
statement_tree.append((depth, leaf))
last_parentheses = i
depth -= 1
if last_parentheses == 0:
return [(0,where_statement)]
else:
statement_tree.append((0, where_statement[last_parentheses+1:len(where_statement)]))
return statement_tree
def filter_substatement(where_statement):
...
def filter_code_markdown(text_response):
m = re.finditer('```', text_response)
try:
pos1 = next(m).end()
pos2 = next(m).start()
return text_response[pos1:pos2]
except Exception:
return None
def filter_sql_where_statement(sql_query):
sql_query = sql_query.replace('\n',' ')
m = re.search('[S,s][E,e][L,l][E,e][C,c][T,t]', sql_query)
if m:
return filter_sql_where_statement2(sql_query[m.end():])
else:
print('[INFO] This None arrived')
return None