feat(recommendations): Added services recommendation (ml_service) and trainer (ml_trainer) (#1275)

* Created two services: recommendation training and recommendation serving

* Deleted Docker temporary

* Added features based in signals information

* Added method to get sessions features using PG

* Added same utils and core elements into ml_trainer

* Added checks before training models, added handler for model serving

* Updated serving API and recommendation functions to use frontend signals features

* reorganized modules to have base image and for both serving and training

* Added Dockerfiles and base Dockerfile

* Solved issue while ordering sessions by relevance

* Added method to save user feedback of recommendations

* Added security authorization

* Updated Dockerfile

* fixed issues with secret insertion to API

* Updated feedback structure

* Added git for dags

* Solved issue of insertion on recommendation feedback

* Changed update method from def to async def and it is called during startup

* Solved issues of airflow running mlflow in dag

* Changes sanity checks and added middleware params

* base path renaming

* Changed update method to a interval method which loads one model each 10s if there are models to download

* Added sql files for recommendation service and trainer

* Cleaned files and added documentation for methods and classes

* Added README file

* Renamed endpoints, changed None into empty array and updated readme

* refactor(recommendation): optimized query

* style(recommendation): changed import to top file, renamed endpoints parameters, function optimization

* refactor(recommendation): .gitignore

* refactor(recommendation): .gitignore

* refactor(recommendation): Optimized Dockerfiles

* refactor(recommendation): changed imports

* refactor(recommendation): optimized requests

* refactor(recommendation): optimized requests

* Fixed boot for fastapi, updated some queries

* Fixed issues while downloading models and while returning json response from API

* limited number of recommendations and set a minimum score to present recommendations

* fix(recommendation): fixed some queries and updated prediction method

* Added env value to control number of predictions to make

* docs(recommendation): Added third party libraries used in recommendation service

* frozen requirements

* Update base_crons.py

added `misfire_grace_time` to recommendation crons

---------

Co-authored-by: Taha Yassine Kraiem <tahayk2@gmail.com>
This commit is contained in:
MauricioGarciaS 2023-06-07 15:58:33 +02:00 committed by GitHub
parent 78e3e8a554
commit cea5eda985
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
43 changed files with 2702 additions and 709 deletions

170
ee/recommendation/.gitignore vendored Normal file
View file

@ -0,0 +1,170 @@
### Example user template template
### Example user template
# IntelliJ project files
.idea
*.iml
out
gen
### Python template
# Byte-compiled / optimized / DLL files
./**/__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

View file

@ -1,14 +1,14 @@
FROM apache/airflow:2.4.3 FROM python:3.10-slim-buster
COPY requirements.txt .
USER root
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y \ && apt-get install -y gcc libc-dev g++ pkg-config libxml2-dev libxmlsec1-dev libxmlsec1-openssl \
vim \ && apt-get clean
&& apt-get install gcc libc-dev g++ -y \
&& apt-get install -y pkg-config libxml2-dev libxmlsec1-dev libxmlsec1-openssl
WORKDIR app
COPY requirements_base.txt .
RUN pip install --no-cache-dir -r requirements_base.txt
COPY core core
COPY utils utils
USER airflow
RUN pip install --upgrade pip
RUN pip install -r requirements.txt

View file

@ -0,0 +1,78 @@
# Recommendations
## index
1. [Build image](#build-image)
1. [Recommendations service image](#recommendations-service-image)
2. [Trainer service image](#trainer-service-image)
2. [Trainer](#trainer-service)
1. [Env params](#trainer-env-params)
3. [Recommendations](#recommendation-service)
1. [Env params](#recommendation-env-params)
## Build image
In order to build both recommendation image and trainer image, first a base image should be created by running the following command:
```bash
docker build -t recommendations_base .
```
which will add the files from `core` and `utils` which are common between `ml_service` and `ml_trainer` and will install common dependencies.
### Recommendations service image
Inside `ml_service` run docker build to create the recommendation service image
```bash
cd ml_service/
docker build -t recommendations .
cd ../
```
### Trainer service image
Inside `ml_trainer` run docker build to create the recommendation service image
```bash
cd ml_trainer/
docker build -t trainer .
cd ../
```
## Trainer service
The trainer is an orchestration service which is in charge of training models and saving models into S3.
This is made using Directed Acyclic Graphs (DAGs) in [Airflow](https://airflow.apache.org) for orchestration
and [MLflow](https://mlflow.org) as a monitoring service for training model that creates a registry over S3.
### Trainer env params
```bash
pg_host=
pg_port=
pg_user=
pg_password=
pg_dbname=
pg_host_ml=
pg_port_ml=
pg_user_ml=
pg_password_ml=
pg_dbname_ml='mlruns'
PG_POOL='true'
MODELS_S3_BUCKET= #'s3://path/to/bucket'
pg_user_airflow=
pg_password_airflow=
pg_dbname_airflow='airflow'
pg_host_airflow=
pg_port_airflow=
AIRFLOW_HOME=/app/airflow
airflow_secret_key=
airflow_admin_password=
crons_train='0 0 * * *'
```
## Recommendation service
The recommendation service is a [FastAPI](https://fastapi.tiangolo.com) server that uses MLflow to read models from S3
and serve them, it also takes feedback from user and saves it into postgres for retraining purposes.
### Recommendation env params
```bash
pg_host=
pg_port=
pg_user=
pg_password=
pg_dbname=
pg_host_ml=
pg_port_ml=
pg_user_ml=
pg_password_ml=
pg_dbname_ml='mlruns'
PG_POOL='true'
API_AUTH_KEY=
```

View file

@ -1 +0,0 @@
docker-compose down --volumes --rmi all

View file

@ -0,0 +1,128 @@
import mlflow.pyfunc
import random
import numpy as np
from sklearn import metrics
from sklearn.svm import SVC
from sklearn.feature_selection import SequentialFeatureSelector as sfs
from sklearn.preprocessing import normalize
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier as knc
def select_features(X, y):
"""
Dimensional reduction of X using k-nearest neighbors and sequential feature selector.
Final dimension set to three features.
Params:
X: Array which will be reduced in dimension (batch_size, n_features).
y: Array of labels (batch_size,).
Output: function that reduces dimension of array.
"""
knn = knc(n_neighbors=3)
selector = sfs(knn, n_features_to_select=3)
X_transformed = selector.fit_transform(X, y)
def transform(input):
return selector.transform(input)
return transform, X_transformed
def sort_database(X, y):
"""
Random shuffle of training values with its respective labels.
Params:
X: Array of features.
y: Array of labels.
Output: Tuple (X_rand_sorted, y_rand_sorted).
"""
sort_list = list(range(len(y)))
random.shuffle(sort_list)
return X[sort_list], y[sort_list]
def preprocess(X):
"""
Preprocessing of features (no dimensional reduction) using principal component analysis.
Params:
X: Array of features.
Output: Tuple (processed array of features function that reduces dimension of array).
"""
_, n = X.shape
pca = PCA(n_components=n)
x = pca.fit_transform(normalize(X))
def transform(input):
return pca.transform(normalize(input))
return x, transform
class SVM_recommendation(mlflow.pyfunc.PythonModel):
def __init__(self, test=False, **params):
f"""{SVC.__doc__}"""
params['probability'] = True
self.svm = SVC(**params)
self.transforms = []
self.score = 0
self.confusion_matrix = None
if test:
knn = knc(n_neighbors=3)
self.transform = [PCA(n_components=3), sfs(knn, n_features_to_select=2)]
def fit(self, X, y):
"""
Train preprocess function, feature selection and Support Vector Machine model
Params:
X: Array of features.
y: Array of labels.
"""
assert X.shape[0] == y.shape[0], 'X and y must have same length'
assert len(X.shape) == 2, 'X must be a two dimension vector'
X, t1 = preprocess(X)
t2, X = select_features(X, y)
self.transforms = [t1, t2]
self.svm.fit(X, y)
pred = self.svm.predict(X)
z = y + 2 * pred
n = len(z)
false_pos = np.count_nonzero(z == 1) / n
false_neg = np.count_nonzero(z == 2) / n
true_pos = np.count_nonzero(z == 3) / n
true_neg = 1 - false_neg - false_pos - true_pos
self.confusion_matrix = np.array([[true_neg, false_pos], [false_neg, true_pos]])
self.score = true_pos + true_neg
def predict(self, x):
"""
Transform and prediction of input features and sorting of each by probability
Params:
X: Array of features.
Output: prediction probability for True (1).
"""
for t in self.transforms:
x = t(x)
return self.svm.predict_proba(x)[:, 1]
def recommendation_order(self, x):
"""
Transform and prediction of input features and sorting of each by probability
Params:
X: Array of features.
Output: Tuple (sorted_features, predictions).
"""
for t in self.transforms:
x = t(x)
pred = self.svm.predict_proba(x)
return sorted(range(len(pred)), key=lambda k: pred[k][1], reverse=True), pred
def plots(self):
"""
Returns the plots in a dict format.
{
'confusion_matrix': confusion matrix figure,
}
"""
display = metrics.ConfusionMatrixDisplay(confusion_matrix=self.confusion_matrix, display_labels=[False, True])
return {'confusion_matrix': display.plot().figure_}

View file

@ -0,0 +1,137 @@
from utils.pg_client import PostgresClient
from decouple import config
from utils.df_utils import _process_pg_response
import numpy as np
def get_training_database(projectId, max_timestamp=None, favorites=False):
"""
Gets training database using projectId, max_timestamp [optional] and favorites (if true adds favorites)
Params:
projectId: project id of all sessions to be selected.
max_timestamp: max timestamp that a not seen session can have in order to be considered not interesting.
favorites: True to use favorite sessions as interesting sessions reference.
Output: Tuple (Set of features, set of labels, dict of indexes of each project_id, session_id, user_id in the set)
"""
args = {"projectId": projectId, "max_timestamp": max_timestamp, "limit": 20}
with PostgresClient() as conn:
x1 = signals_features(conn, **args)
if favorites:
x2 = user_favorite_sessions(args['projectId'], conn)
if max_timestamp is not None:
x3 = user_not_seen_sessions(args['projectId'], args['limit'], conn)
X_project_ids = dict()
X_users_ids = dict()
X_sessions_ids = dict()
_X = list()
_Y = list()
_process_pg_response(x1, _X, _Y, X_project_ids, X_users_ids, X_sessions_ids, label=None)
if favorites:
_process_pg_response(x2, _X, _Y, X_project_ids, X_users_ids, X_sessions_ids, label=1)
if max_timestamp:
_process_pg_response(x3, _X, _Y, X_project_ids, X_users_ids, X_sessions_ids, label=0)
return np.array(_X), np.array(_Y), \
{'project_id': X_project_ids,
'user_id': X_users_ids,
'session_id': X_sessions_ids}
def signals_features(conn, **kwargs):
"""
Selects features from frontend_signals table and mark as interesting given the following conditions:
* If number of events is greater than events_threshold (default=10). (env value)
* If session has been replayed more than once.
"""
assert 'projectId' in kwargs.keys(), 'projectId should be provided in kwargs'
projectId = kwargs['projectId']
events_threshold = config('events_threshold', default=10, cast=int)
query = conn.mogrify("""SELECT T.project_id,
T.session_id,
T.user_id,
T2.viewer_id,
T.events_count,
T.errors_count,
T.duration,
T.country,
T.issue_score,
T.device_type,
T2.interesting as train_label
FROM (SELECT project_id,
user_id as viewer_id,
session_id,
count(CASE WHEN source = 'replay' THEN 1 END) > 1 OR COUNT(1) > %(events_threshold)s as interesting
FROM frontend_signals
WHERE project_id = %(projectId)s
AND session_id is not null
GROUP BY project_id, viewer_id, session_id) as T2
INNER JOIN (SELECT project_id,
session_id,
user_id as viewer_id,
user_id,
events_count,
errors_count,
duration,
user_country as country,
issue_score,
user_device_type as device_type
FROM sessions
WHERE project_id = %(projectId)s
AND duration IS NOT NULL) as T
USING (session_id);""",
{"projectId": projectId, "events_threshold": events_threshold})
conn.execute(query)
res = conn.fetchall()
return res
def user_favorite_sessions(projectId, conn):
"""
Selects features from user_favorite_sessions table.
"""
query = """SELECT project_id,
session_id,
T1.user_id,
events_count,
errors_count,
duration,
user_country as country,
issue_score,
user_device_type as device_type,
T2.user_id AS viewer_id
FROM sessions AS T1
INNER JOIN user_favorite_sessions as T2
USING (session_id)
WHERE project_id = %(projectId)s;"""
conn.execute(
conn.mogrify(query, {"projectId": projectId})
)
res = conn.fetchall()
return res
def user_not_seen_sessions(projectId, limit, conn):
"""
Selects features from user_viewed_sessions table.
"""
# TODO: fetch un-viewed sessions alone, and the users list alone, then cross join them in python
# and ignore deleted users (WHERE users.deleted_at ISNULL)
query = """SELECT project_id, session_id, user_id, viewer_id, events_count, errors_count, duration, user_country as country, issue_score, user_device_type as device_type
FROM (
(SELECT sessions.*
FROM sessions LEFT JOIN user_viewed_sessions USING(session_id)
WHERE project_id = %(projectId)s
AND duration IS NOT NULL
AND user_viewed_sessions.session_id ISNULL
LIMIT %(limit)s) AS T1
LEFT JOIN
(SELECT user_id as viewer_id
FROM users
WHERE tenant_id = (SELECT tenant_id FROM projects WHERE project_id = %(projectId)s)) AS T2 ON true
)"""
conn.execute(
conn.mogrify(query, {"projectId": projectId, "limit": limit})
)
res = conn.fetchall()
return res

View file

@ -1,46 +0,0 @@
from datetime import datetime, timedelta
from textwrap import dedent
import pendulum
from airflow import DAG
from airflow.operators.bash import BashOperator
from airflow.operators.python import PythonOperator
import os
_work_dir = os.getcwd()
def my_function():
l = os.listdir('scripts')
print(l)
return l
dag = DAG(
"first_test",
default_args={
"depends_on_past": True,
"retries": 1,
"retry_delay": timedelta(minutes=3),
},
start_date=pendulum.datetime(2015, 12, 1, tz="UTC"),
description="My first test",
schedule="@daily",
catchup=False,
)
#assigning the task for our dag to do
with dag:
first_world = PythonOperator(
task_id='FirstTest',
python_callable=my_function,
)
hello_world = BashOperator(
task_id='OneTest',
bash_command=f'python {_work_dir}/scripts/processing.py --batch_size 500',
# provide_context=True
)
this_world = BashOperator(
task_id='ThisTest',
bash_command=f'python {_work_dir}/scripts/task.py --mode train --kernel linear',
)
first_world >> hello_world >> this_world

View file

@ -1,285 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# Basic Airflow cluster configuration for CeleryExecutor with Redis and PostgreSQL.
#
# WARNING: This configuration is for local development. Do not use it in a production deployment.
#
# This configuration supports basic configuration using environment variables or an .env file
# The following variables are supported:
#
# AIRFLOW_IMAGE_NAME - Docker image name used to run Airflow.
# Default: apache/airflow:2.4.3
# AIRFLOW_UID - User ID in Airflow containers
# Default: 50000
# Those configurations are useful mostly in case of standalone testing/running Airflow in test/try-out mode
#
# _AIRFLOW_WWW_USER_USERNAME - Username for the administrator account (if requested).
# Default: airflow
# _AIRFLOW_WWW_USER_PASSWORD - Password for the administrator account (if requested).
# Default: airflow
# _PIP_ADDITIONAL_REQUIREMENTS - Additional PIP requirements to add when starting all containers.
# Default: ''
#
# Feel free to modify this file to suit your needs.
---
version: '3'
x-airflow-common:
&airflow-common
# In order to add custom dependencies or upgrade provider packages you can use your extended image.
# Comment the image line, place your Dockerfile in the directory where you placed the docker-compose.yaml
# and uncomment the "build" line below, Then run `docker-compose build` to build the images.
# image: ${AIRFLOW_IMAGE_NAME:-apache/airflow:2.4.3}
build: .
environment:
&airflow-common-env
AIRFLOW__CORE__EXECUTOR: CeleryExecutor
AIRFLOW__DATABASE__SQL_ALCHEMY_CONN: postgresql+psycopg2://airflow:airflow@postgres/airflow
# For backward compatibility, with Airflow <2.3
AIRFLOW__CORE__SQL_ALCHEMY_CONN: postgresql+psycopg2://airflow:airflow@postgres/airflow
AIRFLOW__CELERY__RESULT_BACKEND: db+postgresql://airflow:airflow@postgres/airflow
AIRFLOW__CELERY__BROKER_URL: redis://:@redis:6379/0
AIRFLOW__CORE__FERNET_KEY: ''
AIRFLOW__CORE__DAGS_ARE_PAUSED_AT_CREATION: 'true'
AIRFLOW__CORE__LOAD_EXAMPLES: 'false'
AIRFLOW__API__AUTH_BACKENDS: 'airflow.api.auth.backend.basic_auth'
_PIP_ADDITIONAL_REQUIREMENTS: 'argcomplete'
AIRFLOW__CODE_EDITOR__ENABLED: 'true'
AIRFLOW__CODE_EDITOR__GIT_ENABLED: 'false'
AIRFLOW__CODE_EDITOR__STRING_NORMALIZATION: 'true'
AIRFLOW__CODE_EDITOR__MOUNT: '/opt/airflow/dags'
pg_user: "${pg_user}"
pg_password: "${pg_password}"
pg_dbname: "${pg_dbname}"
pg_host: "${pg_host}"
pg_port: "${pg_port}"
PG_TIMEOUT: "${PG_TIMEOUT}"
PG_POOL: "${PG_POOL}"
volumes:
- ./dags:/opt/airflow/dags
- ./logs:/opt/airflow/logs
- ./plugins:/opt/airflow/plugins
- ./scripts:/opt/airflow/scripts
- ./cache:/opt/airflow/cache
user: "${AIRFLOW_UID:-50000}:0"
depends_on:
&airflow-common-depends-on
redis:
condition: service_healthy
postgres:
condition: service_healthy
services:
postgres:
image: postgres:13
environment:
POSTGRES_USER: airflow
POSTGRES_PASSWORD: airflow
POSTGRES_DB: airflow
volumes:
- postgres-db-volume:/var/lib/postgresql/data
healthcheck:
test: ["CMD", "pg_isready", "-U", "airflow"]
interval: 5s
retries: 5
restart: always
redis:
image: redis:latest
expose:
- 6379
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
timeout: 30s
retries: 50
restart: always
airflow-webserver:
<<: *airflow-common
command: webserver
ports:
- 8080:8080
healthcheck:
test: ["CMD", "curl", "--fail", "http://localhost:8080/health"]
interval: 10s
timeout: 10s
retries: 5
restart: always
depends_on:
<<: *airflow-common-depends-on
airflow-init:
condition: service_completed_successfully
airflow-scheduler:
<<: *airflow-common
command: scheduler
healthcheck:
test: ["CMD-SHELL", 'airflow jobs check --job-type SchedulerJob --hostname "$${HOSTNAME}"']
interval: 10s
timeout: 10s
retries: 5
restart: always
depends_on:
<<: *airflow-common-depends-on
airflow-init:
condition: service_completed_successfully
airflow-worker:
<<: *airflow-common
command: celery worker
healthcheck:
test:
- "CMD-SHELL"
- 'celery --app airflow.executors.celery_executor.app inspect ping -d "celery@$${HOSTNAME}"'
interval: 10s
timeout: 10s
retries: 5
environment:
<<: *airflow-common-env
# Required to handle warm shutdown of the celery workers properly
# See https://airflow.apache.org/docs/docker-stack/entrypoint.html#signal-propagation
DUMB_INIT_SETSID: "0"
restart: always
depends_on:
<<: *airflow-common-depends-on
airflow-init:
condition: service_completed_successfully
airflow-triggerer:
<<: *airflow-common
command: triggerer
healthcheck:
test: ["CMD-SHELL", 'airflow jobs check --job-type TriggererJob --hostname "$${HOSTNAME}"']
interval: 10s
timeout: 10s
retries: 5
restart: always
depends_on:
<<: *airflow-common-depends-on
airflow-init:
condition: service_completed_successfully
airflow-init:
<<: *airflow-common
entrypoint: /bin/bash
# yamllint disable rule:line-length
command:
- -c
- |
function ver() {
printf "%04d%04d%04d%04d" $${1//./ }
}
register-python-argcomplete airflow >> ~/.bashrc
airflow_version=$$(AIRFLOW__LOGGING__LOGGING_LEVEL=INFO && gosu airflow airflow version)
airflow_version_comparable=$$(ver $${airflow_version})
min_airflow_version=2.2.0
min_airflow_version_comparable=$$(ver $${min_airflow_version})
if [[ -z "${AIRFLOW_UID}" ]]; then
echo
echo -e "\033[1;33mWARNING!!!: AIRFLOW_UID not set!\e[0m"
echo "If you are on Linux, you SHOULD follow the instructions below to set "
echo "AIRFLOW_UID environment variable, otherwise files will be owned by root."
echo "For other operating systems you can get rid of the warning with manually created .env file:"
echo " See: https://airflow.apache.org/docs/apache-airflow/stable/howto/docker-compose/index.html#setting-the-right-airflow-user"
echo
fi
one_meg=1048576
mem_available=$$(($$(getconf _PHYS_PAGES) * $$(getconf PAGE_SIZE) / one_meg))
cpus_available=$$(grep -cE 'cpu[0-9]+' /proc/stat)
disk_available=$$(df / | tail -1 | awk '{print $$4}')
warning_resources="false"
if (( mem_available < 4000 )) ; then
echo
echo -e "\033[1;33mWARNING!!!: Not enough memory available for Docker.\e[0m"
echo "At least 4GB of memory required. You have $$(numfmt --to iec $$((mem_available * one_meg)))"
echo
warning_resources="true"
fi
if (( cpus_available < 2 )); then
echo
echo -e "\033[1;33mWARNING!!!: Not enough CPUS available for Docker.\e[0m"
echo "At least 2 CPUs recommended. You have $${cpus_available}"
echo
warning_resources="true"
fi
if (( disk_available < one_meg * 10 )); then
echo
echo -e "\033[1;33mWARNING!!!: Not enough Disk space available for Docker.\e[0m"
echo "At least 10 GBs recommended. You have $$(numfmt --to iec $$((disk_available * 1024 )))"
echo
warning_resources="true"
fi
if [[ $${warning_resources} == "true" ]]; then
echo
echo -e "\033[1;33mWARNING!!!: You have not enough resources to run Airflow (see above)!\e[0m"
echo "Please follow the instructions to increase amount of resources available:"
echo " https://airflow.apache.org/docs/apache-airflow/stable/howto/docker-compose/index.html#before-you-begin"
echo
fi
mkdir -p /sources/logs /sources/dags /sources/plugins
chown -R "${AIRFLOW_UID}:0" /sources/{logs,dags,plugins}
exec /entrypoint airflow version
# yamllint enable rule:line-length
environment:
<<: *airflow-common-env
_AIRFLOW_DB_UPGRADE: 'true'
_AIRFLOW_WWW_USER_CREATE: 'true'
_AIRFLOW_WWW_USER_USERNAME: ${_AIRFLOW_WWW_USER_USERNAME:-airflow}
_AIRFLOW_WWW_USER_PASSWORD: ${_AIRFLOW_WWW_USER_PASSWORD:-airflow}
_PIP_ADDITIONAL_REQUIREMENTS: ''
user: "0:0"
volumes:
- .:/sources
airflow-cli:
<<: *airflow-common
profiles:
- debug
environment:
<<: *airflow-common-env
CONNECTION_CHECK_MAX_COUNT: "0"
# Workaround for entrypoint issue. See: https://github.com/apache/airflow/issues/16252
command:
- bash
- -c
- airflow
# You can enable flower by adding "--profile flower" option e.g. docker-compose --profile flower up
# or by explicitly targeted on the command line e.g. docker-compose up flower.
# See: https://docs.docker.com/compose/profiles/
flower:
<<: *airflow-common
command: celery flower
profiles:
- flower
ports:
- 5555:5555
healthcheck:
test: ["CMD", "curl", "--fail", "http://localhost:5555/"]
interval: 10s
timeout: 10s
retries: 5
restart: always
depends_on:
<<: *airflow-common-depends-on
airflow-init:
condition: service_completed_successfully
volumes:
postgres-db-volume:

View file

@ -0,0 +1,17 @@
FROM recommendations_base
RUN apt-get update \
&& apt-get install -y gcc libc-dev g++ pkg-config libxml2-dev libxmlsec1-dev libxmlsec1-openssl \
&& apt-get clean
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY auth auth
COPY crons crons
COPY core core
COPY entrypoint.sh .
COPY api.py .
EXPOSE 7001
ENTRYPOINT ./entrypoint.sh

View file

@ -0,0 +1,83 @@
from fastapi import FastAPI, Depends
from contextlib import asynccontextmanager
from utils import pg_client
from core.model_handler import recommendation_model
from utils.declarations import FeedbackRecommendation
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from crons.base_crons import cron_jobs
from auth.auth_key import api_key_auth
from core import feedback
from fastapi.middleware.cors import CORSMiddleware
@asynccontextmanager
async def lifespan(app: FastAPI):
await pg_client.init()
await feedback.init()
await recommendation_model.update()
app.schedule.start()
for job in cron_jobs:
app.schedule.add_job(id=job['func'].__name__, **job)
yield
app.schedule.shutdown(wait=False)
await feedback.terminate()
await pg_client.terminate()
app = FastAPI(lifespan=lifespan)
app.schedule = AsyncIOScheduler()
origins = [
"*"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# @app.on_event('startup')
# async def startup():
# await pg_client.init()
# await feedback.init()
# await recommendation_model.update()
# app.schedule.start()
# for job in cron_jobs:
# app.schedule.add_job(id=job['func'].__name__, **job)
#
#
# @app.on_event('shutdown')
# async def shutdown():
# app.schedule.shutdown(wait=False)
# await feedback.terminate()
# await pg_client.terminate()
@app.get('/recommendations/{user_id}/{project_id}', dependencies=[Depends(api_key_auth)])
async def get_recommended_sessions(user_id: int, project_id: int):
recommendations = recommendation_model.get_recommendations(user_id, project_id)
return {'userId': user_id,
'projectId': project_id,
'recommendations': recommendations
}
@app.get('/recommendations/{projectId}/{viewerId}/{sessionId}', dependencies=[Depends(api_key_auth)])
async def already_gave_feedback(projectId: int, viewerId: int, sessionId: int):
return feedback.has_feedback((viewerId, sessionId, projectId))
@app.post('/recommendations/feedback', dependencies=[Depends(api_key_auth)])
async def get_feedback(data: FeedbackRecommendation):
try:
feedback.global_queue.put(tuple(data.dict().values()))
except Exception as e:
return {'error': e}
return {'success': 1}
@app.get('/')
async def health():
return {'status': 200}

View file

@ -0,0 +1,33 @@
from fastapi.security import OAuth2PasswordBearer
from fastapi import HTTPException, Depends, status
from decouple import config
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
class AuthHandler:
def __init__(self):
"""
Authorization method using an API key.
"""
self.__api_keys = [config("API_AUTH_KEY")]
def __contains__(self, api_key):
return api_key in self.__api_keys
def add_key(self, key):
"""Adds new key for authentication."""
self.__api_keys.append(key)
auth_method = AuthHandler()
def api_key_auth(api_key: str = Depends(oauth2_scheme)):
"""Method to verify auth."""
global auth_method
if api_key not in auth_method:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Forbidden"
)

View file

@ -0,0 +1,2 @@
cp ../../api/chalicelib/utils/ch_client.py utils
cp ../../../api/chalicelib/utils/pg_client.py utils

View file

@ -0,0 +1,132 @@
import json
import queue
import logging
from decouple import config
from time import time
from mlflow.store.db.utils import create_sqlalchemy_engine
from sqlalchemy.orm import sessionmaker, session
from sqlalchemy import text
from contextlib import contextmanager
global_queue = None
class ConnectionHandler:
_sessions = sessionmaker()
def __init__(self, uri):
"""Connects into mlflow database."""
self.engine = create_sqlalchemy_engine(uri)
@contextmanager
def get_live_session(self) -> session:
"""
This is a session that can be committed.
Changes will be reflected in the database.
"""
# Automatic transaction and connection handling in session
connection = self.engine.connect()
my_session = type(self)._sessions(bind=connection)
yield my_session
my_session.close()
connection.close()
class EventQueue:
def __init__(self, queue_max_length=50):
"""Saves all recommendations until queue_max_length (default 50) is reached
or max_retention_time surpassed (env value, default 1 hour)."""
self.events = queue.Queue()
self.events.maxsize = queue_max_length
host = config('pg_host_ml')
port = config('pg_port_ml')
user = config('pg_user_ml')
dbname = config('pg_dbname_ml')
password = config('pg_password_ml')
tracking_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{dbname}"
self.connection_handler = ConnectionHandler(tracking_uri)
self.last_flush = time()
self.max_retention_time = config('max_retention_time', default=60*60)
self.feedback_short_mem = list()
def flush(self, conn):
"""Insert recommendations into table recommendation_feedback from mlflow database."""
events = list()
params = dict()
i = 0
insertion_time = time()
while not self.events.empty():
user_id, session_id, project_id, payload = self.events.get()
params[f'user_id_{i}'] = user_id
params[f'session_id_{i}'] = session_id
params[f'project_id_{i}'] = project_id
params[f'payload_{i}'] = json.dumps(payload)
events.append(
f"(%(user_id_{i})s, %(session_id_{i})s, %(project_id_{i})s, %(payload_{i})s::jsonb, {insertion_time})")
i += 1
self.last_flush = time()
self.feedback_short_mem = list()
if i == 0:
return 0
cur = conn.connection().connection.cursor()
query = cur.mogrify(f"""INSERT INTO recommendation_feedback (user_id, session_id, project_id, payload, insertion_time) VALUES {' , '.join(events)};""", params)
conn.execute(text(query.decode("utf-8")))
conn.commit()
return 1
def force_flush(self):
"""Force method flush."""
if not self.events.empty():
try:
with self.connection_handler.get_live_session() as conn:
self.flush(conn)
except Exception as e:
print(f'Error: {e}')
def put(self, element):
"""Adds recommendation into the queue."""
current_time = time()
if self.events.full() or current_time - self.last_flush > self.max_retention_time:
try:
with self.connection_handler.get_live_session() as conn:
self.flush(conn)
except Exception as e:
print(f'Error: {e}')
self.events.put(element)
self.feedback_short_mem.append(element[:3])
self.events.task_done()
def already_has_feedback(self, element):
""""This method verifies if a feedback is already send for the current user-project-sessionId."""
if element[:3] in self.feedback_short_mem:
return True
else:
with self.connection_handler.get_live_session() as conn:
cur = conn.connection().connection.cursor()
query = cur.mogrify("SELECT * FROM recommendation_feedback WHERE user_id=%(user_id)s AND session_id=%(session_id)s AND project_id=%(project_id)s LIMIT 1",
{'user_id': element[0], 'session_id': element[1], 'project_id': element[2]})
cur_result = conn.execute(text(query.decode('utf-8')))
res = cur_result.fetchall()
return len(res) == 1
def has_feedback(data):
global global_queue
assert global_queue is not None, 'Global queue is not yet initialized'
return global_queue.already_has_feedback(data)
async def init():
global global_queue
global_queue = EventQueue()
print("> queue initialized")
async def terminate():
global global_queue
if global_queue is not None:
global_queue.force_flush()
print('> queue fulshed')

View file

@ -0,0 +1,164 @@
import mlflow
import hashlib
import numpy as np
from decouple import config
from utils import pg_client
from utils.df_utils import _process_pg_response
from time import time
host = config('pg_host_ml')
port = config('pg_port_ml')
user = config('pg_user_ml')
dbname = config('pg_dbname_ml')
password = config('pg_password_ml')
tracking_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{dbname}"
mlflow.set_tracking_uri(tracking_uri)
batch_download_size = config('batch_download_size', default=10, cast=int)
def get_latest_uri(projectId, tenantId):
client = mlflow.MlflowClient()
hashed = hashlib.sha256(bytes(f'{projectId}-{tenantId}'.encode('utf-8'))).hexdigest()
model_name = f'{hashed}-RecModel'
model_versions = client.search_model_versions(f"name='{model_name}'")
latest = -1
for i in range(len(model_versions)):
_v = model_versions[i].version
if _v < latest:
continue
else:
latest = _v
return f"runs:/{model_name}/{latest}"
def get_tenant(projectId):
with pg_client.PostgresClient() as conn:
conn.execute(
conn.mogrify("SELECT tenant_id FROM projects WHERE project_id=%(projectId)s", {'projectId': projectId})
)
res = conn.fetchone()
return res['tenant_id']
class ServedModel:
def __init__(self):
"""Handler of mlflow model."""
self.model = None
def load_model(self, model_name, model_version=1):
"""Load model from mlflow given the model version.
Params:
model_name: model name in mlflow repository
model_version: version of model to be downloaded"""
self.model = mlflow.pyfunc.load_model(f'models:/{model_name}/{model_version}')
def predict(self, X) -> np.ndarray:
"""Make prediction for batch X."""
assert self.model is not None, 'Model has to be loaded before predicting. See load_model.__doc__'
return self.model.predict(X)
def _sort_by_recommendation(self, sessions, sessions_features) -> np.ndarray:
"""Make prediction for sessions_features and sort them by relevance."""
pred = self.predict(sessions_features)
threshold = config('threshold_prediction', default=0.6, cast=float)
over_threshold = pred > threshold
pred = pred[over_threshold]
if len(pred) == 0:
return np.array([])
sorted_idx = np.argsort(pred)[::-1]
return sessions[over_threshold][sorted_idx]
def get_recommendations(self, userId, projectId):
"""Gets recommendations for userId for a given projectId.
Selects last unseen_selection_limit non seen sessions (env value, default 100)
and sort them by pertinence using ML model"""
limit = config('unseen_selection_limit', default=100, cast=int)
oldest_limit = 1000*(time() - config('unseen_max_days_ago_selection', default=30, cast=int)*60*60*24)
with pg_client.PostgresClient() as conn:
query = conn.mogrify(
"""SELECT project_id, session_id, user_id, %(userId)s as viewer_id, events_count, errors_count, duration, user_country as country, issue_score, user_device_type as device_type
FROM sessions
WHERE project_id = %(projectId)s AND session_id NOT IN (SELECT session_id FROM user_viewed_sessions WHERE user_id = %(userId)s) AND duration IS NOT NULL AND start_ts > %(oldest_limit)s LIMIT %(limit)s""",
{'userId': userId, 'projectId': projectId, 'limit': limit, 'oldest_limit': oldest_limit}
)
conn.execute(query)
res = conn.fetchall()
_X = list()
_Y = list()
X_project_ids = dict()
X_users_ids = dict()
X_sessions_ids = dict()
_process_pg_response(res, _X, _Y, X_project_ids, X_users_ids, X_sessions_ids, label=0)
return self._sort_by_recommendation(np.array(list(X_sessions_ids.keys())), _X).tolist()
class Recommendations:
def __init__(self):
"""Handler for multiple models.
Properties:
* names [dict]: names of current available models and its versions (model name as key).
* models [dict]: ServedModels objects (model name as key).
* to_download [list]: list of model name and version to be downloaded from mlflow server (in s3).
"""
self.names = dict()
self.models = dict()
self.to_download = list()
async def update(self):
"""Fill to_download list with new models or new version for saved models."""
r_models = mlflow.search_registered_models()
new_names = {m.name: max(m.latest_versions).version for m in r_models}
for name, version in new_names.items():
if (name, version) in self.names.items():
continue
self.to_download.append((name, version))
# self.download_model(name, version)
self.names = new_names
async def download_next(self):
"""Pop element from to_download, download and add it into models."""
download_loop_number = 0
if self.to_download:
while download_loop_number < batch_download_size:
try:
name, version = self.to_download.pop(0)
s_model = ServedModel()
s_model.load_model(name, version)
self.models[name] = s_model
download_loop_number += 1
except IndexError:
break
except Exception as e:
print('[Error] Found exception')
print(repr(e))
break
def download_model(self, name, version):
model = ServedModel()
model.load_model(name, version)
self.models[name] = model
def info(self):
"""Show current loaded models."""
print('Current models inside:')
for model_name, model in self.models.items():
print('Name:', model_name)
print(model.model)
def get_recommendations(self, userId, projectId, n_recommendations=5):
"""Gets recommendation for userId given the projectId.
This method selects the corresponding model and gets recommended sessions ordered by relevance."""
tenantId = get_tenant(projectId)
hashed = hashlib.sha256(bytes(f'{projectId}-{tenantId}'.encode('utf-8'))).hexdigest()
model_name = f'{hashed}-RecModel'
n_recommendations = config('number_of_recommendations', default=5, cast=int)
try:
model = self.models[model_name]
except KeyError:
return []
return model.get_recommendations(userId, projectId)[:n_recommendations]
recommendation_model = Recommendations()

View file

@ -0,0 +1,18 @@
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger
from core.model_handler import recommendation_model
async def update_model():
"""Update list of models to download."""
await recommendation_model.update()
async def download_model():
"""Download next model in list."""
await recommendation_model.download_next()
cron_jobs = [
{"func": update_model, "trigger": CronTrigger(hour=0), "misfire_grace_time": 60, "max_instances": 1},
{"func": download_model, "trigger": IntervalTrigger(seconds=10), "misfire_grace_time": 60, "max_instances": 1},
]

View file

@ -0,0 +1,2 @@
export MLFLOW_TRACKING_URI=postgresql+psycopg2://${pg_user_ml}:${pg_password_ml}@${pg_host_ml}:${pg_port_ml}/${pg_dbname_ml}
uvicorn api:app --host 0.0.0.0 --port 7001 --proxy-headers

View file

@ -0,0 +1,4 @@
fastapi==0.95.2
apscheduler==3.10.1
uvicorn==0.22.0
SQLAlchemy==2.0.15

View file

@ -0,0 +1 @@
docker run -e AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} -e AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} -e AWS_SESSION_TOKEN=${AWS_SESSION_TOKEN} -p 7001:7001 -t recommendation_service

View file

@ -0,0 +1,20 @@
DO
$do$
BEGIN
IF EXISTS (SELECT FROM pg_database WHERE datname = 'mlruns') THEN
RAISE NOTICE 'Database already exists'; -- optional
ELSE
PERFORM dblink_exec('dbname=' || current_database() -- current db
, 'CREATE DATABASE mlruns');
END IF;
END
$do$;
CREATE TABLE IF NOT EXISTS mlruns.public.recommendation_feedback
(
user_id BIGINT,
session_id BIGINT,
project_id BIGINT,
payload jsonb,
insertion_time BIGINT
);

View file

@ -0,0 +1,16 @@
from core.model_handler import Recommendations
from utils import pg_client
import asyncio
async def main():
await pg_client.init()
R = Recommendations()
R.to_download = [('****************************************************************-RecModel', 1)]
await R.download_next()
L = R.get_recommendations(000000000, 000000000)
print(L)
if __name__ == '__main__':
asyncio.run(main())

View file

@ -0,0 +1,18 @@
FROM recommendations_base
RUN apt-get update \
&& apt-get install -y gcc libc-dev g++ pkg-config libxml2-dev libxmlsec1-dev libxmlsec1-openssl git\
&& apt-get clean
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY airflow airflow
COPY entrypoint.sh .
COPY mlflow_server.sh .
COPY main.py .
EXPOSE 8080
EXPOSE 5000
ENTRYPOINT ./entrypoint.sh

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,128 @@
import asyncio
import hashlib
import mlflow
import os
import pendulum
import sys
from airflow import DAG
from airflow.operators.bash import BashOperator
from airflow.operators.python import PythonOperator, ShortCircuitOperator
from datetime import timedelta
from decouple import config
from textwrap import dedent
_work_dir = os.getcwd()
sys.path.insert(1, _work_dir)
from utils import pg_client
client = mlflow.MlflowClient()
models = [model.name for model in client.search_registered_models()]
def split_training(ti):
global models
projects = ti.xcom_pull(key='project_data').split(' ')
tenants = ti.xcom_pull(key='tenant_data').split(' ')
new_projects = list()
old_projects = list()
new_tenants = list()
old_tenants = list()
for i in range(len(projects)):
hashed = hashlib.sha256(bytes(f'{projects[i]}-{tenants[i]}'.encode('utf-8'))).hexdigest()
_model_name = f'{hashed}-RecModel'
if _model_name in models:
old_projects.append(projects[i])
old_tenants.append(tenants[i])
else:
new_projects.append(projects[i])
new_tenants.append(tenants[i])
ti.xcom_push(key='new_project_data', value=' '.join(new_projects))
ti.xcom_push(key='new_tenant_data', value=' '.join(new_tenants))
ti.xcom_push(key='old_project_data', value=' '.join(old_projects))
ti.xcom_push(key='old_tenant_data', value=' '.join(old_tenants))
def continue_new(ti):
L = ti.xcom_pull(key='new_project_data')
return len(L) > 0
def continue_old(ti):
L = ti.xcom_pull(key='old_project_data')
return len(L) > 0
def select_from_db(ti):
os.environ['PG_POOL'] = 'true'
asyncio.run(pg_client.init())
with pg_client.PostgresClient() as conn:
conn.execute("""SELECT tenant_id, project_id as project_id
FROM ((SELECT project_id
FROM frontend_signals
GROUP BY project_id
HAVING count(1) > 10) AS T1
INNER JOIN projects AS T2 USING (project_id));""")
res = conn.fetchall()
projects = list()
tenants = list()
for e in res:
projects.append(str(e['project_id']))
tenants.append(str(e['tenant_id']))
asyncio.run(pg_client.terminate())
ti.xcom_push(key='project_data', value=' '.join(projects))
ti.xcom_push(key='tenant_data', value=' '.join(tenants))
dag = DAG(
"first_test",
default_args={
"retries": 1,
"retry_delay": timedelta(minutes=3),
},
start_date=pendulum.datetime(2015, 12, 1, tz="UTC"),
description="My first test",
schedule=config('crons_train', default='@daily'),
catchup=False,
)
# assigning the task for our dag to do
with dag:
split = PythonOperator(
task_id='Split_Create_and_Retrain',
provide_context=True,
python_callable=split_training,
do_xcom_push=True
)
select_vp = PythonOperator(
task_id='Select_Valid_Projects',
provide_context=True,
python_callable=select_from_db,
do_xcom_push=True
)
dag_split1 = ShortCircuitOperator(
task_id='Create_Condition',
python_callable=continue_new,
)
dag_split2 = ShortCircuitOperator(
task_id='Retrain_Condition',
python_callable=continue_old,
)
new_models = BashOperator(
task_id='Create_Models',
bash_command=f"python {_work_dir}/main.py " + "--projects {{task_instance.xcom_pull(task_ids='Split_Create_and_Retrain', key='new_project_data')}} " +
"--tenants {{task_instance.xcom_pull(task_ids='Split_Create_and_Retrain', key='new_tenant_data')}}",
)
old_models = BashOperator(
task_id='Retrain_Models',
bash_command=f"python {_work_dir}/main.py " + "--projects {{task_instance.xcom_pull(task_ids='Split_Create_and_Retrain', key='old_project_data')}} " +
"--tenants {{task_instance.xcom_pull(task_ids='Split_Create_and_Retrain', key='old_tenant_data')}}",
)
select_vp >> split >> [dag_split1, dag_split2]
dag_split1 >> new_models
dag_split2 >> old_models

View file

@ -0,0 +1,2 @@
cp ../../api/chalicelib/utils/ch_client.py utils
cp ../../../api/chalicelib/utils/pg_client.py utils

View file

@ -0,0 +1,20 @@
# Values setup
find airflow/ -type f -name "*.cfg" -exec sed -i "s/{{pg_user_airflow}}/${pg_user_airflow}/g" {} \;
find airflow/ -type f -name "*.cfg" -exec sed -i "s/{{pg_password_airflow}}/${pg_password_airflow}/g" {} \;
find airflow/ -type f -name "*.cfg" -exec sed -i "s/{{pg_host_airflow}}/${pg_host_airflow}/g" {} \;
find airflow/ -type f -name "*.cfg" -exec sed -i "s/{{pg_port_airflow}}/${pg_port_airflow}/g" {} \;
find airflow/ -type f -name "*.cfg" -exec sed -i "s/{{pg_dbname_airflow}}/${pg_dbname_airflow}/g" {} \;
find airflow/ -type f -name "*.cfg" -exec sed -i "s#{{airflow_secret_key}}#${airflow_secret_key}#g" {} \;
export MLFLOW_TRACKING_URI=postgresql+psycopg2://${pg_user_ml}:${pg_password_ml}@${pg_host_ml}:${pg_port_ml}/${pg_dbname_ml}
git init airflow/dags
# Airflow setup
airflow db init
airflow users create \
--username admin \
--firstname admin \
--lastname admin \
--role Admin \
--email admin@admin.admin \
-p ${airflow_admin_password}
# Run services
airflow webserver --port 8080 & airflow scheduler & ./mlflow_server.sh

View file

@ -0,0 +1,118 @@
import mlflow
import hashlib
import argparse
import numpy as np
from decouple import config
from datetime import datetime
from core.user_features import get_training_database
from core.recommendation_model import SVM_recommendation, sort_database
mlflow.set_tracking_uri(config('MLFLOW_TRACKING_URI'))
def handle_database(x_train, y_train):
"""
Verifies if database is well-balanced. If not and if possible fixes it.
"""
total = len(y_train)
if total < 13:
return None, None
train_balance = y_train.sum() / total
if train_balance < 0.4:
positives = y_train[y_train == 1]
n_positives = len(positives)
x_positive = x_train[y_train == 1]
if n_positives < 7:
return None, None
else:
n_negatives_expected = min(int(n_positives/0.4), total-y_train.sum())
negatives = y_train[y_train == 0][:n_negatives_expected]
x_negative = x_train[y_train == 0][:n_negatives_expected]
return np.concatenate((x_positive, x_negative), axis=0), np.concatenate((negatives, positives), axis=0)
elif train_balance > 0.6:
negatives = y_train[y_train == 0]
n_negatives = len(negatives)
x_negative = x_train[y_train == 0]
if n_negatives < 7:
return None, None
else:
n_positives_expected = min(int(n_negatives / 0.4), y_train.sum())
positives = y_train[y_train == 1][:n_positives_expected]
x_positive = x_train[y_train == 1][:n_positives_expected]
return np.concatenate((x_positive, x_negative), axis=0), np.concatenate((negatives, positives), axis=0)
else:
return x_train, y_train
def main(experiment_name, projectId, tenantId):
"""
Main training method using mlflow for tracking and s3 for stocking.
Params:
experiment_name: experiment name for mlflow repo.
projectId: project id of sessions.
tenantId: tenant of the project id (used mainly as salt for hashing).
"""
hashed = hashlib.sha256(bytes(f'{projectId}-{tenantId}'.encode('utf-8'))).hexdigest()
x_, y_, d = get_training_database(projectId, max_timestamp=1680248412284, favorites=True)
x, y = handle_database(x_, y_)
if x is None:
print(f'[INFO] Project {projectId}: Not enough data to train model - {y_.sum()}/{len(y_)-y_.sum()}')
return
x, y = sort_database(x, y)
_experiment = mlflow.get_experiment_by_name(experiment_name)
if _experiment is None:
artifact_uri = config('MODELS_S3_BUCKET', default='./mlruns')
mlflow.create_experiment(experiment_name, artifact_uri)
mlflow.set_experiment(experiment_name)
with mlflow.start_run(run_name=f'{hashed}-{datetime.now().strftime("%Y-%M-%d_%H:%m")}'):
reg_model_name = f"{hashed}-RecModel"
best_meta = {'score': 0, 'model': None, 'name': 'NoName'}
for kernel in ['linear', 'poly', 'rbf', 'sigmoid']:
with mlflow.start_run(run_name=f'sub_run_with_{kernel}', nested=True):
print("--")
model = SVM_recommendation(kernel=kernel, test=True)
model.fit(x, y)
mlflow.sklearn.log_model(model, "sk_learn",
serialization_format="cloudpickle")
mlflow.log_param("kernel", kernel)
mlflow.log_metric("score", model.score)
for _name, displ in model.plots().items():
#TODO: Close displays not to overload memory
mlflow.log_figure(displ, f'{_name}.png')
if model.score > best_meta['score']:
best_meta['score'] = model.score
best_meta['model'] = model
best_meta['name'] = kernel
mlflow.log_metric("score", best_meta['score'])
mlflow.log_param("name", best_meta['name'])
mlflow.sklearn.log_model(best_meta['model'], "sk_learn",
serialization_format="cloudpickle",
registered_model_name=reg_model_name,
)
if __name__ == '__main__':
import asyncio
import os
os.environ['PG_POOL'] = 'true'
from utils import pg_client
asyncio.run(pg_client.init())
parser = argparse.ArgumentParser(
prog='Recommandation Trainer',
description='This python script aims to create a model able to predict which sessions may be most interesting to replay for the users',
)
parser.add_argument('--projects', type=int, nargs='+')
parser.add_argument('--tenants', type=int, nargs='+')
args = parser.parse_args()
projects = args.projects
tenants = args.tenants
for i in range(len(projects)):
print(f'Processing project {projects[i]}...')
main(experiment_name='s3-recommendations', projectId=projects[i], tenantId=tenants[i])
asyncio.run(pg_client.terminate())

View file

@ -0,0 +1 @@
mlflow server --backend-store-uri postgresql+psycopg2://${pg_user_ml}:${pg_password_ml}@${pg_host_ml}:${pg_port_ml}/${pg_dbname_ml} --default-artifact-root ${MODELS_S3_BUCKET} --host 0.0.0.0 --port 5000

View file

@ -0,0 +1,3 @@
argcomplete==3.0.8
apache-airflow==2.6.1
airflow-code-editor==7.2.1

View file

@ -0,0 +1,11 @@
DO
$do$
BEGIN
IF EXISTS (SELECT FROM pg_database WHERE datname = 'airflow') THEN
RAISE NOTICE 'Database already exists'; -- optional
ELSE
PERFORM dblink_exec('dbname=' || current_database() -- current db
, 'CREATE DATABASE airflow');
END IF;
END
$do$;

View file

@ -1,22 +0,0 @@
requests==2.28.1
urllib3==1.26.12
pyjwt==2.5.0
psycopg2-binary==2.9.3
numpy
threadpoolctl==3.1.0
joblib==1.2.0
scipy
scikit-learn
mlflow
airflow-code-editor
pydantic[email]==1.10.2
clickhouse-driver==0.2.4
python3-saml==1.14.0
python-multipart==0.0.5
python-decouple
argcomplete

View file

@ -0,0 +1,19 @@
requests==2.28.2
urllib3==1.26.12
pyjwt==2.6.0
SQLAlchemy==2.0.10
alembic==1.11.1
psycopg2-binary==2.9.5
joblib==1.2.0
scipy==1.10.1
scikit-learn==1.2.2
mlflow==2.3
clickhouse-driver==0.2.5
python3-saml==1.14.0
python-multipart==0.0.5
python-decouple==3.8
pydantic==1.10.8
boto3==1.26.100

View file

@ -1,11 +0,0 @@
echo 'Setting up required modules..'
mkdir scripts
mkdir plugins
mkdir logs
mkdir scripts/utils
cp ../../api/chalicelib/utils/pg_client.py scripts/utils
cp ../api/chalicelib/utils/ch_client.py scripts/utils
echo 'Building containers...'
docker-compose up airflow-init
echo 'Running containers...'
docker-compose up

View file

@ -1,161 +0,0 @@
from utils.ch_client import ClickHouseClient
from utils.pg_client import PostgresClient
def get_features_clickhouse(**kwargs):
"""Gets features from ClickHouse database"""
if 'limit' in kwargs:
limit = kwargs['limit']
else:
limit = 500
query = f"""SELECT session_id, project_id, user_id, events_count, errors_count, duration, country, issue_score, device_type, rage, jsexception, badrequest FROM (
SELECT session_id, project_id, user_id, events_count, errors_count, duration, toInt8(user_country) as country, issue_score, toInt8(user_device_type) as device_type FROM experimental.sessions WHERE user_id IS NOT NULL) as T1
INNER JOIN (SELECT session_id, project_id, sum(issue_type = 'click_rage') as rage, sum(issue_type = 'js_exception') as jsexception, sum(issue_type = 'bad_request') as badrequest FROM experimental.events WHERE event_type = 'ISSUE' AND session_id > 0 GROUP BY session_id, project_id LIMIT {limit}) as T2
ON T1.session_id = T2.session_id AND T1.project_id = T2.project_id;"""
with ClickHouseClient() as conn:
res = conn.execute(query)
return res
def get_features_postgres(**kwargs):
with PostgresClient() as conn:
funnels = query_funnels(conn, **kwargs)
metrics = query_metrics(conn, **kwargs)
filters = query_with_filters(conn, **kwargs)
#clean_filters(funnels)
#clean_filters(filters)
return clean_filters_split(funnels, isfunnel=True), metrics, clean_filters_split(filters)
def query_funnels(conn, **kwargs):
"""Gets Funnels (PG database)"""
# If public.funnel is empty
funnels_query = f"""SELECT project_id, user_id, filter FROM (SELECT project_id, user_id, metric_id FROM public.metrics WHERE metric_type='funnel'
) as T1 LEFT JOIN (SELECT filter, metric_id FROM public.metric_series) as T2 ON T1.metric_id = T2.metric_id"""
# Else
# funnels_query = "SELECT project_id, user_id, filter FROM public.funnels"
conn.execute(funnels_query)
res = conn.fetchall()
return res
def query_metrics(conn, **kwargs):
"""Gets Metrics (PG_database)"""
metrics_query = """SELECT metric_type, metric_of, metric_value, metric_format FROM public.metrics"""
conn.execute(metrics_query)
res = conn.fetchall()
return res
def query_with_filters(conn, **kwargs):
"""Gets Metrics with filters (PG database)"""
filters_query = """SELECT T1.metric_id as metric_id, project_id, name, metric_type, metric_of, filter FROM (
SELECT metric_id, project_id, name, metric_type, metric_of FROM metrics) as T1 INNER JOIN
(SELECT metric_id, filter FROM metric_series WHERE filter != '{}') as T2 ON T1.metric_id = T2.metric_id"""
conn.execute(filters_query)
res = conn.fetchall()
return res
def transform_funnel(project_id, user_id, data):
res = list()
for k in range(len(data)):
_tmp = data[k]
if _tmp['project_id'] != project_id or _tmp['user_id'] != user_id:
continue
else:
_tmp = _tmp['filter']['events']
res.append(_tmp)
return res
def transform_with_filter(data, *kwargs):
res = list()
for k in range(len(data)):
_tmp = data[k]
jump = False
for _key in kwargs.keys():
if data[_key] != kwargs[_key]:
jump = True
break
if jump:
continue
_type = data['metric_type']
if _type == 'funnel':
res.append(['funnel', _tmp['filter']['events']])
elif _type == 'timeseries':
res.append(['timeseries', _tmp['filter']['filters'], _tmp['filter']['events']])
elif _type == 'table':
res.append(['table', _tmp['metric_of'], _tmp['filter']['events']])
return res
def transform(element):
key_ = element.pop('user_id')
secondary_key_ = element.pop('session_id')
context_ = element.pop('project_id')
features_ = element
del element
return {(key_, context_): {secondary_key_: list(features_.values())}}
def get_by_project(data, project_id):
head_ = [list(d.keys())[0][1] for d in data]
index_ = [k for k in range(len(head_)) if head_[k] == project_id]
return [data[k] for k in index_]
def get_by_user(data, user_id):
head_ = [list(d.keys())[0][0] for d in data]
index_ = [k for k in range(len(head_)) if head_[k] == user_id]
return [data[k] for k in index_]
def clean_filters(data):
for j in range(len(data)):
_filter = data[j]['filter']
_tmp = list()
for i in range(len(_filter['filters'])):
if 'value' in _filter['filters'][i].keys():
_tmp.append({'type': _filter['filters'][i]['type'],
'value': _filter['filters'][i]['value'],
'operator': _filter['filters'][i]['operator']})
data[j]['filter'] = _tmp
def clean_filters_split(data, isfunnel=False):
_data = list()
for j in range(len(data)):
_filter = data[j]['filter']
_tmp = list()
for i in range(len(_filter['filters'])):
if 'value' in _filter['filters'][i].keys():
_type = _filter['filters'][i]['type']
_value = _filter['filters'][i]['value']
if isinstance(_value, str):
_value = [_value]
_operator = _filter['filters'][i]['operator']
if isfunnel:
_data.append({'project_id': data[j]['project_id'], 'user_id': data[j]['user_id'],
'type': _type,
'value': _value,
'operator': _operator
})
else:
_data.append({'metric_id': data[j]['metric_id'], 'project_id': data[j]['project_id'],
'name': data[j]['name'], 'metric_type': data[j]['metric_type'],
'metric_of': data[j]['metric_of'],
'type': _type,
'value': _value,
'operator': _operator
})
return _data
def test():
print('One test')
if __name__ == '__main__':
print('Just a test')
#data = get_features_clickhouse()
#print('Data length:', len(data))

View file

@ -1,15 +0,0 @@
from sklearn.svm import SVC
class SVM_recommendation():
def __init__(**params):
f"""{SVC.__doc__}"""
self.svm = SVC(params)
def fit(self, X1=None, X2=None):
assert X1 is not None or X2 is not None, 'X1 or X2 must be given'
self.svm.fit(X1)
self.svm.fit(X2)
def predict(self, X):
return self.svm.predict(X)

View file

@ -1,60 +0,0 @@
import mlflow
##
import numpy as np
import pickle
from sklearn import datasets, linear_model
from sklearn.metrics import mean_squared_error, r2_score
# source: https://scikit-learn.org/stable/auto_examples/linear_model/plot_ols.html
# Load the diabetes dataset
diabetes_X, diabetes_y = datasets.load_diabetes(return_X_y=True)
# Use only one feature
diabetes_X = diabetes_X[:, np.newaxis, 2]
# Split the data into training/testing sets
diabetes_X_train = diabetes_X[:-20]
diabetes_X_test = diabetes_X[-20:]
# Split the targets into training/testing sets
diabetes_y_train = diabetes_y[:-20]
diabetes_y_test = diabetes_y[-20:]
def print_predictions(m, y_pred):
# The coefficients
print('Coefficients: \n', m.coef_)
# The mean squared error
print('Mean squared error: %.2f'
% mean_squared_error(diabetes_y_test, y_pred))
# The coefficient of determination: 1 is perfect prediction
print('Coefficient of determination: %.2f'
% r2_score(diabetes_y_test, y_pred))
# Create linear regression object
lr_model = linear_model.LinearRegression()
# Train the model using the training sets
lr_model.fit(diabetes_X_train, diabetes_y_train)
# Make predictions using the testing set
diabetes_y_pred = lr_model.predict(diabetes_X_test)
print_predictions(lr_model, diabetes_y_pred)
# save the model in the native sklearn format
filename = 'lr_model.pkl'
pickle.dump(lr_model, open(filename, 'wb'))
##
# load the model into memory
loaded_model = pickle.load(open(filename, 'rb'))
# log and register the model using MLflow scikit-learn API
mlflow.set_tracking_uri("postgresql+psycopg2://airflow:airflow@postgres/mlruns")
reg_model_name = "SklearnLinearRegression"
print("--")
mlflow.sklearn.log_model(loaded_model, "sk_learn",
serialization_format="cloudpickle",
registered_model_name=reg_model_name)

View file

@ -1,42 +0,0 @@
import time
import argparse
from core import features
from utils import pg_client
import multiprocessing as mp
from decouple import config
import asyncio
import pandas
def features_ch(q):
q.put(features.get_features_clickhouse())
def features_pg(q):
q.put(features.get_features_postgres())
def get_features():
#mp.set_start_method('spawn')
#q = mp.Queue()
#p1 = mp.Process(target=features_ch, args=(q,))
#p1.start()
pg_features = features.get_features_postgres()
ch_features = []#p1.join()
return [pg_features, ch_features]
parser = argparse.ArgumentParser(description='Gets and process data from Postgres and ClickHouse.')
parser.add_argument('--batch_size', type=int, required=True, help='--batch_size max size of columns per file to be saved in opt/airflow/cache')
args = parser.parse_args()
if __name__ == '__main__':
asyncio.run(pg_client.init())
print(args)
t1 = time.time()
data = get_features()
#print(data)
cache_dir = config("data_dir", default=f"/opt/airflow/cache")
for d in data[0]:
pandas.DataFrame(d).to_csv(f'{cache_dir}/tmp-{hash(time.time())}', sep=',')
t2 = time.time()
print(f'DONE! information retrieved in {t2-t1: .2f} seconds')

View file

@ -1,41 +0,0 @@
import time
import argparse
from decouple import config
from core import recommendation_model
import pandas
import json
import os
def transform_dict_string(s_dicts):
data = list()
for s_dict in s_dicts:
data.append(json.loads(s_dict.replace("'", '"').replace('None','null').replace('False','false')))
return data
def process_file(file_name):
return pandas.read_csv(file_name, sep=",")
def read_batches():
base_dir = config('dir_path', default='/opt/airflow/cache')
files = os.listdir(base_dir)
for file in files:
yield process_file(f'{base_dir}/{file}')
parser = argparse.ArgumentParser(description='Handle machine learning inputs.')
parser.add_argument('--mode', choices=['train', 'test'], required=True, help='--mode sets the model in train or test mode')
parser.add_argument('--kernel', default='linear', help='--kernel set the kernel to be used for SVM')
args = parser.parse_args()
if __name__ == '__main__':
print(args)
t1 = time.time()
buff = read_batches()
for b in buff:
print(b.head())
t2 = time.time()
print(f'DONE! information retrieved in {t2-t1: .2f} seconds')

View file

@ -1,11 +0,0 @@
CREATE TABLE IF NOT EXISTS frontend_signals
(
project_id bigint NOT NULL,
user_id text NOT NULL,
timestamp bigint NOT NULL,
action text NOT NULL,
source text NOT NULL,
category text NOT NULL,
data json
);
CREATE INDEX IF NOT EXISTS frontend_signals_user_id_idx ON frontend_signals (user_id);

View file

@ -0,0 +1,59 @@
from pydantic import BaseModel, Field
class FeedbackRecommendation(BaseModel):
viewerId: int = Field(...)
sessionId: int = Field(...)
projectId: int = Field(...)
payload: dict = Field(default=dict())
class DeviceValue:
device_types = ['other', 'desktop', 'mobile']
def __init__(self, device_type):
if isinstance(device_type, str):
try:
self.id = self.device_types.index(device_type)
except ValueError:
self.id = 0
self.name = device_type
elif isinstance(device_type, int):
self.id = device_type
self.name = self.device_types[device_type]
def __repr__(self):
return str(self.id)
def __str__(self):
return self.name
def get_int_val(self):
return self.id
def get_str_val(self):
return self.name
class CountryValue:
countries = ['UN', 'RW', 'SO', 'YE', 'IQ', 'SA', 'IR', 'CY', 'TZ', 'SY', 'AM', 'KE', 'CD', 'DJ', 'UG', 'CF', 'SC', 'JO', 'LB', 'KW', 'OM', 'QA', 'BH', 'AE', 'IL', 'TR', 'ET', 'ER', 'EG', 'SD', 'GR', 'BI', 'EE', 'LV', 'AZ', 'LT', 'SJ', 'GE', 'MD', 'BY', 'FI', 'AX', 'UA', 'MK', 'HU', 'BG', 'AL', 'PL', 'RO', 'XK', 'ZW', 'ZM', 'KM', 'MW', 'LS', 'BW', 'MU', 'SZ', 'RE', 'ZA', 'YT', 'MZ', 'MG', 'AF', 'PK', 'BD', 'TM', 'TJ', 'LK', 'BT', 'IN', 'MV', 'IO', 'NP', 'MM', 'UZ', 'KZ', 'KG', 'TF', 'HM', 'CC', 'PW', 'VN', 'TH', 'ID', 'LA', 'TW', 'PH', 'MY', 'CN', 'HK', 'BN', 'MO', 'KH', 'KR', 'JP', 'KP', 'SG', 'CK', 'TL', 'RU', 'MN', 'AU', 'CX', 'MH', 'FM', 'PG', 'SB', 'TV', 'NR', 'VU', 'NC', 'NF', 'NZ', 'FJ', 'LY', 'CM', 'SN', 'CG', 'PT', 'LR', 'CI', 'GH', 'GQ', 'NG', 'BF', 'TG', 'GW', 'MR', 'BJ', 'GA', 'SL', 'ST', 'GI', 'GM', 'GN', 'TD', 'NE', 'ML', 'EH', 'TN', 'ES', 'MA', 'MT', 'DZ', 'FO', 'DK', 'IS', 'GB', 'CH', 'SE', 'NL', 'AT', 'BE', 'DE', 'LU', 'IE', 'MC', 'FR', 'AD', 'LI', 'JE', 'IM', 'GG', 'SK', 'CZ', 'NO', 'VA', 'SM', 'IT', 'SI', 'ME', 'HR', 'BA', 'AO', 'NA', 'SH', 'BV', 'BB', 'CV', 'GY', 'GF', 'SR', 'PM', 'GL', 'PY', 'UY', 'BR', 'FK', 'GS', 'JM', 'DO', 'CU', 'MQ', 'BS', 'BM', 'AI', 'TT', 'KN', 'DM', 'AG', 'LC', 'TC', 'AW', 'VG', 'VC', 'MS', 'MF', 'BL', 'GP', 'GD', 'KY', 'BZ', 'SV', 'GT', 'HN', 'NI', 'CR', 'VE', 'EC', 'CO', 'PA', 'HT', 'AR', 'CL', 'BO', 'PE', 'MX', 'PF', 'PN', 'KI', 'TK', 'TO', 'WF', 'WS', 'NU', 'MP', 'GU', 'PR', 'VI', 'UM', 'AS', 'CA', 'US', 'PS', 'RS', 'AQ', 'SX', 'CW', 'BQ', 'SS', 'BU', 'VD', 'YD', 'DD']
def __init__(self, country):
if isinstance(country, str):
self.id = -128+self.countries.index(country)
self.name = country
elif isinstance(country, int):
self.id = country
self.name = self.countries[128+country]
def __repr__(self):
return str(self.id)
def __str__(self):
return self.name
def get_int_val(self):
return self.id
def get_str_val(self):
return self.name

View file

@ -0,0 +1,24 @@
from utils.declarations import CountryValue, DeviceValue
def _add_to_dict(element, index, dictionary):
if element not in dictionary.keys():
dictionary[element] = [index]
else:
dictionary[element].append(index)
def _process_pg_response(res, _X, _Y, X_project_ids, X_users_ids, X_sessions_ids, label=None):
for i in range(len(res)):
x = res[i]
_add_to_dict(x.pop('project_id'), i, X_project_ids)
_add_to_dict(x.pop('session_id'), i, X_sessions_ids)
_add_to_dict(x.pop('user_id'), i, X_users_ids)
if label is None:
_Y.append(int(x.pop('train_label')))
else:
_Y.append(label)
x['country'] = CountryValue(x['country']).get_int_val()
x['device_type'] = DeviceValue(x['device_type']).get_int_val()
_X.append(list(x.values()))

View file

@ -10,12 +10,35 @@ from psycopg2 import pool
logging.basicConfig(level=config("LOGLEVEL", default=logging.INFO)) logging.basicConfig(level=config("LOGLEVEL", default=logging.INFO))
logging.getLogger('apscheduler').setLevel(config("LOGLEVEL", default=logging.INFO)) logging.getLogger('apscheduler').setLevel(config("LOGLEVEL", default=logging.INFO))
_PG_CONFIG = {"host": config("pg_host"), conn_str = config('string_connection', default='')
if conn_str == '':
_PG_CONFIG = {"host": config("pg_host"),
"database": config("pg_dbname"), "database": config("pg_dbname"),
"user": config("pg_user"), "user": config("pg_user"),
"password": config("pg_password"), "password": config("pg_password"),
"port": config("pg_port", cast=int), "port": config("pg_port", cast=int),
"application_name": config("APP_NAME", default="PY")} "application_name": config("APP_NAME", default="PY")}
else:
import urllib.parse
conn_str = urllib.parse.unquote(conn_str)
usr_info, host_info = conn_str.split('@')
i = usr_info.find('://')
pg_user, pg_password = usr_info[i+3:].split(':')
host_info, pg_dbname = host_info.split('/')
i = host_info.find(':')
if i == -1:
pg_host = host_info
pg_port = 5432
else:
pg_host, pg_port = host_info.split(':')
pg_port = int(pg_port)
_PG_CONFIG = {"host": pg_host,
"database": pg_dbname,
"user": pg_user,
"password": pg_password,
"port": pg_port,
"application_name": config("APP_NAME", default="PY")}
PG_CONFIG = dict(_PG_CONFIG) PG_CONFIG = dict(_PG_CONFIG)
if config("PG_TIMEOUT", cast=int, default=0) > 0: if config("PG_TIMEOUT", cast=int, default=0) > 0:
PG_CONFIG["options"] = f"-c statement_timeout={config('PG_TIMEOUT', cast=int) * 1000}" PG_CONFIG["options"] = f"-c statement_timeout={config('PG_TIMEOUT', cast=int) * 1000}"
@ -87,9 +110,10 @@ class PostgresClient:
long_query = False long_query = False
unlimited_query = False unlimited_query = False
def __init__(self, long_query=False, unlimited_query=False): def __init__(self, long_query=False, unlimited_query=False, use_pool=True):
self.long_query = long_query self.long_query = long_query
self.unlimited_query = unlimited_query self.unlimited_query = unlimited_query
self.use_pool = use_pool
if unlimited_query: if unlimited_query:
long_config = dict(_PG_CONFIG) long_config = dict(_PG_CONFIG)
long_config["application_name"] += "-UNLIMITED" long_config["application_name"] += "-UNLIMITED"
@ -100,7 +124,7 @@ class PostgresClient:
long_config["options"] = f"-c statement_timeout=" \ long_config["options"] = f"-c statement_timeout=" \
f"{config('pg_long_timeout', cast=int, default=5 * 60) * 1000}" f"{config('pg_long_timeout', cast=int, default=5 * 60) * 1000}"
self.connection = psycopg2.connect(**long_config) self.connection = psycopg2.connect(**long_config)
elif not config('PG_POOL', cast=bool, default=True): elif not use_pool or not config('PG_POOL', cast=bool, default=True):
single_config = dict(_PG_CONFIG) single_config = dict(_PG_CONFIG)
single_config["application_name"] += "-NOPOOL" single_config["application_name"] += "-NOPOOL"
single_config["options"] = f"-c statement_timeout={config('PG_TIMEOUT', cast=int, default=30) * 1000}" single_config["options"] = f"-c statement_timeout={config('PG_TIMEOUT', cast=int, default=30) * 1000}"
@ -111,6 +135,8 @@ class PostgresClient:
def __enter__(self): def __enter__(self):
if self.cursor is None: if self.cursor is None:
self.cursor = self.connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor) self.cursor = self.connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
self.cursor.cursor_execute = self.cursor.execute
self.cursor.execute = self.__execute
self.cursor.recreate = self.recreate_cursor self.cursor.recreate = self.recreate_cursor
return self.cursor return self.cursor
@ -118,11 +144,12 @@ class PostgresClient:
try: try:
self.connection.commit() self.connection.commit()
self.cursor.close() self.cursor.close()
if self.long_query or self.unlimited_query: if not self.use_pool or self.long_query or self.unlimited_query:
self.connection.close() self.connection.close()
except Exception as error: except Exception as error:
logging.error("Error while committing/closing PG-connection", error) logging.error("Error while committing/closing PG-connection", error)
if str(error) == "connection already closed" \ if str(error) == "connection already closed" \
and self.use_pool \
and not self.long_query \ and not self.long_query \
and not self.unlimited_query \ and not self.unlimited_query \
and config('PG_POOL', cast=bool, default=True): and config('PG_POOL', cast=bool, default=True):
@ -132,10 +159,22 @@ class PostgresClient:
raise error raise error
finally: finally:
if config('PG_POOL', cast=bool, default=True) \ if config('PG_POOL', cast=bool, default=True) \
and self.use_pool \
and not self.long_query \ and not self.long_query \
and not self.unlimited_query: and not self.unlimited_query:
postgreSQL_pool.putconn(self.connection) postgreSQL_pool.putconn(self.connection)
def __execute(self, query, vars=None):
try:
result = self.cursor.cursor_execute(query=query, vars=vars)
except psycopg2.Error as error:
logging.error(f"!!! Error of type:{type(error)} while executing query:")
logging.error(query)
logging.info("starting rollback to allow future execution")
self.connection.rollback()
raise error
return result
def recreate_cursor(self, rollback=False): def recreate_cursor(self, rollback=False):
if rollback: if rollback:
try: try:

View file

@ -48,6 +48,9 @@ Below is the list of dependencies used in OpenReplay software. Licenses may chan
| pandas | BSD3 | Python | | pandas | BSD3 | Python |
| numpy | BSD3 | Python | | numpy | BSD3 | Python |
| scikit-learn | BSD3 | Python | | scikit-learn | BSD3 | Python |
|apache-airflow| Apache2| Python|
|airflow-code-editor| Apache2 | Python|
|mlflow| Apache2 | Python|
| sqlalchemy | MIT | Python | | sqlalchemy | MIT | Python |
| pandas-redshift | MIT | Python | | pandas-redshift | MIT | Python |
| confluent-kafka | Apache2 | Python | | confluent-kafka | Apache2 | Python |