diff --git a/ee/recommendation/.gitignore b/ee/recommendation/.gitignore new file mode 100644 index 000000000..a4224b0bf --- /dev/null +++ b/ee/recommendation/.gitignore @@ -0,0 +1,170 @@ +### Example user template template +### Example user template + +# IntelliJ project files +.idea +*.iml +out +gen +### Python template +# Byte-compiled / optimized / DLL files +./**/__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + diff --git a/ee/recommendation/Dockerfile b/ee/recommendation/Dockerfile index 992bcf89a..d17e4e4ed 100644 --- a/ee/recommendation/Dockerfile +++ b/ee/recommendation/Dockerfile @@ -1,14 +1,14 @@ -FROM apache/airflow:2.4.3 -COPY requirements.txt . +FROM python:3.10-slim-buster -USER root RUN apt-get update \ - && apt-get install -y \ - vim \ - && apt-get install gcc libc-dev g++ -y \ - && apt-get install -y pkg-config libxml2-dev libxmlsec1-dev libxmlsec1-openssl + && apt-get install -y gcc libc-dev g++ pkg-config libxml2-dev libxmlsec1-dev libxmlsec1-openssl \ + && apt-get clean +WORKDIR app +COPY requirements_base.txt . + +RUN pip install --no-cache-dir -r requirements_base.txt + +COPY core core +COPY utils utils -USER airflow -RUN pip install --upgrade pip -RUN pip install -r requirements.txt diff --git a/ee/recommendation/README.md b/ee/recommendation/README.md new file mode 100644 index 000000000..60a397a1b --- /dev/null +++ b/ee/recommendation/README.md @@ -0,0 +1,78 @@ +# Recommendations + +## index +1. [Build image](#build-image) + 1. [Recommendations service image](#recommendations-service-image) + 2. [Trainer service image](#trainer-service-image) +2. [Trainer](#trainer-service) + 1. [Env params](#trainer-env-params) +3. [Recommendations](#recommendation-service) + 1. [Env params](#recommendation-env-params) + +## Build image +In order to build both recommendation image and trainer image, first a base image should be created by running the following command: +```bash +docker build -t recommendations_base . +``` +which will add the files from `core` and `utils` which are common between `ml_service` and `ml_trainer` and will install common dependencies. + +### Recommendations service image +Inside `ml_service` run docker build to create the recommendation service image +```bash +cd ml_service/ +docker build -t recommendations . +cd ../ +``` +### Trainer service image +Inside `ml_trainer` run docker build to create the recommendation service image +```bash +cd ml_trainer/ +docker build -t trainer . +cd ../ +``` +## Trainer service +The trainer is an orchestration service which is in charge of training models and saving models into S3. +This is made using Directed Acyclic Graphs (DAGs) in [Airflow](https://airflow.apache.org) for orchestration +and [MLflow](https://mlflow.org) as a monitoring service for training model that creates a registry over S3. +### Trainer env params +```bash + pg_host= + pg_port= + pg_user= + pg_password= + pg_dbname= + pg_host_ml= + pg_port_ml= + pg_user_ml= + pg_password_ml= + pg_dbname_ml='mlruns' + PG_POOL='true' + MODELS_S3_BUCKET= #'s3://path/to/bucket' + pg_user_airflow= + pg_password_airflow= + pg_dbname_airflow='airflow' + pg_host_airflow= + pg_port_airflow= + AIRFLOW_HOME=/app/airflow + airflow_secret_key= + airflow_admin_password= + crons_train='0 0 * * *' +``` +## Recommendation service +The recommendation service is a [FastAPI](https://fastapi.tiangolo.com) server that uses MLflow to read models from S3 +and serve them, it also takes feedback from user and saves it into postgres for retraining purposes. +### Recommendation env params +```bash + pg_host= + pg_port= + pg_user= + pg_password= + pg_dbname= + pg_host_ml= + pg_port_ml= + pg_user_ml= + pg_password_ml= + pg_dbname_ml='mlruns' + PG_POOL='true' + API_AUTH_KEY= +``` \ No newline at end of file diff --git a/ee/recommendation/clean.sh b/ee/recommendation/clean.sh deleted file mode 100644 index 857c8d63d..000000000 --- a/ee/recommendation/clean.sh +++ /dev/null @@ -1 +0,0 @@ -docker-compose down --volumes --rmi all diff --git a/ee/recommendation/core/recommendation_model.py b/ee/recommendation/core/recommendation_model.py new file mode 100644 index 000000000..823fd5093 --- /dev/null +++ b/ee/recommendation/core/recommendation_model.py @@ -0,0 +1,128 @@ +import mlflow.pyfunc +import random +import numpy as np +from sklearn import metrics +from sklearn.svm import SVC +from sklearn.feature_selection import SequentialFeatureSelector as sfs +from sklearn.preprocessing import normalize +from sklearn.decomposition import PCA +from sklearn.neighbors import KNeighborsClassifier as knc + + +def select_features(X, y): + """ + Dimensional reduction of X using k-nearest neighbors and sequential feature selector. + Final dimension set to three features. + Params: + X: Array which will be reduced in dimension (batch_size, n_features). + y: Array of labels (batch_size,). + Output: function that reduces dimension of array. + """ + knn = knc(n_neighbors=3) + selector = sfs(knn, n_features_to_select=3) + X_transformed = selector.fit_transform(X, y) + + def transform(input): + return selector.transform(input) + return transform, X_transformed + + +def sort_database(X, y): + """ + Random shuffle of training values with its respective labels. + Params: + X: Array of features. + y: Array of labels. + Output: Tuple (X_rand_sorted, y_rand_sorted). + """ + sort_list = list(range(len(y))) + random.shuffle(sort_list) + return X[sort_list], y[sort_list] + + +def preprocess(X): + """ + Preprocessing of features (no dimensional reduction) using principal component analysis. + Params: + X: Array of features. + Output: Tuple (processed array of features function that reduces dimension of array). + """ + _, n = X.shape + pca = PCA(n_components=n) + x = pca.fit_transform(normalize(X)) + + def transform(input): + return pca.transform(normalize(input)) + + return x, transform + + +class SVM_recommendation(mlflow.pyfunc.PythonModel): + + def __init__(self, test=False, **params): + f"""{SVC.__doc__}""" + params['probability'] = True + self.svm = SVC(**params) + self.transforms = [] + self.score = 0 + self.confusion_matrix = None + if test: + knn = knc(n_neighbors=3) + self.transform = [PCA(n_components=3), sfs(knn, n_features_to_select=2)] + + def fit(self, X, y): + """ + Train preprocess function, feature selection and Support Vector Machine model + Params: + X: Array of features. + y: Array of labels. + """ + assert X.shape[0] == y.shape[0], 'X and y must have same length' + assert len(X.shape) == 2, 'X must be a two dimension vector' + X, t1 = preprocess(X) + t2, X = select_features(X, y) + self.transforms = [t1, t2] + self.svm.fit(X, y) + pred = self.svm.predict(X) + z = y + 2 * pred + n = len(z) + false_pos = np.count_nonzero(z == 1) / n + false_neg = np.count_nonzero(z == 2) / n + true_pos = np.count_nonzero(z == 3) / n + true_neg = 1 - false_neg - false_pos - true_pos + self.confusion_matrix = np.array([[true_neg, false_pos], [false_neg, true_pos]]) + self.score = true_pos + true_neg + + + def predict(self, x): + """ + Transform and prediction of input features and sorting of each by probability + Params: + X: Array of features. + Output: prediction probability for True (1). + """ + for t in self.transforms: + x = t(x) + return self.svm.predict_proba(x)[:, 1] + + def recommendation_order(self, x): + """ + Transform and prediction of input features and sorting of each by probability + Params: + X: Array of features. + Output: Tuple (sorted_features, predictions). + """ + for t in self.transforms: + x = t(x) + pred = self.svm.predict_proba(x) + return sorted(range(len(pred)), key=lambda k: pred[k][1], reverse=True), pred + + def plots(self): + """ + Returns the plots in a dict format. + { + 'confusion_matrix': confusion matrix figure, + } + """ + display = metrics.ConfusionMatrixDisplay(confusion_matrix=self.confusion_matrix, display_labels=[False, True]) + return {'confusion_matrix': display.plot().figure_} diff --git a/ee/recommendation/core/user_features.py b/ee/recommendation/core/user_features.py new file mode 100644 index 000000000..72fbfdd24 --- /dev/null +++ b/ee/recommendation/core/user_features.py @@ -0,0 +1,137 @@ +from utils.pg_client import PostgresClient +from decouple import config +from utils.df_utils import _process_pg_response +import numpy as np + + +def get_training_database(projectId, max_timestamp=None, favorites=False): + """ + Gets training database using projectId, max_timestamp [optional] and favorites (if true adds favorites) + Params: + projectId: project id of all sessions to be selected. + max_timestamp: max timestamp that a not seen session can have in order to be considered not interesting. + favorites: True to use favorite sessions as interesting sessions reference. + Output: Tuple (Set of features, set of labels, dict of indexes of each project_id, session_id, user_id in the set) + """ + args = {"projectId": projectId, "max_timestamp": max_timestamp, "limit": 20} + with PostgresClient() as conn: + x1 = signals_features(conn, **args) + if favorites: + x2 = user_favorite_sessions(args['projectId'], conn) + if max_timestamp is not None: + x3 = user_not_seen_sessions(args['projectId'], args['limit'], conn) + + X_project_ids = dict() + X_users_ids = dict() + X_sessions_ids = dict() + + _X = list() + _Y = list() + _process_pg_response(x1, _X, _Y, X_project_ids, X_users_ids, X_sessions_ids, label=None) + if favorites: + _process_pg_response(x2, _X, _Y, X_project_ids, X_users_ids, X_sessions_ids, label=1) + if max_timestamp: + _process_pg_response(x3, _X, _Y, X_project_ids, X_users_ids, X_sessions_ids, label=0) + return np.array(_X), np.array(_Y), \ + {'project_id': X_project_ids, + 'user_id': X_users_ids, + 'session_id': X_sessions_ids} + + +def signals_features(conn, **kwargs): + """ + Selects features from frontend_signals table and mark as interesting given the following conditions: + * If number of events is greater than events_threshold (default=10). (env value) + * If session has been replayed more than once. + """ + assert 'projectId' in kwargs.keys(), 'projectId should be provided in kwargs' + projectId = kwargs['projectId'] + events_threshold = config('events_threshold', default=10, cast=int) + query = conn.mogrify("""SELECT T.project_id, + T.session_id, + T.user_id, + T2.viewer_id, + T.events_count, + T.errors_count, + T.duration, + T.country, + T.issue_score, + T.device_type, + T2.interesting as train_label + FROM (SELECT project_id, + user_id as viewer_id, + session_id, + count(CASE WHEN source = 'replay' THEN 1 END) > 1 OR COUNT(1) > %(events_threshold)s as interesting + FROM frontend_signals + WHERE project_id = %(projectId)s + AND session_id is not null + GROUP BY project_id, viewer_id, session_id) as T2 + INNER JOIN (SELECT project_id, + session_id, + user_id as viewer_id, + user_id, + events_count, + errors_count, + duration, + user_country as country, + issue_score, + user_device_type as device_type + FROM sessions + WHERE project_id = %(projectId)s + AND duration IS NOT NULL) as T + USING (session_id);""", + {"projectId": projectId, "events_threshold": events_threshold}) + conn.execute(query) + res = conn.fetchall() + return res + + +def user_favorite_sessions(projectId, conn): + """ + Selects features from user_favorite_sessions table. + """ + query = """SELECT project_id, + session_id, + T1.user_id, + events_count, + errors_count, + duration, + user_country as country, + issue_score, + user_device_type as device_type, + T2.user_id AS viewer_id + FROM sessions AS T1 + INNER JOIN user_favorite_sessions as T2 + USING (session_id) + WHERE project_id = %(projectId)s;""" + conn.execute( + conn.mogrify(query, {"projectId": projectId}) + ) + res = conn.fetchall() + return res + + +def user_not_seen_sessions(projectId, limit, conn): + """ + Selects features from user_viewed_sessions table. + """ + # TODO: fetch un-viewed sessions alone, and the users list alone, then cross join them in python + # and ignore deleted users (WHERE users.deleted_at ISNULL) + query = """SELECT project_id, session_id, user_id, viewer_id, events_count, errors_count, duration, user_country as country, issue_score, user_device_type as device_type +FROM ( + (SELECT sessions.* + FROM sessions LEFT JOIN user_viewed_sessions USING(session_id) + WHERE project_id = %(projectId)s + AND duration IS NOT NULL + AND user_viewed_sessions.session_id ISNULL + LIMIT %(limit)s) AS T1 + LEFT JOIN + (SELECT user_id as viewer_id + FROM users + WHERE tenant_id = (SELECT tenant_id FROM projects WHERE project_id = %(projectId)s)) AS T2 ON true + )""" + conn.execute( + conn.mogrify(query, {"projectId": projectId, "limit": limit}) + ) + res = conn.fetchall() + return res diff --git a/ee/recommendation/dags/training_dag.py b/ee/recommendation/dags/training_dag.py deleted file mode 100644 index ff340f772..000000000 --- a/ee/recommendation/dags/training_dag.py +++ /dev/null @@ -1,46 +0,0 @@ -from datetime import datetime, timedelta -from textwrap import dedent - -import pendulum - -from airflow import DAG -from airflow.operators.bash import BashOperator -from airflow.operators.python import PythonOperator -import os -_work_dir = os.getcwd() - -def my_function(): - l = os.listdir('scripts') - print(l) - return l - -dag = DAG( - "first_test", - default_args={ - "depends_on_past": True, - "retries": 1, - "retry_delay": timedelta(minutes=3), - }, - start_date=pendulum.datetime(2015, 12, 1, tz="UTC"), - description="My first test", - schedule="@daily", - catchup=False, -) - - -#assigning the task for our dag to do -with dag: - first_world = PythonOperator( - task_id='FirstTest', - python_callable=my_function, - ) - hello_world = BashOperator( - task_id='OneTest', - bash_command=f'python {_work_dir}/scripts/processing.py --batch_size 500', - # provide_context=True - ) - this_world = BashOperator( - task_id='ThisTest', - bash_command=f'python {_work_dir}/scripts/task.py --mode train --kernel linear', - ) - first_world >> hello_world >> this_world diff --git a/ee/recommendation/docker-compose.yaml b/ee/recommendation/docker-compose.yaml deleted file mode 100644 index d7d068551..000000000 --- a/ee/recommendation/docker-compose.yaml +++ /dev/null @@ -1,285 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -# Basic Airflow cluster configuration for CeleryExecutor with Redis and PostgreSQL. -# -# WARNING: This configuration is for local development. Do not use it in a production deployment. -# -# This configuration supports basic configuration using environment variables or an .env file -# The following variables are supported: -# -# AIRFLOW_IMAGE_NAME - Docker image name used to run Airflow. -# Default: apache/airflow:2.4.3 -# AIRFLOW_UID - User ID in Airflow containers -# Default: 50000 -# Those configurations are useful mostly in case of standalone testing/running Airflow in test/try-out mode -# -# _AIRFLOW_WWW_USER_USERNAME - Username for the administrator account (if requested). -# Default: airflow -# _AIRFLOW_WWW_USER_PASSWORD - Password for the administrator account (if requested). -# Default: airflow -# _PIP_ADDITIONAL_REQUIREMENTS - Additional PIP requirements to add when starting all containers. -# Default: '' -# -# Feel free to modify this file to suit your needs. ---- -version: '3' -x-airflow-common: - &airflow-common - # In order to add custom dependencies or upgrade provider packages you can use your extended image. - # Comment the image line, place your Dockerfile in the directory where you placed the docker-compose.yaml - # and uncomment the "build" line below, Then run `docker-compose build` to build the images. - # image: ${AIRFLOW_IMAGE_NAME:-apache/airflow:2.4.3} - build: . - environment: - &airflow-common-env - AIRFLOW__CORE__EXECUTOR: CeleryExecutor - AIRFLOW__DATABASE__SQL_ALCHEMY_CONN: postgresql+psycopg2://airflow:airflow@postgres/airflow - # For backward compatibility, with Airflow <2.3 - AIRFLOW__CORE__SQL_ALCHEMY_CONN: postgresql+psycopg2://airflow:airflow@postgres/airflow - AIRFLOW__CELERY__RESULT_BACKEND: db+postgresql://airflow:airflow@postgres/airflow - AIRFLOW__CELERY__BROKER_URL: redis://:@redis:6379/0 - AIRFLOW__CORE__FERNET_KEY: '' - AIRFLOW__CORE__DAGS_ARE_PAUSED_AT_CREATION: 'true' - AIRFLOW__CORE__LOAD_EXAMPLES: 'false' - AIRFLOW__API__AUTH_BACKENDS: 'airflow.api.auth.backend.basic_auth' - _PIP_ADDITIONAL_REQUIREMENTS: 'argcomplete' - AIRFLOW__CODE_EDITOR__ENABLED: 'true' - AIRFLOW__CODE_EDITOR__GIT_ENABLED: 'false' - AIRFLOW__CODE_EDITOR__STRING_NORMALIZATION: 'true' - AIRFLOW__CODE_EDITOR__MOUNT: '/opt/airflow/dags' - pg_user: "${pg_user}" - pg_password: "${pg_password}" - pg_dbname: "${pg_dbname}" - pg_host: "${pg_host}" - pg_port: "${pg_port}" - PG_TIMEOUT: "${PG_TIMEOUT}" - PG_POOL: "${PG_POOL}" - volumes: - - ./dags:/opt/airflow/dags - - ./logs:/opt/airflow/logs - - ./plugins:/opt/airflow/plugins - - ./scripts:/opt/airflow/scripts - - ./cache:/opt/airflow/cache - user: "${AIRFLOW_UID:-50000}:0" - depends_on: - &airflow-common-depends-on - redis: - condition: service_healthy - postgres: - condition: service_healthy - -services: - postgres: - image: postgres:13 - environment: - POSTGRES_USER: airflow - POSTGRES_PASSWORD: airflow - POSTGRES_DB: airflow - volumes: - - postgres-db-volume:/var/lib/postgresql/data - healthcheck: - test: ["CMD", "pg_isready", "-U", "airflow"] - interval: 5s - retries: 5 - restart: always - - redis: - image: redis:latest - expose: - - 6379 - healthcheck: - test: ["CMD", "redis-cli", "ping"] - interval: 5s - timeout: 30s - retries: 50 - restart: always - - airflow-webserver: - <<: *airflow-common - command: webserver - ports: - - 8080:8080 - healthcheck: - test: ["CMD", "curl", "--fail", "http://localhost:8080/health"] - interval: 10s - timeout: 10s - retries: 5 - restart: always - depends_on: - <<: *airflow-common-depends-on - airflow-init: - condition: service_completed_successfully - - airflow-scheduler: - <<: *airflow-common - command: scheduler - healthcheck: - test: ["CMD-SHELL", 'airflow jobs check --job-type SchedulerJob --hostname "$${HOSTNAME}"'] - interval: 10s - timeout: 10s - retries: 5 - restart: always - depends_on: - <<: *airflow-common-depends-on - airflow-init: - condition: service_completed_successfully - - airflow-worker: - <<: *airflow-common - command: celery worker - healthcheck: - test: - - "CMD-SHELL" - - 'celery --app airflow.executors.celery_executor.app inspect ping -d "celery@$${HOSTNAME}"' - interval: 10s - timeout: 10s - retries: 5 - environment: - <<: *airflow-common-env - # Required to handle warm shutdown of the celery workers properly - # See https://airflow.apache.org/docs/docker-stack/entrypoint.html#signal-propagation - DUMB_INIT_SETSID: "0" - restart: always - depends_on: - <<: *airflow-common-depends-on - airflow-init: - condition: service_completed_successfully - - airflow-triggerer: - <<: *airflow-common - command: triggerer - healthcheck: - test: ["CMD-SHELL", 'airflow jobs check --job-type TriggererJob --hostname "$${HOSTNAME}"'] - interval: 10s - timeout: 10s - retries: 5 - restart: always - depends_on: - <<: *airflow-common-depends-on - airflow-init: - condition: service_completed_successfully - - airflow-init: - <<: *airflow-common - entrypoint: /bin/bash - # yamllint disable rule:line-length - command: - - -c - - | - function ver() { - printf "%04d%04d%04d%04d" $${1//./ } - } - register-python-argcomplete airflow >> ~/.bashrc - airflow_version=$$(AIRFLOW__LOGGING__LOGGING_LEVEL=INFO && gosu airflow airflow version) - airflow_version_comparable=$$(ver $${airflow_version}) - min_airflow_version=2.2.0 - min_airflow_version_comparable=$$(ver $${min_airflow_version}) - if [[ -z "${AIRFLOW_UID}" ]]; then - echo - echo -e "\033[1;33mWARNING!!!: AIRFLOW_UID not set!\e[0m" - echo "If you are on Linux, you SHOULD follow the instructions below to set " - echo "AIRFLOW_UID environment variable, otherwise files will be owned by root." - echo "For other operating systems you can get rid of the warning with manually created .env file:" - echo " See: https://airflow.apache.org/docs/apache-airflow/stable/howto/docker-compose/index.html#setting-the-right-airflow-user" - echo - fi - one_meg=1048576 - mem_available=$$(($$(getconf _PHYS_PAGES) * $$(getconf PAGE_SIZE) / one_meg)) - cpus_available=$$(grep -cE 'cpu[0-9]+' /proc/stat) - disk_available=$$(df / | tail -1 | awk '{print $$4}') - warning_resources="false" - if (( mem_available < 4000 )) ; then - echo - echo -e "\033[1;33mWARNING!!!: Not enough memory available for Docker.\e[0m" - echo "At least 4GB of memory required. You have $$(numfmt --to iec $$((mem_available * one_meg)))" - echo - warning_resources="true" - fi - if (( cpus_available < 2 )); then - echo - echo -e "\033[1;33mWARNING!!!: Not enough CPUS available for Docker.\e[0m" - echo "At least 2 CPUs recommended. You have $${cpus_available}" - echo - warning_resources="true" - fi - if (( disk_available < one_meg * 10 )); then - echo - echo -e "\033[1;33mWARNING!!!: Not enough Disk space available for Docker.\e[0m" - echo "At least 10 GBs recommended. You have $$(numfmt --to iec $$((disk_available * 1024 )))" - echo - warning_resources="true" - fi - if [[ $${warning_resources} == "true" ]]; then - echo - echo -e "\033[1;33mWARNING!!!: You have not enough resources to run Airflow (see above)!\e[0m" - echo "Please follow the instructions to increase amount of resources available:" - echo " https://airflow.apache.org/docs/apache-airflow/stable/howto/docker-compose/index.html#before-you-begin" - echo - fi - mkdir -p /sources/logs /sources/dags /sources/plugins - chown -R "${AIRFLOW_UID}:0" /sources/{logs,dags,plugins} - exec /entrypoint airflow version - # yamllint enable rule:line-length - environment: - <<: *airflow-common-env - _AIRFLOW_DB_UPGRADE: 'true' - _AIRFLOW_WWW_USER_CREATE: 'true' - _AIRFLOW_WWW_USER_USERNAME: ${_AIRFLOW_WWW_USER_USERNAME:-airflow} - _AIRFLOW_WWW_USER_PASSWORD: ${_AIRFLOW_WWW_USER_PASSWORD:-airflow} - _PIP_ADDITIONAL_REQUIREMENTS: '' - user: "0:0" - volumes: - - .:/sources - - airflow-cli: - <<: *airflow-common - profiles: - - debug - environment: - <<: *airflow-common-env - CONNECTION_CHECK_MAX_COUNT: "0" - # Workaround for entrypoint issue. See: https://github.com/apache/airflow/issues/16252 - command: - - bash - - -c - - airflow - - # You can enable flower by adding "--profile flower" option e.g. docker-compose --profile flower up - # or by explicitly targeted on the command line e.g. docker-compose up flower. - # See: https://docs.docker.com/compose/profiles/ - flower: - <<: *airflow-common - command: celery flower - profiles: - - flower - ports: - - 5555:5555 - healthcheck: - test: ["CMD", "curl", "--fail", "http://localhost:5555/"] - interval: 10s - timeout: 10s - retries: 5 - restart: always - depends_on: - <<: *airflow-common-depends-on - airflow-init: - condition: service_completed_successfully - -volumes: - postgres-db-volume: diff --git a/ee/recommendation/ml_service/Dockerfile b/ee/recommendation/ml_service/Dockerfile new file mode 100644 index 000000000..f2e811fd7 --- /dev/null +++ b/ee/recommendation/ml_service/Dockerfile @@ -0,0 +1,17 @@ +FROM recommendations_base + +RUN apt-get update \ + && apt-get install -y gcc libc-dev g++ pkg-config libxml2-dev libxmlsec1-dev libxmlsec1-openssl \ + && apt-get clean + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY auth auth +COPY crons crons +COPY core core +COPY entrypoint.sh . +COPY api.py . + +EXPOSE 7001 +ENTRYPOINT ./entrypoint.sh diff --git a/ee/recommendation/ml_service/api.py b/ee/recommendation/ml_service/api.py new file mode 100644 index 000000000..f247f86f0 --- /dev/null +++ b/ee/recommendation/ml_service/api.py @@ -0,0 +1,83 @@ +from fastapi import FastAPI, Depends +from contextlib import asynccontextmanager +from utils import pg_client +from core.model_handler import recommendation_model +from utils.declarations import FeedbackRecommendation +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from crons.base_crons import cron_jobs +from auth.auth_key import api_key_auth +from core import feedback +from fastapi.middleware.cors import CORSMiddleware + + +@asynccontextmanager +async def lifespan(app: FastAPI): + await pg_client.init() + await feedback.init() + await recommendation_model.update() + app.schedule.start() + for job in cron_jobs: + app.schedule.add_job(id=job['func'].__name__, **job) + yield + app.schedule.shutdown(wait=False) + await feedback.terminate() + await pg_client.terminate() + +app = FastAPI(lifespan=lifespan) +app.schedule = AsyncIOScheduler() + +origins = [ + "*" +] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# @app.on_event('startup') +# async def startup(): +# await pg_client.init() +# await feedback.init() +# await recommendation_model.update() +# app.schedule.start() +# for job in cron_jobs: +# app.schedule.add_job(id=job['func'].__name__, **job) +# +# +# @app.on_event('shutdown') +# async def shutdown(): +# app.schedule.shutdown(wait=False) +# await feedback.terminate() +# await pg_client.terminate() + + +@app.get('/recommendations/{user_id}/{project_id}', dependencies=[Depends(api_key_auth)]) +async def get_recommended_sessions(user_id: int, project_id: int): + recommendations = recommendation_model.get_recommendations(user_id, project_id) + return {'userId': user_id, + 'projectId': project_id, + 'recommendations': recommendations + } + + +@app.get('/recommendations/{projectId}/{viewerId}/{sessionId}', dependencies=[Depends(api_key_auth)]) +async def already_gave_feedback(projectId: int, viewerId: int, sessionId: int): + return feedback.has_feedback((viewerId, sessionId, projectId)) + + +@app.post('/recommendations/feedback', dependencies=[Depends(api_key_auth)]) +async def get_feedback(data: FeedbackRecommendation): + try: + feedback.global_queue.put(tuple(data.dict().values())) + except Exception as e: + return {'error': e} + return {'success': 1} + + +@app.get('/') +async def health(): + return {'status': 200} diff --git a/ee/recommendation/ml_service/auth/auth_key.py b/ee/recommendation/ml_service/auth/auth_key.py new file mode 100644 index 000000000..dd9438ac4 --- /dev/null +++ b/ee/recommendation/ml_service/auth/auth_key.py @@ -0,0 +1,33 @@ +from fastapi.security import OAuth2PasswordBearer +from fastapi import HTTPException, Depends, status +from decouple import config + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + +class AuthHandler: + def __init__(self): + """ + Authorization method using an API key. + """ + self.__api_keys = [config("API_AUTH_KEY")] + + def __contains__(self, api_key): + return api_key in self.__api_keys + + def add_key(self, key): + """Adds new key for authentication.""" + self.__api_keys.append(key) + + +auth_method = AuthHandler() + + +def api_key_auth(api_key: str = Depends(oauth2_scheme)): + """Method to verify auth.""" + global auth_method + if api_key not in auth_method: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Forbidden" + ) \ No newline at end of file diff --git a/ee/recommendation/ml_service/build_dev.sh b/ee/recommendation/ml_service/build_dev.sh new file mode 100755 index 000000000..c11741677 --- /dev/null +++ b/ee/recommendation/ml_service/build_dev.sh @@ -0,0 +1,2 @@ +cp ../../api/chalicelib/utils/ch_client.py utils +cp ../../../api/chalicelib/utils/pg_client.py utils diff --git a/ee/recommendation/ml_service/core/feedback.py b/ee/recommendation/ml_service/core/feedback.py new file mode 100644 index 000000000..64b44548c --- /dev/null +++ b/ee/recommendation/ml_service/core/feedback.py @@ -0,0 +1,132 @@ +import json +import queue +import logging +from decouple import config +from time import time + +from mlflow.store.db.utils import create_sqlalchemy_engine +from sqlalchemy.orm import sessionmaker, session +from sqlalchemy import text +from contextlib import contextmanager + +global_queue = None + + +class ConnectionHandler: + _sessions = sessionmaker() + def __init__(self, uri): + """Connects into mlflow database.""" + self.engine = create_sqlalchemy_engine(uri) + + @contextmanager + def get_live_session(self) -> session: + """ + This is a session that can be committed. + Changes will be reflected in the database. + """ + # Automatic transaction and connection handling in session + connection = self.engine.connect() + my_session = type(self)._sessions(bind=connection) + + yield my_session + + my_session.close() + connection.close() + + +class EventQueue: + def __init__(self, queue_max_length=50): + """Saves all recommendations until queue_max_length (default 50) is reached + or max_retention_time surpassed (env value, default 1 hour).""" + self.events = queue.Queue() + self.events.maxsize = queue_max_length + host = config('pg_host_ml') + port = config('pg_port_ml') + user = config('pg_user_ml') + dbname = config('pg_dbname_ml') + password = config('pg_password_ml') + + tracking_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{dbname}" + self.connection_handler = ConnectionHandler(tracking_uri) + self.last_flush = time() + self.max_retention_time = config('max_retention_time', default=60*60) + self.feedback_short_mem = list() + + def flush(self, conn): + """Insert recommendations into table recommendation_feedback from mlflow database.""" + events = list() + params = dict() + i = 0 + insertion_time = time() + while not self.events.empty(): + user_id, session_id, project_id, payload = self.events.get() + params[f'user_id_{i}'] = user_id + params[f'session_id_{i}'] = session_id + params[f'project_id_{i}'] = project_id + params[f'payload_{i}'] = json.dumps(payload) + events.append( + f"(%(user_id_{i})s, %(session_id_{i})s, %(project_id_{i})s, %(payload_{i})s::jsonb, {insertion_time})") + i += 1 + self.last_flush = time() + self.feedback_short_mem = list() + if i == 0: + return 0 + cur = conn.connection().connection.cursor() + query = cur.mogrify(f"""INSERT INTO recommendation_feedback (user_id, session_id, project_id, payload, insertion_time) VALUES {' , '.join(events)};""", params) + conn.execute(text(query.decode("utf-8"))) + conn.commit() + return 1 + + def force_flush(self): + """Force method flush.""" + if not self.events.empty(): + try: + with self.connection_handler.get_live_session() as conn: + self.flush(conn) + except Exception as e: + print(f'Error: {e}') + + def put(self, element): + """Adds recommendation into the queue.""" + current_time = time() + if self.events.full() or current_time - self.last_flush > self.max_retention_time: + try: + with self.connection_handler.get_live_session() as conn: + self.flush(conn) + except Exception as e: + print(f'Error: {e}') + self.events.put(element) + self.feedback_short_mem.append(element[:3]) + self.events.task_done() + + def already_has_feedback(self, element): + """"This method verifies if a feedback is already send for the current user-project-sessionId.""" + if element[:3] in self.feedback_short_mem: + return True + else: + with self.connection_handler.get_live_session() as conn: + cur = conn.connection().connection.cursor() + query = cur.mogrify("SELECT * FROM recommendation_feedback WHERE user_id=%(user_id)s AND session_id=%(session_id)s AND project_id=%(project_id)s LIMIT 1", + {'user_id': element[0], 'session_id': element[1], 'project_id': element[2]}) + cur_result = conn.execute(text(query.decode('utf-8'))) + res = cur_result.fetchall() + return len(res) == 1 + + +def has_feedback(data): + global global_queue + assert global_queue is not None, 'Global queue is not yet initialized' + return global_queue.already_has_feedback(data) + + +async def init(): + global global_queue + global_queue = EventQueue() + print("> queue initialized") + + +async def terminate(): + global global_queue + if global_queue is not None: + global_queue.force_flush() + print('> queue fulshed') diff --git a/ee/recommendation/ml_service/core/model_handler.py b/ee/recommendation/ml_service/core/model_handler.py new file mode 100644 index 000000000..45c2470f7 --- /dev/null +++ b/ee/recommendation/ml_service/core/model_handler.py @@ -0,0 +1,164 @@ +import mlflow +import hashlib +import numpy as np +from decouple import config +from utils import pg_client +from utils.df_utils import _process_pg_response +from time import time + +host = config('pg_host_ml') +port = config('pg_port_ml') +user = config('pg_user_ml') +dbname = config('pg_dbname_ml') +password = config('pg_password_ml') + +tracking_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{dbname}" +mlflow.set_tracking_uri(tracking_uri) +batch_download_size = config('batch_download_size', default=10, cast=int) + + +def get_latest_uri(projectId, tenantId): + client = mlflow.MlflowClient() + hashed = hashlib.sha256(bytes(f'{projectId}-{tenantId}'.encode('utf-8'))).hexdigest() + model_name = f'{hashed}-RecModel' + model_versions = client.search_model_versions(f"name='{model_name}'") + latest = -1 + for i in range(len(model_versions)): + _v = model_versions[i].version + if _v < latest: + continue + else: + latest = _v + return f"runs:/{model_name}/{latest}" + + +def get_tenant(projectId): + with pg_client.PostgresClient() as conn: + conn.execute( + conn.mogrify("SELECT tenant_id FROM projects WHERE project_id=%(projectId)s", {'projectId': projectId}) + ) + res = conn.fetchone() + return res['tenant_id'] + + +class ServedModel: + def __init__(self): + """Handler of mlflow model.""" + self.model = None + + def load_model(self, model_name, model_version=1): + """Load model from mlflow given the model version. + Params: + model_name: model name in mlflow repository + model_version: version of model to be downloaded""" + self.model = mlflow.pyfunc.load_model(f'models:/{model_name}/{model_version}') + + def predict(self, X) -> np.ndarray: + """Make prediction for batch X.""" + assert self.model is not None, 'Model has to be loaded before predicting. See load_model.__doc__' + return self.model.predict(X) + + def _sort_by_recommendation(self, sessions, sessions_features) -> np.ndarray: + """Make prediction for sessions_features and sort them by relevance.""" + pred = self.predict(sessions_features) + threshold = config('threshold_prediction', default=0.6, cast=float) + over_threshold = pred > threshold + pred = pred[over_threshold] + if len(pred) == 0: + return np.array([]) + sorted_idx = np.argsort(pred)[::-1] + return sessions[over_threshold][sorted_idx] + + def get_recommendations(self, userId, projectId): + """Gets recommendations for userId for a given projectId. + Selects last unseen_selection_limit non seen sessions (env value, default 100) + and sort them by pertinence using ML model""" + limit = config('unseen_selection_limit', default=100, cast=int) + oldest_limit = 1000*(time() - config('unseen_max_days_ago_selection', default=30, cast=int)*60*60*24) + with pg_client.PostgresClient() as conn: + query = conn.mogrify( + """SELECT project_id, session_id, user_id, %(userId)s as viewer_id, events_count, errors_count, duration, user_country as country, issue_score, user_device_type as device_type + FROM sessions + WHERE project_id = %(projectId)s AND session_id NOT IN (SELECT session_id FROM user_viewed_sessions WHERE user_id = %(userId)s) AND duration IS NOT NULL AND start_ts > %(oldest_limit)s LIMIT %(limit)s""", + {'userId': userId, 'projectId': projectId, 'limit': limit, 'oldest_limit': oldest_limit} + ) + conn.execute(query) + res = conn.fetchall() + _X = list() + _Y = list() + X_project_ids = dict() + X_users_ids = dict() + X_sessions_ids = dict() + _process_pg_response(res, _X, _Y, X_project_ids, X_users_ids, X_sessions_ids, label=0) + + return self._sort_by_recommendation(np.array(list(X_sessions_ids.keys())), _X).tolist() + + +class Recommendations: + def __init__(self): + """Handler for multiple models. + Properties: + * names [dict]: names of current available models and its versions (model name as key). + * models [dict]: ServedModels objects (model name as key). + * to_download [list]: list of model name and version to be downloaded from mlflow server (in s3). + """ + self.names = dict() + self.models = dict() + self.to_download = list() + + async def update(self): + """Fill to_download list with new models or new version for saved models.""" + r_models = mlflow.search_registered_models() + new_names = {m.name: max(m.latest_versions).version for m in r_models} + for name, version in new_names.items(): + if (name, version) in self.names.items(): + continue + self.to_download.append((name, version)) + # self.download_model(name, version) + self.names = new_names + + async def download_next(self): + """Pop element from to_download, download and add it into models.""" + download_loop_number = 0 + if self.to_download: + while download_loop_number < batch_download_size: + try: + name, version = self.to_download.pop(0) + s_model = ServedModel() + s_model.load_model(name, version) + self.models[name] = s_model + download_loop_number += 1 + except IndexError: + break + except Exception as e: + print('[Error] Found exception') + print(repr(e)) + break + + def download_model(self, name, version): + model = ServedModel() + model.load_model(name, version) + self.models[name] = model + + def info(self): + """Show current loaded models.""" + print('Current models inside:') + for model_name, model in self.models.items(): + print('Name:', model_name) + print(model.model) + + def get_recommendations(self, userId, projectId, n_recommendations=5): + """Gets recommendation for userId given the projectId. + This method selects the corresponding model and gets recommended sessions ordered by relevance.""" + tenantId = get_tenant(projectId) + hashed = hashlib.sha256(bytes(f'{projectId}-{tenantId}'.encode('utf-8'))).hexdigest() + model_name = f'{hashed}-RecModel' + n_recommendations = config('number_of_recommendations', default=5, cast=int) + try: + model = self.models[model_name] + except KeyError: + return [] + return model.get_recommendations(userId, projectId)[:n_recommendations] + + +recommendation_model = Recommendations() diff --git a/ee/recommendation/ml_service/crons/base_crons.py b/ee/recommendation/ml_service/crons/base_crons.py new file mode 100644 index 000000000..7ffc1c007 --- /dev/null +++ b/ee/recommendation/ml_service/crons/base_crons.py @@ -0,0 +1,18 @@ +from apscheduler.triggers.cron import CronTrigger +from apscheduler.triggers.interval import IntervalTrigger +from core.model_handler import recommendation_model + + +async def update_model(): + """Update list of models to download.""" + await recommendation_model.update() + + +async def download_model(): + """Download next model in list.""" + await recommendation_model.download_next() + +cron_jobs = [ + {"func": update_model, "trigger": CronTrigger(hour=0), "misfire_grace_time": 60, "max_instances": 1}, + {"func": download_model, "trigger": IntervalTrigger(seconds=10), "misfire_grace_time": 60, "max_instances": 1}, +] diff --git a/ee/recommendation/ml_service/entrypoint.sh b/ee/recommendation/ml_service/entrypoint.sh new file mode 100755 index 000000000..c2e623d4a --- /dev/null +++ b/ee/recommendation/ml_service/entrypoint.sh @@ -0,0 +1,2 @@ +export MLFLOW_TRACKING_URI=postgresql+psycopg2://${pg_user_ml}:${pg_password_ml}@${pg_host_ml}:${pg_port_ml}/${pg_dbname_ml} +uvicorn api:app --host 0.0.0.0 --port 7001 --proxy-headers diff --git a/ee/recommendation/ml_service/requirements.txt b/ee/recommendation/ml_service/requirements.txt new file mode 100644 index 000000000..0eb819bf7 --- /dev/null +++ b/ee/recommendation/ml_service/requirements.txt @@ -0,0 +1,4 @@ +fastapi==0.95.2 +apscheduler==3.10.1 +uvicorn==0.22.0 +SQLAlchemy==2.0.15 diff --git a/ee/recommendation/ml_service/run.sh b/ee/recommendation/ml_service/run.sh new file mode 100755 index 000000000..fbf47ec36 --- /dev/null +++ b/ee/recommendation/ml_service/run.sh @@ -0,0 +1 @@ +docker run -e AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} -e AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} -e AWS_SESSION_TOKEN=${AWS_SESSION_TOKEN} -p 7001:7001 -t recommendation_service diff --git a/ee/recommendation/ml_service/sql/init.sql b/ee/recommendation/ml_service/sql/init.sql new file mode 100644 index 000000000..0a05f925d --- /dev/null +++ b/ee/recommendation/ml_service/sql/init.sql @@ -0,0 +1,20 @@ +DO +$do$ +BEGIN + IF EXISTS (SELECT FROM pg_database WHERE datname = 'mlruns') THEN + RAISE NOTICE 'Database already exists'; -- optional + ELSE + PERFORM dblink_exec('dbname=' || current_database() -- current db + , 'CREATE DATABASE mlruns'); + END IF; +END +$do$; + +CREATE TABLE IF NOT EXISTS mlruns.public.recommendation_feedback +( + user_id BIGINT, + session_id BIGINT, + project_id BIGINT, + payload jsonb, + insertion_time BIGINT +); diff --git a/ee/recommendation/ml_service/test.py b/ee/recommendation/ml_service/test.py new file mode 100644 index 000000000..baade146c --- /dev/null +++ b/ee/recommendation/ml_service/test.py @@ -0,0 +1,16 @@ +from core.model_handler import Recommendations +from utils import pg_client +import asyncio + + +async def main(): + await pg_client.init() + R = Recommendations() + R.to_download = [('****************************************************************-RecModel', 1)] + await R.download_next() + L = R.get_recommendations(000000000, 000000000) + print(L) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/ee/recommendation/ml_trainer/Dockerfile b/ee/recommendation/ml_trainer/Dockerfile new file mode 100644 index 000000000..2adfd1333 --- /dev/null +++ b/ee/recommendation/ml_trainer/Dockerfile @@ -0,0 +1,18 @@ +FROM recommendations_base + + +RUN apt-get update \ + && apt-get install -y gcc libc-dev g++ pkg-config libxml2-dev libxmlsec1-dev libxmlsec1-openssl git\ + && apt-get clean + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY airflow airflow +COPY entrypoint.sh . +COPY mlflow_server.sh . +COPY main.py . + +EXPOSE 8080 +EXPOSE 5000 +ENTRYPOINT ./entrypoint.sh diff --git a/ee/recommendation/ml_trainer/airflow/airflow.cfg b/ee/recommendation/ml_trainer/airflow/airflow.cfg new file mode 100644 index 000000000..abdf26bab --- /dev/null +++ b/ee/recommendation/ml_trainer/airflow/airflow.cfg @@ -0,0 +1,1238 @@ +[core] +# The folder where your airflow pipelines live, most likely a +# subfolder in a code repository. This path must be absolute. +dags_folder = /app/airflow/dags + +# Hostname by providing a path to a callable, which will resolve the hostname. +# The format is "package.function". +# +# For example, default value "airflow.utils.net.getfqdn" means that result from patched +# version of socket.getfqdn() - see https://github.com/python/cpython/issues/49254. +# +# No argument should be required in the function specified. +# If using IP address as hostname is preferred, use value ``airflow.utils.net.get_host_ip_address`` +hostname_callable = airflow.utils.net.getfqdn + +# Default timezone in case supplied date times are naive +# can be utc (default), system, or any IANA timezone string (e.g. Europe/Amsterdam) +default_timezone = utc + +# The executor class that airflow should use. Choices include +# ``SequentialExecutor``, ``LocalExecutor``, ``CeleryExecutor``, ``DaskExecutor``, +# ``KubernetesExecutor``, ``CeleryKubernetesExecutor`` or the +# full import path to the class when using a custom executor. +executor = SequentialExecutor + +# This defines the maximum number of task instances that can run concurrently per scheduler in +# Airflow, regardless of the worker count. Generally this value, multiplied by the number of +# schedulers in your cluster, is the maximum number of task instances with the running +# state in the metadata database. +parallelism = 32 + +# The maximum number of task instances allowed to run concurrently in each DAG. To calculate +# the number of tasks that is running concurrently for a DAG, add up the number of running +# tasks for all DAG runs of the DAG. This is configurable at the DAG level with ``max_active_tasks``, +# which is defaulted as ``max_active_tasks_per_dag``. +# +# An example scenario when this would be useful is when you want to stop a new dag with an early +# start date from stealing all the executor slots in a cluster. +max_active_tasks_per_dag = 16 + +# Are DAGs paused by default at creation +dags_are_paused_at_creation = False + +# The maximum number of active DAG runs per DAG. The scheduler will not create more DAG runs +# if it reaches the limit. This is configurable at the DAG level with ``max_active_runs``, +# which is defaulted as ``max_active_runs_per_dag``. +max_active_runs_per_dag = 16 + +# The name of the method used in order to start Python processes via the multiprocessing module. +# This corresponds directly with the options available in the Python docs: +# https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method. +# Must be one of the values returned by: +# https://docs.python.org/3/library/multiprocessing.html#multiprocessing.get_all_start_methods. +# Example: mp_start_method = fork +# mp_start_method = + +# Whether to load the DAG examples that ship with Airflow. It's good to +# get started, but you probably want to set this to ``False`` in a production +# environment +load_examples = False + +# Path to the folder containing Airflow plugins +plugins_folder = /app/airflow/plugins + +# Should tasks be executed via forking of the parent process ("False", +# the speedier option) or by spawning a new python process ("True" slow, +# but means plugin changes picked up by tasks straight away) +execute_tasks_new_python_interpreter = False + +# Secret key to save connection passwords in the db +fernet_key = + +# Whether to disable pickling dags +donot_pickle = True + +# How long before timing out a python file import +dagbag_import_timeout = 30.0 + +# Should a traceback be shown in the UI for dagbag import errors, +# instead of just the exception message +dagbag_import_error_tracebacks = True + +# If tracebacks are shown, how many entries from the traceback should be shown +dagbag_import_error_traceback_depth = 2 + +# How long before timing out a DagFileProcessor, which processes a dag file +dag_file_processor_timeout = 50 + +# The class to use for running task instances in a subprocess. +# Choices include StandardTaskRunner, CgroupTaskRunner or the full import path to the class +# when using a custom task runner. +task_runner = StandardTaskRunner + +# If set, tasks without a ``run_as_user`` argument will be run with this user +# Can be used to de-elevate a sudo user running Airflow when executing tasks +default_impersonation = + +# What security module to use (for example kerberos) +security = + +# Turn unit test mode on (overwrites many configuration options with test +# values at runtime) +unit_test_mode = False + +# Whether to enable pickling for xcom (note that this is insecure and allows for +# RCE exploits). +enable_xcom_pickling = False + +# What classes can be imported during deserialization. This is a multi line value. +# The individual items will be parsed as regexp. Python built-in classes (like dict) +# are always allowed +allowed_deserialization_classes = airflow\..* + +# When a task is killed forcefully, this is the amount of time in seconds that +# it has to cleanup after it is sent a SIGTERM, before it is SIGKILLED +killed_task_cleanup_time = 60 + +# Whether to override params with dag_run.conf. If you pass some key-value pairs +# through ``airflow dags backfill -c`` or +# ``airflow dags trigger -c``, the key-value pairs will override the existing ones in params. +dag_run_conf_overrides_params = True + +# When discovering DAGs, ignore any files that don't contain the strings ``DAG`` and ``airflow``. +dag_discovery_safe_mode = True + +# The pattern syntax used in the ".airflowignore" files in the DAG directories. Valid values are +# ``regexp`` or ``glob``. +dag_ignore_file_syntax = regexp + +# The number of retries each task is going to have by default. Can be overridden at dag or task level. +default_task_retries = 0 + +# The number of seconds each task is going to wait by default between retries. Can be overridden at +# dag or task level. +default_task_retry_delay = 300 + +# The weighting method used for the effective total priority weight of the task +default_task_weight_rule = downstream + +# The default task execution_timeout value for the operators. Expected an integer value to +# be passed into timedelta as seconds. If not specified, then the value is considered as None, +# meaning that the operators are never timed out by default. +default_task_execution_timeout = + +# Updating serialized DAG can not be faster than a minimum interval to reduce database write rate. +min_serialized_dag_update_interval = 30 + +# If True, serialized DAGs are compressed before writing to DB. +# Note: this will disable the DAG dependencies view +compress_serialized_dags = False + +# Fetching serialized DAG can not be faster than a minimum interval to reduce database +# read rate. This config controls when your DAGs are updated in the Webserver +min_serialized_dag_fetch_interval = 10 + +# Maximum number of Rendered Task Instance Fields (Template Fields) per task to store +# in the Database. +# All the template_fields for each of Task Instance are stored in the Database. +# Keeping this number small may cause an error when you try to view ``Rendered`` tab in +# TaskInstance view for older tasks. +max_num_rendered_ti_fields_per_task = 30 + +# On each dagrun check against defined SLAs +check_slas = True + +# Path to custom XCom class that will be used to store and resolve operators results +# Example: xcom_backend = path.to.CustomXCom +xcom_backend = airflow.models.xcom.BaseXCom + +# By default Airflow plugins are lazily-loaded (only loaded when required). Set it to ``False``, +# if you want to load plugins whenever 'airflow' is invoked via cli or loaded from module. +lazy_load_plugins = True + +# By default Airflow providers are lazily-discovered (discovery and imports happen only when required). +# Set it to False, if you want to discover providers whenever 'airflow' is invoked via cli or +# loaded from module. +lazy_discover_providers = True + +# Hide sensitive Variables or Connection extra json keys from UI and task logs when set to True +# +# (Connection passwords are always hidden in logs) +hide_sensitive_var_conn_fields = True + +# A comma-separated list of extra sensitive keywords to look for in variables names or connection's +# extra JSON. +sensitive_var_conn_names = + +# Task Slot counts for ``default_pool``. This setting would not have any effect in an existing +# deployment where the ``default_pool`` is already created. For existing deployments, users can +# change the number of slots using Webserver, API or the CLI +default_pool_task_slot_count = 128 + +# The maximum list/dict length an XCom can push to trigger task mapping. If the pushed list/dict has a +# length exceeding this value, the task pushing the XCom will be failed automatically to prevent the +# mapped tasks from clogging the scheduler. +max_map_length = 1024 + +# The default umask to use for process when run in daemon mode (scheduler, worker, etc.) +# +# This controls the file-creation mode mask which determines the initial value of file permission bits +# for newly created files. +# +# This value is treated as an octal-integer. +daemon_umask = 0o077 + +# Class to use as dataset manager. +# Example: dataset_manager_class = airflow.datasets.manager.DatasetManager +# dataset_manager_class = + +# Kwargs to supply to dataset manager. +# Example: dataset_manager_kwargs = {"some_param": "some_value"} +# dataset_manager_kwargs = + +[database] +# The SqlAlchemy connection string to the metadata database. +# SqlAlchemy supports many different database engines. +# More information here: +# http://airflow.apache.org/docs/apache-airflow/stable/howto/set-up-database.html#database-uri +sql_alchemy_conn = postgresql+psycopg2://{{pg_user_airflow}}:{{pg_password_airflow}}@{{pg_host_airflow}}:{{pg_port_airflow}}/{{pg_dbname_airflow}} + +# Extra engine specific keyword args passed to SQLAlchemy's create_engine, as a JSON-encoded value +# Example: sql_alchemy_engine_args = {"arg1": True} +# sql_alchemy_engine_args = + +# The encoding for the databases +sql_engine_encoding = utf-8 + +# Collation for ``dag_id``, ``task_id``, ``key``, ``external_executor_id`` columns +# in case they have different encoding. +# By default this collation is the same as the database collation, however for ``mysql`` and ``mariadb`` +# the default is ``utf8mb3_bin`` so that the index sizes of our index keys will not exceed +# the maximum size of allowed index when collation is set to ``utf8mb4`` variant +# (see https://github.com/apache/airflow/pull/17603#issuecomment-901121618). +# sql_engine_collation_for_ids = + +# If SqlAlchemy should pool database connections. +sql_alchemy_pool_enabled = True + +# The SqlAlchemy pool size is the maximum number of database connections +# in the pool. 0 indicates no limit. +sql_alchemy_pool_size = 5 + +# The maximum overflow size of the pool. +# When the number of checked-out connections reaches the size set in pool_size, +# additional connections will be returned up to this limit. +# When those additional connections are returned to the pool, they are disconnected and discarded. +# It follows then that the total number of simultaneous connections the pool will allow +# is pool_size + max_overflow, +# and the total number of "sleeping" connections the pool will allow is pool_size. +# max_overflow can be set to ``-1`` to indicate no overflow limit; +# no limit will be placed on the total number of concurrent connections. Defaults to ``10``. +sql_alchemy_max_overflow = 10 + +# The SqlAlchemy pool recycle is the number of seconds a connection +# can be idle in the pool before it is invalidated. This config does +# not apply to sqlite. If the number of DB connections is ever exceeded, +# a lower config value will allow the system to recover faster. +sql_alchemy_pool_recycle = 1800 + +# Check connection at the start of each connection pool checkout. +# Typically, this is a simple statement like "SELECT 1". +# More information here: +# https://docs.sqlalchemy.org/en/14/core/pooling.html#disconnect-handling-pessimistic +sql_alchemy_pool_pre_ping = True + +# The schema to use for the metadata database. +# SqlAlchemy supports databases with the concept of multiple schemas. +sql_alchemy_schema = + +# Import path for connect args in SqlAlchemy. Defaults to an empty dict. +# This is useful when you want to configure db engine args that SqlAlchemy won't parse +# in connection string. +# See https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.connect_args +# sql_alchemy_connect_args = + +# Whether to load the default connections that ship with Airflow. It's good to +# get started, but you probably want to set this to ``False`` in a production +# environment +load_default_connections = True + +# Number of times the code should be retried in case of DB Operational Errors. +# Not all transactions will be retried as it can cause undesired state. +# Currently it is only used in ``DagFileProcessor.process_file`` to retry ``dagbag.sync_to_db``. +max_db_retries = 3 + +[logging] +# The folder where airflow should store its log files. +# This path must be absolute. +# There are a few existing configurations that assume this is set to the default. +# If you choose to override this you may need to update the dag_processor_manager_log_location and +# dag_processor_manager_log_location settings as well. +base_log_folder = /app/airflow/logs + +# Airflow can store logs remotely in AWS S3, Google Cloud Storage or Elastic Search. +# Set this to True if you want to enable remote logging. +remote_logging = False + +# Users must supply an Airflow connection id that provides access to the storage +# location. Depending on your remote logging service, this may only be used for +# reading logs, not writing them. +remote_log_conn_id = + +# Path to Google Credential JSON file. If omitted, authorization based on `the Application Default +# Credentials +# `__ will +# be used. +google_key_path = + +# Storage bucket URL for remote logging +# S3 buckets should start with "s3://" +# Cloudwatch log groups should start with "cloudwatch://" +# GCS buckets should start with "gs://" +# WASB buckets should start with "wasb" just to help Airflow select correct handler +# Stackdriver logs should start with "stackdriver://" +remote_base_log_folder = + +# Use server-side encryption for logs stored in S3 +encrypt_s3_logs = False + +# Logging level. +# +# Supported values: ``CRITICAL``, ``ERROR``, ``WARNING``, ``INFO``, ``DEBUG``. +logging_level = INFO + +# Logging level for celery. If not set, it uses the value of logging_level +# +# Supported values: ``CRITICAL``, ``ERROR``, ``WARNING``, ``INFO``, ``DEBUG``. +celery_logging_level = + +# Logging level for Flask-appbuilder UI. +# +# Supported values: ``CRITICAL``, ``ERROR``, ``WARNING``, ``INFO``, ``DEBUG``. +fab_logging_level = WARNING + +# Logging class +# Specify the class that will specify the logging configuration +# This class has to be on the python classpath +# Example: logging_config_class = my.path.default_local_settings.LOGGING_CONFIG +logging_config_class = + +# Flag to enable/disable Colored logs in Console +# Colour the logs when the controlling terminal is a TTY. +colored_console_log = True + +# Log format for when Colored logs is enabled +colored_log_format = [%%(blue)s%%(asctime)s%%(reset)s] {%%(blue)s%%(filename)s:%%(reset)s%%(lineno)d} %%(log_color)s%%(levelname)s%%(reset)s - %%(log_color)s%%(message)s%%(reset)s +colored_formatter_class = airflow.utils.log.colored_log.CustomTTYColoredFormatter + +# Format of Log line +log_format = [%%(asctime)s] {%%(filename)s:%%(lineno)d} %%(levelname)s - %%(message)s +simple_log_format = %%(asctime)s %%(levelname)s - %%(message)s + +# Where to send dag parser logs. If "file", logs are sent to log files defined by child_process_log_directory. +dag_processor_log_target = file + +# Format of Dag Processor Log line +dag_processor_log_format = [%%(asctime)s] [SOURCE:DAG_PROCESSOR] {%%(filename)s:%%(lineno)d} %%(levelname)s - %%(message)s +log_formatter_class = airflow.utils.log.timezone_aware.TimezoneAware + +# Specify prefix pattern like mentioned below with stream handler TaskHandlerWithCustomFormatter +# Example: task_log_prefix_template = {ti.dag_id}-{ti.task_id}-{execution_date}-{try_number} +task_log_prefix_template = + +# Formatting for how airflow generates file names/paths for each task run. +log_filename_template = dag_id={{ ti.dag_id }}/run_id={{ ti.run_id }}/task_id={{ ti.task_id }}/{%% if ti.map_index >= 0 %%}map_index={{ ti.map_index }}/{%% endif %%}attempt={{ try_number }}.log + +# Formatting for how airflow generates file names for log +log_processor_filename_template = {{ filename }}.log + +# Full path of dag_processor_manager logfile. +dag_processor_manager_log_location = /app/airflow/logs/dag_processor_manager/dag_processor_manager.log + +# Name of handler to read task instance logs. +# Defaults to use ``task`` handler. +task_log_reader = task + +# A comma\-separated list of third-party logger names that will be configured to print messages to +# consoles\. +# Example: extra_logger_names = connexion,sqlalchemy +extra_logger_names = + +# When you start an airflow worker, airflow starts a tiny web server +# subprocess to serve the workers local log files to the airflow main +# web server, who then builds pages and sends them to users. This defines +# the port on which the logs are served. It needs to be unused, and open +# visible from the main web server to connect into the workers. +worker_log_server_port = 8793 + +[metrics] + +# StatsD (https://github.com/etsy/statsd) integration settings. +# Enables sending metrics to StatsD. +statsd_on = False +statsd_host = localhost +statsd_port = 8125 +statsd_prefix = airflow + +# If you want to avoid sending all the available metrics to StatsD, +# you can configure an allow list of prefixes (comma separated) to send only the metrics that +# start with the elements of the list (e.g: "scheduler,executor,dagrun") +statsd_allow_list = + +# A function that validate the StatsD stat name, apply changes to the stat name if necessary and return +# the transformed stat name. +# +# The function should have the following signature: +# def func_name(stat_name: str) -> str: +stat_name_handler = + +# To enable datadog integration to send airflow metrics. +statsd_datadog_enabled = False + +# List of datadog tags attached to all metrics(e.g: key1:value1,key2:value2) +statsd_datadog_tags = + +# If you want to utilise your own custom StatsD client set the relevant +# module path below. +# Note: The module path must exist on your PYTHONPATH for Airflow to pick it up +# statsd_custom_client_path = + +[secrets] +# Full class name of secrets backend to enable (will precede env vars and metastore in search path) +# Example: backend = airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend +backend = + +# The backend_kwargs param is loaded into a dictionary and passed to __init__ of secrets backend class. +# See documentation for the secrets backend you are using. JSON is expected. +# Example for AWS Systems Manager ParameterStore: +# ``{"connections_prefix": "/airflow/connections", "profile_name": "default"}`` +backend_kwargs = + +[cli] +# In what way should the cli access the API. The LocalClient will use the +# database directly, while the json_client will use the api running on the +# webserver +api_client = airflow.api.client.local_client + +# If you set web_server_url_prefix, do NOT forget to append it here, ex: +# ``endpoint_url = http://localhost:8080/myroot`` +# So api will look like: ``http://localhost:8080/myroot/api/experimental/...`` +endpoint_url = http://localhost:8080 + +[debug] +# Used only with ``DebugExecutor``. If set to ``True`` DAG will fail with first +# failed task. Helpful for debugging purposes. +fail_fast = False + +[api] +# Enables the deprecated experimental API. Please note that these APIs do not have access control. +# The authenticated user has full access. +# +# .. warning:: +# +# This `Experimental REST API `__ is +# deprecated since version 2.0. Please consider using +# `the Stable REST API `__. +# For more information on migration, see +# `RELEASE_NOTES.rst `_ +enable_experimental_api = False + +# Comma separated list of auth backends to authenticate users of the API. See +# https://airflow.apache.org/docs/apache-airflow/stable/security/api.html for possible values. +# ("airflow.api.auth.backend.default" allows all requests for historic reasons) +auth_backends = airflow.api.auth.backend.session + +# Used to set the maximum page limit for API requests +maximum_page_limit = 100 + +# Used to set the default page limit when limit is zero. A default limit +# of 100 is set on OpenApi spec. However, this particular default limit +# only work when limit is set equal to zero(0) from API requests. +# If no limit is supplied, the OpenApi spec default is used. +fallback_page_limit = 100 + +# The intended audience for JWT token credentials used for authorization. This value must match on the client and server sides. If empty, audience will not be tested. +# Example: google_oauth2_audience = project-id-random-value.apps.googleusercontent.com +google_oauth2_audience = + +# Path to Google Cloud Service Account key file (JSON). If omitted, authorization based on +# `the Application Default Credentials +# `__ will +# be used. +# Example: google_key_path = /files/service-account-json +google_key_path = + +# Used in response to a preflight request to indicate which HTTP +# headers can be used when making the actual request. This header is +# the server side response to the browser's +# Access-Control-Request-Headers header. +access_control_allow_headers = + +# Specifies the method or methods allowed when accessing the resource. +access_control_allow_methods = + +# Indicates whether the response can be shared with requesting code from the given origins. +# Separate URLs with space. +access_control_allow_origins = + +[lineage] +# what lineage backend to use +backend = + +[atlas] +sasl_enabled = False +host = +port = 21000 +username = +password = + +[operators] +# The default owner assigned to each new operator, unless +# provided explicitly or passed via ``default_args`` +default_owner = airflow +default_cpus = 1 +default_ram = 512 +default_disk = 512 +default_gpus = 0 + +# Default queue that tasks get assigned to and that worker listen on. +default_queue = default + +# Is allowed to pass additional/unused arguments (args, kwargs) to the BaseOperator operator. +# If set to False, an exception will be thrown, otherwise only the console message will be displayed. +allow_illegal_arguments = False + +[hive] +# Default mapreduce queue for HiveOperator tasks +default_hive_mapred_queue = + +# Template for mapred_job_name in HiveOperator, supports the following named parameters +# hostname, dag_id, task_id, execution_date +# mapred_job_name_template = + +[webserver] +# The base url of your website as airflow cannot guess what domain or +# cname you are using. This is used in automated emails that +# airflow sends to point links to the right web server +base_url = http://localhost:8080 + +# Default timezone to display all dates in the UI, can be UTC, system, or +# any IANA timezone string (e.g. Europe/Amsterdam). If left empty the +# default value of core/default_timezone will be used +# Example: default_ui_timezone = America/New_York +default_ui_timezone = UTC + +# The ip specified when starting the web server +web_server_host = 0.0.0.0 + +# The port on which to run the web server +web_server_port = 8080 + +# Paths to the SSL certificate and key for the web server. When both are +# provided SSL will be enabled. This does not change the web server port. +web_server_ssl_cert = + +# Paths to the SSL certificate and key for the web server. When both are +# provided SSL will be enabled. This does not change the web server port. +web_server_ssl_key = + +# The type of backend used to store web session data, can be 'database' or 'securecookie' +# Example: session_backend = securecookie +session_backend = database + +# Number of seconds the webserver waits before killing gunicorn master that doesn't respond +web_server_master_timeout = 120 + +# Number of seconds the gunicorn webserver waits before timing out on a worker +web_server_worker_timeout = 120 + +# Number of workers to refresh at a time. When set to 0, worker refresh is +# disabled. When nonzero, airflow periodically refreshes webserver workers by +# bringing up new ones and killing old ones. +worker_refresh_batch_size = 1 + +# Number of seconds to wait before refreshing a batch of workers. +worker_refresh_interval = 6000 + +# If set to True, Airflow will track files in plugins_folder directory. When it detects changes, +# then reload the gunicorn. +reload_on_plugin_change = False + +# Secret key used to run your flask app. It should be as random as possible. However, when running +# more than 1 instances of webserver, make sure all of them use the same ``secret_key`` otherwise +# one of them will error with "CSRF session token is missing". +# The webserver key is also used to authorize requests to Celery workers when logs are retrieved. +# The token generated using the secret key has a short expiry time though - make sure that time on +# ALL the machines that you run airflow components on is synchronized (for example using ntpd) +# otherwise you might get "forbidden" errors when the logs are accessed. +secret_key = {{airflow_secret_key}} + +# Number of workers to run the Gunicorn web server +workers = 4 + +# The worker class gunicorn should use. Choices include +# sync (default), eventlet, gevent. Note when using gevent you might also want to set the +# "_AIRFLOW_PATCH_GEVENT" environment variable to "1" to make sure gevent patching is done as +# early as possible. +worker_class = sync + +# Log files for the gunicorn webserver. '-' means log to stderr. +access_logfile = - + +# Log files for the gunicorn webserver. '-' means log to stderr. +error_logfile = - + +# Access log format for gunicorn webserver. +# default format is %%(h)s %%(l)s %%(u)s %%(t)s "%%(r)s" %%(s)s %%(b)s "%%(f)s" "%%(a)s" +# documentation - https://docs.gunicorn.org/en/stable/settings.html#access-log-format +access_logformat = + +# Expose the configuration file in the web server. Set to "non-sensitive-only" to show all values +# except those that have security implications. "True" shows all values. "False" hides the +# configuration completely. +expose_config = False + +# Expose hostname in the web server +expose_hostname = True + +# Expose stacktrace in the web server +expose_stacktrace = False + +# Default DAG view. Valid values are: ``grid``, ``graph``, ``duration``, ``gantt``, ``landing_times`` +dag_default_view = grid + +# Default DAG orientation. Valid values are: +# ``LR`` (Left->Right), ``TB`` (Top->Bottom), ``RL`` (Right->Left), ``BT`` (Bottom->Top) +dag_orientation = LR + +# The amount of time (in secs) webserver will wait for initial handshake +# while fetching logs from other worker machine +log_fetch_timeout_sec = 5 + +# Time interval (in secs) to wait before next log fetching. +log_fetch_delay_sec = 2 + +# Distance away from page bottom to enable auto tailing. +log_auto_tailing_offset = 30 + +# Animation speed for auto tailing log display. +log_animation_speed = 1000 + +# By default, the webserver shows paused DAGs. Flip this to hide paused +# DAGs by default +hide_paused_dags_by_default = False + +# Consistent page size across all listing views in the UI +page_size = 100 + +# Define the color of navigation bar +navbar_color = #fff + +# Default dagrun to show in UI +default_dag_run_display_number = 25 + +# Enable werkzeug ``ProxyFix`` middleware for reverse proxy +enable_proxy_fix = False + +# Number of values to trust for ``X-Forwarded-For``. +# More info: https://werkzeug.palletsprojects.com/en/0.16.x/middleware/proxy_fix/ +proxy_fix_x_for = 1 + +# Number of values to trust for ``X-Forwarded-Proto`` +proxy_fix_x_proto = 1 + +# Number of values to trust for ``X-Forwarded-Host`` +proxy_fix_x_host = 1 + +# Number of values to trust for ``X-Forwarded-Port`` +proxy_fix_x_port = 1 + +# Number of values to trust for ``X-Forwarded-Prefix`` +proxy_fix_x_prefix = 1 + +# Set secure flag on session cookie +cookie_secure = False + +# Set samesite policy on session cookie +cookie_samesite = Lax + +# Default setting for wrap toggle on DAG code and TI log views. +default_wrap = False + +# Allow the UI to be rendered in a frame +x_frame_enabled = True + +# Send anonymous user activity to your analytics tool +# choose from google_analytics, segment, or metarouter +# analytics_tool = + +# Unique ID of your account in the analytics tool +# analytics_id = + +# 'Recent Tasks' stats will show for old DagRuns if set +show_recent_stats_for_completed_runs = True + +# Update FAB permissions and sync security manager roles +# on webserver startup +update_fab_perms = True + +# The UI cookie lifetime in minutes. User will be logged out from UI after +# ``session_lifetime_minutes`` of non-activity +session_lifetime_minutes = 43200 + +# Sets a custom page title for the DAGs overview page and site title for all pages +# instance_name = + +# Whether the custom page title for the DAGs overview page contains any Markup language +instance_name_has_markup = False + +# How frequently, in seconds, the DAG data will auto-refresh in graph or grid view +# when auto-refresh is turned on +auto_refresh_interval = 3 + +# Boolean for displaying warning for publicly viewable deployment +warn_deployment_exposure = True + +# Comma separated string of view events to exclude from dag audit view. +# All other events will be added minus the ones passed here. +# The audit logs in the db will not be affected by this parameter. +audit_view_excluded_events = gantt,landing_times,tries,duration,calendar,graph,grid,tree,tree_data + +# Comma separated string of view events to include in dag audit view. +# If passed, only these events will populate the dag audit view. +# The audit logs in the db will not be affected by this parameter. +# Example: audit_view_included_events = dagrun_cleared,failed +# audit_view_included_events = + +[email] + +# Configuration email backend and whether to +# send email alerts on retry or failure +# Email backend to use +email_backend = airflow.utils.email.send_email_smtp + +# Email connection to use +email_conn_id = smtp_default + +# Whether email alerts should be sent when a task is retried +default_email_on_retry = True + +# Whether email alerts should be sent when a task failed +default_email_on_failure = True + +# File that will be used as the template for Email subject (which will be rendered using Jinja2). +# If not set, Airflow uses a base template. +# Example: subject_template = /path/to/my_subject_template_file +# subject_template = + +# File that will be used as the template for Email content (which will be rendered using Jinja2). +# If not set, Airflow uses a base template. +# Example: html_content_template = /path/to/my_html_content_template_file +# html_content_template = + +# Email address that will be used as sender address. +# It can either be raw email or the complete address in a format ``Sender Name `` +# Example: from_email = Airflow +# from_email = + +[smtp] + +# If you want airflow to send emails on retries, failure, and you want to use +# the airflow.utils.email.send_email_smtp function, you have to configure an +# smtp server here +smtp_host = localhost +smtp_starttls = True +smtp_ssl = False +# Example: smtp_user = airflow +# smtp_user = +# Example: smtp_password = airflow +# smtp_password = +smtp_port = 25 +smtp_mail_from = airflow@example.com +smtp_timeout = 30 +smtp_retry_limit = 5 + +[sentry] + +# Sentry (https://docs.sentry.io) integration. Here you can supply +# additional configuration options based on the Python platform. See: +# https://docs.sentry.io/error-reporting/configuration/?platform=python. +# Unsupported options: ``integrations``, ``in_app_include``, ``in_app_exclude``, +# ``ignore_errors``, ``before_breadcrumb``, ``transport``. +# Enable error reporting to Sentry +sentry_on = false +sentry_dsn = + +# Dotted path to a before_send function that the sentry SDK should be configured to use. +# before_send = + +[local_kubernetes_executor] + +# This section only applies if you are using the ``LocalKubernetesExecutor`` in +# ``[core]`` section above +# Define when to send a task to ``KubernetesExecutor`` when using ``LocalKubernetesExecutor``. +# When the queue of a task is the value of ``kubernetes_queue`` (default ``kubernetes``), +# the task is executed via ``KubernetesExecutor``, +# otherwise via ``LocalExecutor`` +kubernetes_queue = kubernetes + +[celery_kubernetes_executor] + +# This section only applies if you are using the ``CeleryKubernetesExecutor`` in +# ``[core]`` section above +# Define when to send a task to ``KubernetesExecutor`` when using ``CeleryKubernetesExecutor``. +# When the queue of a task is the value of ``kubernetes_queue`` (default ``kubernetes``), +# the task is executed via ``KubernetesExecutor``, +# otherwise via ``CeleryExecutor`` +kubernetes_queue = kubernetes + +[celery] + +# This section only applies if you are using the CeleryExecutor in +# ``[core]`` section above +# The app name that will be used by celery +celery_app_name = airflow.executors.celery_executor + +# The concurrency that will be used when starting workers with the +# ``airflow celery worker`` command. This defines the number of task instances that +# a worker will take, so size up your workers based on the resources on +# your worker box and the nature of your tasks +worker_concurrency = 16 + +# The maximum and minimum concurrency that will be used when starting workers with the +# ``airflow celery worker`` command (always keep minimum processes, but grow +# to maximum if necessary). Note the value should be max_concurrency,min_concurrency +# Pick these numbers based on resources on worker box and the nature of the task. +# If autoscale option is available, worker_concurrency will be ignored. +# http://docs.celeryproject.org/en/latest/reference/celery.bin.worker.html#cmdoption-celery-worker-autoscale +# Example: worker_autoscale = 16,12 +# worker_autoscale = + +# Used to increase the number of tasks that a worker prefetches which can improve performance. +# The number of processes multiplied by worker_prefetch_multiplier is the number of tasks +# that are prefetched by a worker. A value greater than 1 can result in tasks being unnecessarily +# blocked if there are multiple workers and one worker prefetches tasks that sit behind long +# running tasks while another worker has unutilized processes that are unable to process the already +# claimed blocked tasks. +# https://docs.celeryproject.org/en/stable/userguide/optimizing.html#prefetch-limits +worker_prefetch_multiplier = 1 + +# Specify if remote control of the workers is enabled. +# When using Amazon SQS as the broker, Celery creates lots of ``.*reply-celery-pidbox`` queues. You can +# prevent this by setting this to false. However, with this disabled Flower won't work. +worker_enable_remote_control = true + +# The Celery broker URL. Celery supports RabbitMQ, Redis and experimentally +# a sqlalchemy database. Refer to the Celery documentation for more information. +broker_url = redis://redis:6379/0 + +# The Celery result_backend. When a job finishes, it needs to update the +# metadata of the job. Therefore it will post a message on a message bus, +# or insert it into a database (depending of the backend) +# This status is used by the scheduler to update the state of the task +# The use of a database is highly recommended +# When not specified, sql_alchemy_conn with a db+ scheme prefix will be used +# http://docs.celeryproject.org/en/latest/userguide/configuration.html#task-result-backend-settings +# Example: result_backend = db+postgresql://postgres:airflow@postgres/airflow +# result_backend = + +# Celery Flower is a sweet UI for Celery. Airflow has a shortcut to start +# it ``airflow celery flower``. This defines the IP that Celery Flower runs on +flower_host = 0.0.0.0 + +# The root URL for Flower +# Example: flower_url_prefix = /flower +flower_url_prefix = + +# This defines the port that Celery Flower runs on +flower_port = 5555 + +# Securing Flower with Basic Authentication +# Accepts user:password pairs separated by a comma +# Example: flower_basic_auth = user1:password1,user2:password2 +flower_basic_auth = + +# How many processes CeleryExecutor uses to sync task state. +# 0 means to use max(1, number of cores - 1) processes. +sync_parallelism = 0 + +# Import path for celery configuration options +celery_config_options = airflow.config_templates.default_celery.DEFAULT_CELERY_CONFIG +ssl_active = False +ssl_key = +ssl_cert = +ssl_cacert = + +# Celery Pool implementation. +# Choices include: ``prefork`` (default), ``eventlet``, ``gevent`` or ``solo``. +# See: +# https://docs.celeryproject.org/en/latest/userguide/workers.html#concurrency +# https://docs.celeryproject.org/en/latest/userguide/concurrency/eventlet.html +pool = prefork + +# The number of seconds to wait before timing out ``send_task_to_executor`` or +# ``fetch_celery_task_state`` operations. +operation_timeout = 1.0 + +# Celery task will report its status as 'started' when the task is executed by a worker. +# This is used in Airflow to keep track of the running tasks and if a Scheduler is restarted +# or run in HA mode, it can adopt the orphan tasks launched by previous SchedulerJob. +task_track_started = True + +# Time in seconds after which adopted tasks which are queued in celery are assumed to be stalled, +# and are automatically rescheduled. This setting does the same thing as ``stalled_task_timeout`` but +# applies specifically to adopted tasks only. When set to 0, the ``stalled_task_timeout`` setting +# also applies to adopted tasks. +task_adoption_timeout = 600 + +# Time in seconds after which tasks queued in celery are assumed to be stalled, and are automatically +# rescheduled. Adopted tasks will instead use the ``task_adoption_timeout`` setting if specified. +# When set to 0, automatic clearing of stalled tasks is disabled. +stalled_task_timeout = 0 + +# The Maximum number of retries for publishing task messages to the broker when failing +# due to ``AirflowTaskTimeout`` error before giving up and marking Task as failed. +task_publish_max_retries = 3 + +# Worker initialisation check to validate Metadata Database connection +worker_precheck = False + +[celery_broker_transport_options] + +# This section is for specifying options which can be passed to the +# underlying celery broker transport. See: +# http://docs.celeryproject.org/en/latest/userguide/configuration.html#std:setting-broker_transport_options +# The visibility timeout defines the number of seconds to wait for the worker +# to acknowledge the task before the message is redelivered to another worker. +# Make sure to increase the visibility timeout to match the time of the longest +# ETA you're planning to use. +# visibility_timeout is only supported for Redis and SQS celery brokers. +# See: +# http://docs.celeryproject.org/en/master/userguide/configuration.html#std:setting-broker_transport_options +# Example: visibility_timeout = 21600 +# visibility_timeout = + +[dask] + +# This section only applies if you are using the DaskExecutor in +# [core] section above +# The IP address and port of the Dask cluster's scheduler. +cluster_address = 127.0.0.1:8786 + +# TLS/ SSL settings to access a secured Dask scheduler. +tls_ca = +tls_cert = +tls_key = + +[scheduler] +# Task instances listen for external kill signal (when you clear tasks +# from the CLI or the UI), this defines the frequency at which they should +# listen (in seconds). +job_heartbeat_sec = 5 + +# The scheduler constantly tries to trigger new tasks (look at the +# scheduler section in the docs for more information). This defines +# how often the scheduler should run (in seconds). +scheduler_heartbeat_sec = 5 + +# The number of times to try to schedule each DAG file +# -1 indicates unlimited number +num_runs = -1 + +# Controls how long the scheduler will sleep between loops, but if there was nothing to do +# in the loop. i.e. if it scheduled something then it will start the next loop +# iteration straight away. +scheduler_idle_sleep_time = 1 + +# Number of seconds after which a DAG file is parsed. The DAG file is parsed every +# ``min_file_process_interval`` number of seconds. Updates to DAGs are reflected after +# this interval. Keeping this number low will increase CPU usage. +min_file_process_interval = 30 + +# How often (in seconds) to check for stale DAGs (DAGs which are no longer present in +# the expected files) which should be deactivated, as well as datasets that are no longer +# referenced and should be marked as orphaned. +parsing_cleanup_interval = 60 + +# How often (in seconds) to scan the DAGs directory for new files. Default to 5 minutes. +dag_dir_list_interval = 300 + +# How often should stats be printed to the logs. Setting to 0 will disable printing stats +print_stats_interval = 30 + +# How often (in seconds) should pool usage stats be sent to StatsD (if statsd_on is enabled) +pool_metrics_interval = 5.0 + +# If the last scheduler heartbeat happened more than scheduler_health_check_threshold +# ago (in seconds), scheduler is considered unhealthy. +# This is used by the health check in the "/health" endpoint +scheduler_health_check_threshold = 30 + +# When you start a scheduler, airflow starts a tiny web server +# subprocess to serve a health check if this is set to True +enable_health_check = False + +# When you start a scheduler, airflow starts a tiny web server +# subprocess to serve a health check on this port +scheduler_health_check_server_port = 8974 + +# How often (in seconds) should the scheduler check for orphaned tasks and SchedulerJobs +orphaned_tasks_check_interval = 300.0 +child_process_log_directory = /app/airflow/logs/scheduler + +# Local task jobs periodically heartbeat to the DB. If the job has +# not heartbeat in this many seconds, the scheduler will mark the +# associated task instance as failed and will re-schedule the task. +scheduler_zombie_task_threshold = 300 + +# How often (in seconds) should the scheduler check for zombie tasks. +zombie_detection_interval = 10.0 + +# Turn off scheduler catchup by setting this to ``False``. +# Default behavior is unchanged and +# Command Line Backfills still work, but the scheduler +# will not do scheduler catchup if this is ``False``, +# however it can be set on a per DAG basis in the +# DAG definition (catchup) +catchup_by_default = True + +# Setting this to True will make first task instance of a task +# ignore depends_on_past setting. A task instance will be considered +# as the first task instance of a task when there is no task instance +# in the DB with an execution_date earlier than it., i.e. no manual marking +# success will be needed for a newly added task to be scheduled. +ignore_first_depends_on_past_by_default = True + +# This changes the batch size of queries in the scheduling main loop. +# If this is too high, SQL query performance may be impacted by +# complexity of query predicate, and/or excessive locking. +# Additionally, you may hit the maximum allowable query length for your db. +# Set this to 0 for no limit (not advised) +max_tis_per_query = 512 + +# Should the scheduler issue ``SELECT ... FOR UPDATE`` in relevant queries. +# If this is set to False then you should not run more than a single +# scheduler at once +use_row_level_locking = True + +# Max number of DAGs to create DagRuns for per scheduler loop. +max_dagruns_to_create_per_loop = 10 + +# How many DagRuns should a scheduler examine (and lock) when scheduling +# and queuing tasks. +max_dagruns_per_loop_to_schedule = 20 + +# Should the Task supervisor process perform a "mini scheduler" to attempt to schedule more tasks of the +# same DAG. Leaving this on will mean tasks in the same DAG execute quicker, but might starve out other +# dags in some circumstances +schedule_after_task_execution = True + +# The scheduler can run multiple processes in parallel to parse dags. +# This defines how many processes will run. +parsing_processes = 2 + +# One of ``modified_time``, ``random_seeded_by_host`` and ``alphabetical``. +# The scheduler will list and sort the dag files to decide the parsing order. +# +# * ``modified_time``: Sort by modified time of the files. This is useful on large scale to parse the +# recently modified DAGs first. +# * ``random_seeded_by_host``: Sort randomly across multiple Schedulers but with same order on the +# same host. This is useful when running with Scheduler in HA mode where each scheduler can +# parse different DAG files. +# * ``alphabetical``: Sort by filename +file_parsing_sort_mode = modified_time + +# Whether the dag processor is running as a standalone process or it is a subprocess of a scheduler +# job. +standalone_dag_processor = False + +# Only applicable if `[scheduler]standalone_dag_processor` is true and callbacks are stored +# in database. Contains maximum number of callbacks that are fetched during a single loop. +max_callbacks_per_loop = 20 + +# Only applicable if `[scheduler]standalone_dag_processor` is true. +# Time in seconds after which dags, which were not updated by Dag Processor are deactivated. +dag_stale_not_seen_duration = 600 + +# Turn off scheduler use of cron intervals by setting this to False. +# DAGs submitted manually in the web UI or with trigger_dag will still run. +use_job_schedule = True + +# Allow externally triggered DagRuns for Execution Dates in the future +# Only has effect if schedule_interval is set to None in DAG +allow_trigger_in_future = False + +# How often to check for expired trigger requests that have not run yet. +trigger_timeout_check_interval = 15 + +[triggerer] +# How many triggers a single Triggerer will run at once, by default. +default_capacity = 1000 + +[kerberos] +ccache = /tmp/airflow_krb5_ccache + +# gets augmented with fqdn +principal = airflow +reinit_frequency = 3600 +kinit_path = kinit +keytab = airflow.keytab + +# Allow to disable ticket forwardability. +forwardable = True + +# Allow to remove source IP from token, useful when using token behind NATted Docker host. +include_ip = True + +[elasticsearch] +# Elasticsearch host +host = + +# Format of the log_id, which is used to query for a given tasks logs +log_id_template = {dag_id}-{task_id}-{run_id}-{map_index}-{try_number} + +# Used to mark the end of a log stream for a task +end_of_log_mark = end_of_log + +# Qualified URL for an elasticsearch frontend (like Kibana) with a template argument for log_id +# Code will construct log_id using the log_id template from the argument above. +# NOTE: scheme will default to https if one is not provided +# Example: frontend = http://localhost:5601/app/kibana#/discover?_a=(columns:!(message),query:(language:kuery,query:'log_id: "{log_id}"'),sort:!(log.offset,asc)) +frontend = + +# Write the task logs to the stdout of the worker, rather than the default files +write_stdout = False + +# Instead of the default log formatter, write the log lines as JSON +json_format = False + +# Log fields to also attach to the json output, if enabled +json_fields = asctime, filename, lineno, levelname, message + +# The field where host name is stored (normally either `host` or `host.name`) +host_field = host + +# The field where offset is stored (normally either `offset` or `log.offset`) +offset_field = offset + +[elasticsearch_configs] +use_ssl = False +verify_certs = True + +[kubernetes_executor] +# Path to the YAML pod file that forms the basis for KubernetesExecutor workers. +pod_template_file = + +# The repository of the Kubernetes Image for the Worker to Run +worker_container_repository = + +# The tag of the Kubernetes Image for the Worker to Run +worker_container_tag = + +# The Kubernetes namespace where airflow workers should be created. Defaults to ``default`` +namespace = default + +# If True, all worker pods will be deleted upon termination +delete_worker_pods = True + +# If False (and delete_worker_pods is True), +# failed worker pods will not be deleted so users can investigate them. +# This only prevents removal of worker pods where the worker itself failed, +# not when the task it ran failed. +delete_worker_pods_on_failure = False + +# Number of Kubernetes Worker Pod creation calls per scheduler loop. +# Note that the current default of "1" will only launch a single pod +# per-heartbeat. It is HIGHLY recommended that users increase this +# number to match the tolerance of their kubernetes cluster for +# better performance. +worker_pods_creation_batch_size = 1 + +# Allows users to launch pods in multiple namespaces. +# Will require creating a cluster-role for the scheduler +multi_namespace_mode = False + +# Use the service account kubernetes gives to pods to connect to kubernetes cluster. +# It's intended for clients that expect to be running inside a pod running on kubernetes. +# It will raise an exception if called from a process not running in a kubernetes environment. +in_cluster = True + +# When running with in_cluster=False change the default cluster_context or config_file +# options to Kubernetes client. Leave blank these to use default behaviour like ``kubectl`` has. +# cluster_context = + +# Path to the kubernetes configfile to be used when ``in_cluster`` is set to False +# config_file = + +# Keyword parameters to pass while calling a kubernetes client core_v1_api methods +# from Kubernetes Executor provided as a single line formatted JSON dictionary string. +# List of supported params are similar for all core_v1_apis, hence a single config +# variable for all apis. See: +# https://raw.githubusercontent.com/kubernetes-client/python/41f11a09995efcd0142e25946adc7591431bfb2f/kubernetes/client/api/core_v1_api.py +kube_client_request_args = + +# Optional keyword arguments to pass to the ``delete_namespaced_pod`` kubernetes client +# ``core_v1_api`` method when using the Kubernetes Executor. +# This should be an object and can contain any of the options listed in the ``v1DeleteOptions`` +# class defined here: +# https://github.com/kubernetes-client/python/blob/41f11a09995efcd0142e25946adc7591431bfb2f/kubernetes/client/models/v1_delete_options.py#L19 +# Example: delete_option_kwargs = {"grace_period_seconds": 10} +delete_option_kwargs = + +# Enables TCP keepalive mechanism. This prevents Kubernetes API requests to hang indefinitely +# when idle connection is time-outed on services like cloud load balancers or firewalls. +enable_tcp_keepalive = True + +# When the `enable_tcp_keepalive` option is enabled, TCP probes a connection that has +# been idle for `tcp_keep_idle` seconds. +tcp_keep_idle = 120 + +# When the `enable_tcp_keepalive` option is enabled, if Kubernetes API does not respond +# to a keepalive probe, TCP retransmits the probe after `tcp_keep_intvl` seconds. +tcp_keep_intvl = 30 + +# When the `enable_tcp_keepalive` option is enabled, if Kubernetes API does not respond +# to a keepalive probe, TCP retransmits the probe `tcp_keep_cnt number` of times before +# a connection is considered to be broken. +tcp_keep_cnt = 6 + +# Set this to false to skip verifying SSL certificate of Kubernetes python client. +verify_ssl = True + +# How long in seconds a worker can be in Pending before it is considered a failure +worker_pods_pending_timeout = 300 + +# How often in seconds to check if Pending workers have exceeded their timeouts +worker_pods_pending_timeout_check_interval = 120 + +# How often in seconds to check for task instances stuck in "queued" status without a pod +worker_pods_queued_check_interval = 60 + +# How many pending pods to check for timeout violations in each check interval. +# You may want this higher if you have a very large cluster and/or use ``multi_namespace_mode``. +worker_pods_pending_timeout_batch_size = 100 + +[sensors] +# Sensor default timeout, 7 days by default (7 * 24 * 60 * 60). +default_timeout = 604800 diff --git a/ee/recommendation/ml_trainer/airflow/dags/training_dag.py b/ee/recommendation/ml_trainer/airflow/dags/training_dag.py new file mode 100644 index 000000000..ea81f21cf --- /dev/null +++ b/ee/recommendation/ml_trainer/airflow/dags/training_dag.py @@ -0,0 +1,128 @@ +import asyncio +import hashlib +import mlflow +import os +import pendulum +import sys +from airflow import DAG +from airflow.operators.bash import BashOperator +from airflow.operators.python import PythonOperator, ShortCircuitOperator +from datetime import timedelta +from decouple import config +from textwrap import dedent +_work_dir = os.getcwd() +sys.path.insert(1, _work_dir) +from utils import pg_client + + +client = mlflow.MlflowClient() +models = [model.name for model in client.search_registered_models()] + + +def split_training(ti): + global models + projects = ti.xcom_pull(key='project_data').split(' ') + tenants = ti.xcom_pull(key='tenant_data').split(' ') + new_projects = list() + old_projects = list() + new_tenants = list() + old_tenants = list() + for i in range(len(projects)): + hashed = hashlib.sha256(bytes(f'{projects[i]}-{tenants[i]}'.encode('utf-8'))).hexdigest() + _model_name = f'{hashed}-RecModel' + if _model_name in models: + old_projects.append(projects[i]) + old_tenants.append(tenants[i]) + else: + new_projects.append(projects[i]) + new_tenants.append(tenants[i]) + ti.xcom_push(key='new_project_data', value=' '.join(new_projects)) + ti.xcom_push(key='new_tenant_data', value=' '.join(new_tenants)) + ti.xcom_push(key='old_project_data', value=' '.join(old_projects)) + ti.xcom_push(key='old_tenant_data', value=' '.join(old_tenants)) + + +def continue_new(ti): + L = ti.xcom_pull(key='new_project_data') + return len(L) > 0 + + +def continue_old(ti): + L = ti.xcom_pull(key='old_project_data') + return len(L) > 0 + + +def select_from_db(ti): + os.environ['PG_POOL'] = 'true' + asyncio.run(pg_client.init()) + with pg_client.PostgresClient() as conn: + conn.execute("""SELECT tenant_id, project_id as project_id + FROM ((SELECT project_id + FROM frontend_signals + GROUP BY project_id + HAVING count(1) > 10) AS T1 + INNER JOIN projects AS T2 USING (project_id));""") + res = conn.fetchall() + projects = list() + tenants = list() + for e in res: + projects.append(str(e['project_id'])) + tenants.append(str(e['tenant_id'])) + asyncio.run(pg_client.terminate()) + ti.xcom_push(key='project_data', value=' '.join(projects)) + ti.xcom_push(key='tenant_data', value=' '.join(tenants)) + + +dag = DAG( + "first_test", + default_args={ + "retries": 1, + "retry_delay": timedelta(minutes=3), + }, + start_date=pendulum.datetime(2015, 12, 1, tz="UTC"), + description="My first test", + schedule=config('crons_train', default='@daily'), + catchup=False, +) + +# assigning the task for our dag to do +with dag: + split = PythonOperator( + task_id='Split_Create_and_Retrain', + provide_context=True, + python_callable=split_training, + do_xcom_push=True + ) + + select_vp = PythonOperator( + task_id='Select_Valid_Projects', + provide_context=True, + python_callable=select_from_db, + do_xcom_push=True + ) + + dag_split1 = ShortCircuitOperator( + task_id='Create_Condition', + python_callable=continue_new, + ) + + dag_split2 = ShortCircuitOperator( + task_id='Retrain_Condition', + python_callable=continue_old, + ) + + new_models = BashOperator( + task_id='Create_Models', + bash_command=f"python {_work_dir}/main.py " + "--projects {{task_instance.xcom_pull(task_ids='Split_Create_and_Retrain', key='new_project_data')}} " + + "--tenants {{task_instance.xcom_pull(task_ids='Split_Create_and_Retrain', key='new_tenant_data')}}", + ) + + old_models = BashOperator( + task_id='Retrain_Models', + bash_command=f"python {_work_dir}/main.py " + "--projects {{task_instance.xcom_pull(task_ids='Split_Create_and_Retrain', key='old_project_data')}} " + + "--tenants {{task_instance.xcom_pull(task_ids='Split_Create_and_Retrain', key='old_tenant_data')}}", + ) + + select_vp >> split >> [dag_split1, dag_split2] + dag_split1 >> new_models + dag_split2 >> old_models diff --git a/ee/recommendation/ml_trainer/build_dev.sh b/ee/recommendation/ml_trainer/build_dev.sh new file mode 100755 index 000000000..c11741677 --- /dev/null +++ b/ee/recommendation/ml_trainer/build_dev.sh @@ -0,0 +1,2 @@ +cp ../../api/chalicelib/utils/ch_client.py utils +cp ../../../api/chalicelib/utils/pg_client.py utils diff --git a/ee/recommendation/ml_trainer/entrypoint.sh b/ee/recommendation/ml_trainer/entrypoint.sh new file mode 100755 index 000000000..68e119897 --- /dev/null +++ b/ee/recommendation/ml_trainer/entrypoint.sh @@ -0,0 +1,20 @@ +# Values setup +find airflow/ -type f -name "*.cfg" -exec sed -i "s/{{pg_user_airflow}}/${pg_user_airflow}/g" {} \; +find airflow/ -type f -name "*.cfg" -exec sed -i "s/{{pg_password_airflow}}/${pg_password_airflow}/g" {} \; +find airflow/ -type f -name "*.cfg" -exec sed -i "s/{{pg_host_airflow}}/${pg_host_airflow}/g" {} \; +find airflow/ -type f -name "*.cfg" -exec sed -i "s/{{pg_port_airflow}}/${pg_port_airflow}/g" {} \; +find airflow/ -type f -name "*.cfg" -exec sed -i "s/{{pg_dbname_airflow}}/${pg_dbname_airflow}/g" {} \; +find airflow/ -type f -name "*.cfg" -exec sed -i "s#{{airflow_secret_key}}#${airflow_secret_key}#g" {} \; +export MLFLOW_TRACKING_URI=postgresql+psycopg2://${pg_user_ml}:${pg_password_ml}@${pg_host_ml}:${pg_port_ml}/${pg_dbname_ml} +git init airflow/dags +# Airflow setup +airflow db init +airflow users create \ + --username admin \ + --firstname admin \ + --lastname admin \ + --role Admin \ + --email admin@admin.admin \ + -p ${airflow_admin_password} +# Run services +airflow webserver --port 8080 & airflow scheduler & ./mlflow_server.sh diff --git a/ee/recommendation/ml_trainer/main.py b/ee/recommendation/ml_trainer/main.py new file mode 100644 index 000000000..53ae07ea7 --- /dev/null +++ b/ee/recommendation/ml_trainer/main.py @@ -0,0 +1,118 @@ +import mlflow +import hashlib +import argparse +import numpy as np +from decouple import config +from datetime import datetime +from core.user_features import get_training_database +from core.recommendation_model import SVM_recommendation, sort_database + +mlflow.set_tracking_uri(config('MLFLOW_TRACKING_URI')) + + +def handle_database(x_train, y_train): + """ + Verifies if database is well-balanced. If not and if possible fixes it. + """ + total = len(y_train) + if total < 13: + return None, None + train_balance = y_train.sum() / total + if train_balance < 0.4: + positives = y_train[y_train == 1] + n_positives = len(positives) + x_positive = x_train[y_train == 1] + if n_positives < 7: + return None, None + else: + n_negatives_expected = min(int(n_positives/0.4), total-y_train.sum()) + negatives = y_train[y_train == 0][:n_negatives_expected] + x_negative = x_train[y_train == 0][:n_negatives_expected] + return np.concatenate((x_positive, x_negative), axis=0), np.concatenate((negatives, positives), axis=0) + elif train_balance > 0.6: + negatives = y_train[y_train == 0] + n_negatives = len(negatives) + x_negative = x_train[y_train == 0] + if n_negatives < 7: + return None, None + else: + n_positives_expected = min(int(n_negatives / 0.4), y_train.sum()) + positives = y_train[y_train == 1][:n_positives_expected] + x_positive = x_train[y_train == 1][:n_positives_expected] + return np.concatenate((x_positive, x_negative), axis=0), np.concatenate((negatives, positives), axis=0) + else: + return x_train, y_train + + +def main(experiment_name, projectId, tenantId): + """ + Main training method using mlflow for tracking and s3 for stocking. + Params: + experiment_name: experiment name for mlflow repo. + projectId: project id of sessions. + tenantId: tenant of the project id (used mainly as salt for hashing). + """ + hashed = hashlib.sha256(bytes(f'{projectId}-{tenantId}'.encode('utf-8'))).hexdigest() + x_, y_, d = get_training_database(projectId, max_timestamp=1680248412284, favorites=True) + + x, y = handle_database(x_, y_) + if x is None: + print(f'[INFO] Project {projectId}: Not enough data to train model - {y_.sum()}/{len(y_)-y_.sum()}') + return + x, y = sort_database(x, y) + + _experiment = mlflow.get_experiment_by_name(experiment_name) + if _experiment is None: + artifact_uri = config('MODELS_S3_BUCKET', default='./mlruns') + mlflow.create_experiment(experiment_name, artifact_uri) + mlflow.set_experiment(experiment_name) + with mlflow.start_run(run_name=f'{hashed}-{datetime.now().strftime("%Y-%M-%d_%H:%m")}'): + reg_model_name = f"{hashed}-RecModel" + best_meta = {'score': 0, 'model': None, 'name': 'NoName'} + for kernel in ['linear', 'poly', 'rbf', 'sigmoid']: + with mlflow.start_run(run_name=f'sub_run_with_{kernel}', nested=True): + print("--") + model = SVM_recommendation(kernel=kernel, test=True) + model.fit(x, y) + mlflow.sklearn.log_model(model, "sk_learn", + serialization_format="cloudpickle") + mlflow.log_param("kernel", kernel) + mlflow.log_metric("score", model.score) + for _name, displ in model.plots().items(): + #TODO: Close displays not to overload memory + mlflow.log_figure(displ, f'{_name}.png') + if model.score > best_meta['score']: + best_meta['score'] = model.score + best_meta['model'] = model + best_meta['name'] = kernel + mlflow.log_metric("score", best_meta['score']) + mlflow.log_param("name", best_meta['name']) + mlflow.sklearn.log_model(best_meta['model'], "sk_learn", + serialization_format="cloudpickle", + registered_model_name=reg_model_name, + ) + + +if __name__ == '__main__': + import asyncio + import os + os.environ['PG_POOL'] = 'true' + from utils import pg_client + + asyncio.run(pg_client.init()) + parser = argparse.ArgumentParser( + prog='Recommandation Trainer', + description='This python script aims to create a model able to predict which sessions may be most interesting to replay for the users', + ) + parser.add_argument('--projects', type=int, nargs='+') + parser.add_argument('--tenants', type=int, nargs='+') + + args = parser.parse_args() + + projects = args.projects + tenants = args.tenants + + for i in range(len(projects)): + print(f'Processing project {projects[i]}...') + main(experiment_name='s3-recommendations', projectId=projects[i], tenantId=tenants[i]) + asyncio.run(pg_client.terminate()) diff --git a/ee/recommendation/ml_trainer/mlflow_server.sh b/ee/recommendation/ml_trainer/mlflow_server.sh new file mode 100755 index 000000000..63bbd72e2 --- /dev/null +++ b/ee/recommendation/ml_trainer/mlflow_server.sh @@ -0,0 +1 @@ +mlflow server --backend-store-uri postgresql+psycopg2://${pg_user_ml}:${pg_password_ml}@${pg_host_ml}:${pg_port_ml}/${pg_dbname_ml} --default-artifact-root ${MODELS_S3_BUCKET} --host 0.0.0.0 --port 5000 diff --git a/ee/recommendation/ml_trainer/requirements.txt b/ee/recommendation/ml_trainer/requirements.txt new file mode 100644 index 000000000..4b6ae2175 --- /dev/null +++ b/ee/recommendation/ml_trainer/requirements.txt @@ -0,0 +1,3 @@ +argcomplete==3.0.8 +apache-airflow==2.6.1 +airflow-code-editor==7.2.1 diff --git a/ee/recommendation/ml_trainer/sql/init.sql b/ee/recommendation/ml_trainer/sql/init.sql new file mode 100644 index 000000000..0a88c0ea2 --- /dev/null +++ b/ee/recommendation/ml_trainer/sql/init.sql @@ -0,0 +1,11 @@ +DO +$do$ +BEGIN + IF EXISTS (SELECT FROM pg_database WHERE datname = 'airflow') THEN + RAISE NOTICE 'Database already exists'; -- optional + ELSE + PERFORM dblink_exec('dbname=' || current_database() -- current db + , 'CREATE DATABASE airflow'); + END IF; +END +$do$; diff --git a/ee/recommendation/requirements.txt b/ee/recommendation/requirements.txt deleted file mode 100644 index 7f0d26c2e..000000000 --- a/ee/recommendation/requirements.txt +++ /dev/null @@ -1,22 +0,0 @@ -requests==2.28.1 -urllib3==1.26.12 -pyjwt==2.5.0 -psycopg2-binary==2.9.3 - -numpy -threadpoolctl==3.1.0 -joblib==1.2.0 -scipy -scikit-learn -mlflow - -airflow-code-editor - -pydantic[email]==1.10.2 - -clickhouse-driver==0.2.4 -python3-saml==1.14.0 -python-multipart==0.0.5 -python-decouple - -argcomplete diff --git a/ee/recommendation/requirements_base.txt b/ee/recommendation/requirements_base.txt new file mode 100644 index 000000000..48c72ee26 --- /dev/null +++ b/ee/recommendation/requirements_base.txt @@ -0,0 +1,19 @@ +requests==2.28.2 +urllib3==1.26.12 +pyjwt==2.6.0 +SQLAlchemy==2.0.10 +alembic==1.11.1 +psycopg2-binary==2.9.5 + +joblib==1.2.0 +scipy==1.10.1 +scikit-learn==1.2.2 +mlflow==2.3 + +clickhouse-driver==0.2.5 +python3-saml==1.14.0 +python-multipart==0.0.5 +python-decouple==3.8 +pydantic==1.10.8 + +boto3==1.26.100 diff --git a/ee/recommendation/run.sh b/ee/recommendation/run.sh deleted file mode 100644 index 0a703bca4..000000000 --- a/ee/recommendation/run.sh +++ /dev/null @@ -1,11 +0,0 @@ -echo 'Setting up required modules..' -mkdir scripts -mkdir plugins -mkdir logs -mkdir scripts/utils -cp ../../api/chalicelib/utils/pg_client.py scripts/utils -cp ../api/chalicelib/utils/ch_client.py scripts/utils -echo 'Building containers...' -docker-compose up airflow-init -echo 'Running containers...' -docker-compose up diff --git a/ee/recommendation/scripts/core/features.py b/ee/recommendation/scripts/core/features.py deleted file mode 100644 index c2e21535e..000000000 --- a/ee/recommendation/scripts/core/features.py +++ /dev/null @@ -1,161 +0,0 @@ -from utils.ch_client import ClickHouseClient -from utils.pg_client import PostgresClient - -def get_features_clickhouse(**kwargs): - """Gets features from ClickHouse database""" - if 'limit' in kwargs: - limit = kwargs['limit'] - else: - limit = 500 - query = f"""SELECT session_id, project_id, user_id, events_count, errors_count, duration, country, issue_score, device_type, rage, jsexception, badrequest FROM ( - SELECT session_id, project_id, user_id, events_count, errors_count, duration, toInt8(user_country) as country, issue_score, toInt8(user_device_type) as device_type FROM experimental.sessions WHERE user_id IS NOT NULL) as T1 -INNER JOIN (SELECT session_id, project_id, sum(issue_type = 'click_rage') as rage, sum(issue_type = 'js_exception') as jsexception, sum(issue_type = 'bad_request') as badrequest FROM experimental.events WHERE event_type = 'ISSUE' AND session_id > 0 GROUP BY session_id, project_id LIMIT {limit}) as T2 -ON T1.session_id = T2.session_id AND T1.project_id = T2.project_id;""" - with ClickHouseClient() as conn: - res = conn.execute(query) - return res - - -def get_features_postgres(**kwargs): - with PostgresClient() as conn: - funnels = query_funnels(conn, **kwargs) - metrics = query_metrics(conn, **kwargs) - filters = query_with_filters(conn, **kwargs) - #clean_filters(funnels) - #clean_filters(filters) - return clean_filters_split(funnels, isfunnel=True), metrics, clean_filters_split(filters) - - - -def query_funnels(conn, **kwargs): - """Gets Funnels (PG database)""" - # If public.funnel is empty - funnels_query = f"""SELECT project_id, user_id, filter FROM (SELECT project_id, user_id, metric_id FROM public.metrics WHERE metric_type='funnel' - ) as T1 LEFT JOIN (SELECT filter, metric_id FROM public.metric_series) as T2 ON T1.metric_id = T2.metric_id""" - # Else - # funnels_query = "SELECT project_id, user_id, filter FROM public.funnels" - - conn.execute(funnels_query) - res = conn.fetchall() - return res - - -def query_metrics(conn, **kwargs): - """Gets Metrics (PG_database)""" - metrics_query = """SELECT metric_type, metric_of, metric_value, metric_format FROM public.metrics""" - conn.execute(metrics_query) - res = conn.fetchall() - return res - - -def query_with_filters(conn, **kwargs): - """Gets Metrics with filters (PG database)""" - filters_query = """SELECT T1.metric_id as metric_id, project_id, name, metric_type, metric_of, filter FROM ( - SELECT metric_id, project_id, name, metric_type, metric_of FROM metrics) as T1 INNER JOIN - (SELECT metric_id, filter FROM metric_series WHERE filter != '{}') as T2 ON T1.metric_id = T2.metric_id""" - conn.execute(filters_query) - res = conn.fetchall() - return res - - -def transform_funnel(project_id, user_id, data): - res = list() - for k in range(len(data)): - _tmp = data[k] - if _tmp['project_id'] != project_id or _tmp['user_id'] != user_id: - continue - else: - _tmp = _tmp['filter']['events'] - res.append(_tmp) - return res - - -def transform_with_filter(data, *kwargs): - res = list() - for k in range(len(data)): - _tmp = data[k] - jump = False - for _key in kwargs.keys(): - if data[_key] != kwargs[_key]: - jump = True - break - if jump: - continue - _type = data['metric_type'] - if _type == 'funnel': - res.append(['funnel', _tmp['filter']['events']]) - elif _type == 'timeseries': - res.append(['timeseries', _tmp['filter']['filters'], _tmp['filter']['events']]) - elif _type == 'table': - res.append(['table', _tmp['metric_of'], _tmp['filter']['events']]) - return res - - -def transform(element): - key_ = element.pop('user_id') - secondary_key_ = element.pop('session_id') - context_ = element.pop('project_id') - features_ = element - del element - return {(key_, context_): {secondary_key_: list(features_.values())}} - - -def get_by_project(data, project_id): - head_ = [list(d.keys())[0][1] for d in data] - index_ = [k for k in range(len(head_)) if head_[k] == project_id] - return [data[k] for k in index_] - - -def get_by_user(data, user_id): - head_ = [list(d.keys())[0][0] for d in data] - index_ = [k for k in range(len(head_)) if head_[k] == user_id] - return [data[k] for k in index_] - - -def clean_filters(data): - for j in range(len(data)): - _filter = data[j]['filter'] - _tmp = list() - for i in range(len(_filter['filters'])): - if 'value' in _filter['filters'][i].keys(): - _tmp.append({'type': _filter['filters'][i]['type'], - 'value': _filter['filters'][i]['value'], - 'operator': _filter['filters'][i]['operator']}) - data[j]['filter'] = _tmp - - -def clean_filters_split(data, isfunnel=False): - _data = list() - for j in range(len(data)): - _filter = data[j]['filter'] - _tmp = list() - for i in range(len(_filter['filters'])): - if 'value' in _filter['filters'][i].keys(): - _type = _filter['filters'][i]['type'] - _value = _filter['filters'][i]['value'] - if isinstance(_value, str): - _value = [_value] - _operator = _filter['filters'][i]['operator'] - if isfunnel: - _data.append({'project_id': data[j]['project_id'], 'user_id': data[j]['user_id'], - 'type': _type, - 'value': _value, - 'operator': _operator - }) - else: - _data.append({'metric_id': data[j]['metric_id'], 'project_id': data[j]['project_id'], - 'name': data[j]['name'], 'metric_type': data[j]['metric_type'], - 'metric_of': data[j]['metric_of'], - 'type': _type, - 'value': _value, - 'operator': _operator - }) - return _data - -def test(): - print('One test') - -if __name__ == '__main__': - print('Just a test') - #data = get_features_clickhouse() - #print('Data length:', len(data)) diff --git a/ee/recommendation/scripts/core/recommendation_model.py b/ee/recommendation/scripts/core/recommendation_model.py deleted file mode 100644 index 9dae948a7..000000000 --- a/ee/recommendation/scripts/core/recommendation_model.py +++ /dev/null @@ -1,15 +0,0 @@ -from sklearn.svm import SVC - -class SVM_recommendation(): - def __init__(**params): - f"""{SVC.__doc__}""" - self.svm = SVC(params) - - def fit(self, X1=None, X2=None): - assert X1 is not None or X2 is not None, 'X1 or X2 must be given' - self.svm.fit(X1) - self.svm.fit(X2) - - - def predict(self, X): - return self.svm.predict(X) diff --git a/ee/recommendation/scripts/model_registry.py b/ee/recommendation/scripts/model_registry.py deleted file mode 100644 index 80d6dbde6..000000000 --- a/ee/recommendation/scripts/model_registry.py +++ /dev/null @@ -1,60 +0,0 @@ -import mlflow -## -import numpy as np -import pickle - -from sklearn import datasets, linear_model -from sklearn.metrics import mean_squared_error, r2_score - -# source: https://scikit-learn.org/stable/auto_examples/linear_model/plot_ols.html - -# Load the diabetes dataset -diabetes_X, diabetes_y = datasets.load_diabetes(return_X_y=True) - -# Use only one feature -diabetes_X = diabetes_X[:, np.newaxis, 2] - -# Split the data into training/testing sets -diabetes_X_train = diabetes_X[:-20] -diabetes_X_test = diabetes_X[-20:] - -# Split the targets into training/testing sets -diabetes_y_train = diabetes_y[:-20] -diabetes_y_test = diabetes_y[-20:] - - -def print_predictions(m, y_pred): - - # The coefficients - print('Coefficients: \n', m.coef_) - # The mean squared error - print('Mean squared error: %.2f' - % mean_squared_error(diabetes_y_test, y_pred)) - # The coefficient of determination: 1 is perfect prediction - print('Coefficient of determination: %.2f' - % r2_score(diabetes_y_test, y_pred)) - -# Create linear regression object -lr_model = linear_model.LinearRegression() - -# Train the model using the training sets -lr_model.fit(diabetes_X_train, diabetes_y_train) - -# Make predictions using the testing set -diabetes_y_pred = lr_model.predict(diabetes_X_test) -print_predictions(lr_model, diabetes_y_pred) - -# save the model in the native sklearn format -filename = 'lr_model.pkl' -pickle.dump(lr_model, open(filename, 'wb')) -## -# load the model into memory -loaded_model = pickle.load(open(filename, 'rb')) - -# log and register the model using MLflow scikit-learn API -mlflow.set_tracking_uri("postgresql+psycopg2://airflow:airflow@postgres/mlruns") -reg_model_name = "SklearnLinearRegression" -print("--") -mlflow.sklearn.log_model(loaded_model, "sk_learn", - serialization_format="cloudpickle", - registered_model_name=reg_model_name) diff --git a/ee/recommendation/scripts/processing.py b/ee/recommendation/scripts/processing.py deleted file mode 100644 index 8f3631655..000000000 --- a/ee/recommendation/scripts/processing.py +++ /dev/null @@ -1,42 +0,0 @@ -import time -import argparse -from core import features -from utils import pg_client -import multiprocessing as mp -from decouple import config -import asyncio -import pandas - - -def features_ch(q): - q.put(features.get_features_clickhouse()) - -def features_pg(q): - q.put(features.get_features_postgres()) - -def get_features(): - #mp.set_start_method('spawn') - #q = mp.Queue() - #p1 = mp.Process(target=features_ch, args=(q,)) - #p1.start() - pg_features = features.get_features_postgres() - ch_features = []#p1.join() - return [pg_features, ch_features] - - -parser = argparse.ArgumentParser(description='Gets and process data from Postgres and ClickHouse.') -parser.add_argument('--batch_size', type=int, required=True, help='--batch_size max size of columns per file to be saved in opt/airflow/cache') - -args = parser.parse_args() - -if __name__ == '__main__': - asyncio.run(pg_client.init()) - print(args) - t1 = time.time() - data = get_features() - #print(data) - cache_dir = config("data_dir", default=f"/opt/airflow/cache") - for d in data[0]: - pandas.DataFrame(d).to_csv(f'{cache_dir}/tmp-{hash(time.time())}', sep=',') - t2 = time.time() - print(f'DONE! information retrieved in {t2-t1: .2f} seconds') diff --git a/ee/recommendation/scripts/task.py b/ee/recommendation/scripts/task.py deleted file mode 100644 index b427fa1c5..000000000 --- a/ee/recommendation/scripts/task.py +++ /dev/null @@ -1,41 +0,0 @@ -import time -import argparse -from decouple import config -from core import recommendation_model - -import pandas -import json -import os - - -def transform_dict_string(s_dicts): - data = list() - for s_dict in s_dicts: - data.append(json.loads(s_dict.replace("'", '"').replace('None','null').replace('False','false'))) - return data - -def process_file(file_name): - return pandas.read_csv(file_name, sep=",") - - -def read_batches(): - base_dir = config('dir_path', default='/opt/airflow/cache') - files = os.listdir(base_dir) - for file in files: - yield process_file(f'{base_dir}/{file}') - - -parser = argparse.ArgumentParser(description='Handle machine learning inputs.') -parser.add_argument('--mode', choices=['train', 'test'], required=True, help='--mode sets the model in train or test mode') -parser.add_argument('--kernel', default='linear', help='--kernel set the kernel to be used for SVM') - -args = parser.parse_args() - -if __name__ == '__main__': - print(args) - t1 = time.time() - buff = read_batches() - for b in buff: - print(b.head()) - t2 = time.time() - print(f'DONE! information retrieved in {t2-t1: .2f} seconds') diff --git a/ee/recommendation/signals.sql b/ee/recommendation/signals.sql deleted file mode 100644 index 5500969ed..000000000 --- a/ee/recommendation/signals.sql +++ /dev/null @@ -1,11 +0,0 @@ -CREATE TABLE IF NOT EXISTS frontend_signals -( - project_id bigint NOT NULL, - user_id text NOT NULL, - timestamp bigint NOT NULL, - action text NOT NULL, - source text NOT NULL, - category text NOT NULL, - data json -); -CREATE INDEX IF NOT EXISTS frontend_signals_user_id_idx ON frontend_signals (user_id); diff --git a/ee/recommendation/scripts/utils/ch_client.py b/ee/recommendation/utils/ch_client.py similarity index 100% rename from ee/recommendation/scripts/utils/ch_client.py rename to ee/recommendation/utils/ch_client.py diff --git a/ee/recommendation/utils/declarations.py b/ee/recommendation/utils/declarations.py new file mode 100644 index 000000000..d09da7e8c --- /dev/null +++ b/ee/recommendation/utils/declarations.py @@ -0,0 +1,59 @@ +from pydantic import BaseModel, Field + + +class FeedbackRecommendation(BaseModel): + viewerId: int = Field(...) + sessionId: int = Field(...) + projectId: int = Field(...) + payload: dict = Field(default=dict()) + + +class DeviceValue: + device_types = ['other', 'desktop', 'mobile'] + + def __init__(self, device_type): + if isinstance(device_type, str): + try: + self.id = self.device_types.index(device_type) + except ValueError: + self.id = 0 + self.name = device_type + elif isinstance(device_type, int): + self.id = device_type + self.name = self.device_types[device_type] + + def __repr__(self): + return str(self.id) + + def __str__(self): + return self.name + + def get_int_val(self): + return self.id + + def get_str_val(self): + return self.name + + +class CountryValue: + countries = ['UN', 'RW', 'SO', 'YE', 'IQ', 'SA', 'IR', 'CY', 'TZ', 'SY', 'AM', 'KE', 'CD', 'DJ', 'UG', 'CF', 'SC', 'JO', 'LB', 'KW', 'OM', 'QA', 'BH', 'AE', 'IL', 'TR', 'ET', 'ER', 'EG', 'SD', 'GR', 'BI', 'EE', 'LV', 'AZ', 'LT', 'SJ', 'GE', 'MD', 'BY', 'FI', 'AX', 'UA', 'MK', 'HU', 'BG', 'AL', 'PL', 'RO', 'XK', 'ZW', 'ZM', 'KM', 'MW', 'LS', 'BW', 'MU', 'SZ', 'RE', 'ZA', 'YT', 'MZ', 'MG', 'AF', 'PK', 'BD', 'TM', 'TJ', 'LK', 'BT', 'IN', 'MV', 'IO', 'NP', 'MM', 'UZ', 'KZ', 'KG', 'TF', 'HM', 'CC', 'PW', 'VN', 'TH', 'ID', 'LA', 'TW', 'PH', 'MY', 'CN', 'HK', 'BN', 'MO', 'KH', 'KR', 'JP', 'KP', 'SG', 'CK', 'TL', 'RU', 'MN', 'AU', 'CX', 'MH', 'FM', 'PG', 'SB', 'TV', 'NR', 'VU', 'NC', 'NF', 'NZ', 'FJ', 'LY', 'CM', 'SN', 'CG', 'PT', 'LR', 'CI', 'GH', 'GQ', 'NG', 'BF', 'TG', 'GW', 'MR', 'BJ', 'GA', 'SL', 'ST', 'GI', 'GM', 'GN', 'TD', 'NE', 'ML', 'EH', 'TN', 'ES', 'MA', 'MT', 'DZ', 'FO', 'DK', 'IS', 'GB', 'CH', 'SE', 'NL', 'AT', 'BE', 'DE', 'LU', 'IE', 'MC', 'FR', 'AD', 'LI', 'JE', 'IM', 'GG', 'SK', 'CZ', 'NO', 'VA', 'SM', 'IT', 'SI', 'ME', 'HR', 'BA', 'AO', 'NA', 'SH', 'BV', 'BB', 'CV', 'GY', 'GF', 'SR', 'PM', 'GL', 'PY', 'UY', 'BR', 'FK', 'GS', 'JM', 'DO', 'CU', 'MQ', 'BS', 'BM', 'AI', 'TT', 'KN', 'DM', 'AG', 'LC', 'TC', 'AW', 'VG', 'VC', 'MS', 'MF', 'BL', 'GP', 'GD', 'KY', 'BZ', 'SV', 'GT', 'HN', 'NI', 'CR', 'VE', 'EC', 'CO', 'PA', 'HT', 'AR', 'CL', 'BO', 'PE', 'MX', 'PF', 'PN', 'KI', 'TK', 'TO', 'WF', 'WS', 'NU', 'MP', 'GU', 'PR', 'VI', 'UM', 'AS', 'CA', 'US', 'PS', 'RS', 'AQ', 'SX', 'CW', 'BQ', 'SS', 'BU', 'VD', 'YD', 'DD'] + + def __init__(self, country): + if isinstance(country, str): + self.id = -128+self.countries.index(country) + self.name = country + elif isinstance(country, int): + self.id = country + self.name = self.countries[128+country] + + def __repr__(self): + return str(self.id) + + def __str__(self): + return self.name + + def get_int_val(self): + return self.id + + def get_str_val(self): + return self.name diff --git a/ee/recommendation/utils/df_utils.py b/ee/recommendation/utils/df_utils.py new file mode 100644 index 000000000..5eafe129b --- /dev/null +++ b/ee/recommendation/utils/df_utils.py @@ -0,0 +1,24 @@ +from utils.declarations import CountryValue, DeviceValue + + +def _add_to_dict(element, index, dictionary): + if element not in dictionary.keys(): + dictionary[element] = [index] + else: + dictionary[element].append(index) + + +def _process_pg_response(res, _X, _Y, X_project_ids, X_users_ids, X_sessions_ids, label=None): + for i in range(len(res)): + x = res[i] + _add_to_dict(x.pop('project_id'), i, X_project_ids) + _add_to_dict(x.pop('session_id'), i, X_sessions_ids) + _add_to_dict(x.pop('user_id'), i, X_users_ids) + if label is None: + _Y.append(int(x.pop('train_label'))) + else: + _Y.append(label) + + x['country'] = CountryValue(x['country']).get_int_val() + x['device_type'] = DeviceValue(x['device_type']).get_int_val() + _X.append(list(x.values())) diff --git a/ee/recommendation/scripts/utils/pg_client.py b/ee/recommendation/utils/pg_client.py similarity index 78% rename from ee/recommendation/scripts/utils/pg_client.py rename to ee/recommendation/utils/pg_client.py index 69a5b5a8b..1bfad6d36 100644 --- a/ee/recommendation/scripts/utils/pg_client.py +++ b/ee/recommendation/utils/pg_client.py @@ -10,12 +10,35 @@ from psycopg2 import pool logging.basicConfig(level=config("LOGLEVEL", default=logging.INFO)) logging.getLogger('apscheduler').setLevel(config("LOGLEVEL", default=logging.INFO)) -_PG_CONFIG = {"host": config("pg_host"), +conn_str = config('string_connection', default='') +if conn_str == '': + _PG_CONFIG = {"host": config("pg_host"), "database": config("pg_dbname"), "user": config("pg_user"), "password": config("pg_password"), "port": config("pg_port", cast=int), "application_name": config("APP_NAME", default="PY")} +else: + import urllib.parse + conn_str = urllib.parse.unquote(conn_str) + usr_info, host_info = conn_str.split('@') + i = usr_info.find('://') + pg_user, pg_password = usr_info[i+3:].split(':') + host_info, pg_dbname = host_info.split('/') + i = host_info.find(':') + if i == -1: + pg_host = host_info + pg_port = 5432 + else: + pg_host, pg_port = host_info.split(':') + pg_port = int(pg_port) + _PG_CONFIG = {"host": pg_host, + "database": pg_dbname, + "user": pg_user, + "password": pg_password, + "port": pg_port, + "application_name": config("APP_NAME", default="PY")} + PG_CONFIG = dict(_PG_CONFIG) if config("PG_TIMEOUT", cast=int, default=0) > 0: PG_CONFIG["options"] = f"-c statement_timeout={config('PG_TIMEOUT', cast=int) * 1000}" @@ -87,9 +110,10 @@ class PostgresClient: long_query = False unlimited_query = False - def __init__(self, long_query=False, unlimited_query=False): + def __init__(self, long_query=False, unlimited_query=False, use_pool=True): self.long_query = long_query self.unlimited_query = unlimited_query + self.use_pool = use_pool if unlimited_query: long_config = dict(_PG_CONFIG) long_config["application_name"] += "-UNLIMITED" @@ -100,7 +124,7 @@ class PostgresClient: long_config["options"] = f"-c statement_timeout=" \ f"{config('pg_long_timeout', cast=int, default=5 * 60) * 1000}" self.connection = psycopg2.connect(**long_config) - elif not config('PG_POOL', cast=bool, default=True): + elif not use_pool or not config('PG_POOL', cast=bool, default=True): single_config = dict(_PG_CONFIG) single_config["application_name"] += "-NOPOOL" single_config["options"] = f"-c statement_timeout={config('PG_TIMEOUT', cast=int, default=30) * 1000}" @@ -111,6 +135,8 @@ class PostgresClient: def __enter__(self): if self.cursor is None: self.cursor = self.connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor) + self.cursor.cursor_execute = self.cursor.execute + self.cursor.execute = self.__execute self.cursor.recreate = self.recreate_cursor return self.cursor @@ -118,11 +144,12 @@ class PostgresClient: try: self.connection.commit() self.cursor.close() - if self.long_query or self.unlimited_query: + if not self.use_pool or self.long_query or self.unlimited_query: self.connection.close() except Exception as error: logging.error("Error while committing/closing PG-connection", error) if str(error) == "connection already closed" \ + and self.use_pool \ and not self.long_query \ and not self.unlimited_query \ and config('PG_POOL', cast=bool, default=True): @@ -132,10 +159,22 @@ class PostgresClient: raise error finally: if config('PG_POOL', cast=bool, default=True) \ + and self.use_pool \ and not self.long_query \ and not self.unlimited_query: postgreSQL_pool.putconn(self.connection) + def __execute(self, query, vars=None): + try: + result = self.cursor.cursor_execute(query=query, vars=vars) + except psycopg2.Error as error: + logging.error(f"!!! Error of type:{type(error)} while executing query:") + logging.error(query) + logging.info("starting rollback to allow future execution") + self.connection.rollback() + raise error + return result + def recreate_cursor(self, rollback=False): if rollback: try: diff --git a/third-party.md b/third-party.md index 2775702f6..5c373781f 100644 --- a/third-party.md +++ b/third-party.md @@ -48,6 +48,9 @@ Below is the list of dependencies used in OpenReplay software. Licenses may chan | pandas | BSD3 | Python | | numpy | BSD3 | Python | | scikit-learn | BSD3 | Python | +|apache-airflow| Apache2| Python| +|airflow-code-editor| Apache2 | Python| +|mlflow| Apache2 | Python| | sqlalchemy | MIT | Python | | pandas-redshift | MIT | Python | | confluent-kafka | Apache2 | Python |