From d3a611a6766d9955a51165d12e96637bb4f2e7b9 Mon Sep 17 00:00:00 2001 From: MauricioGarciaS <47052044+MauricioGarciaS@users.noreply.github.com> Date: Wed, 29 Nov 2023 16:41:33 +0100 Subject: [PATCH] Changed pg pool to normal pg connection --- ee/connectors/consumer_pool.py | 10 +- ee/connectors/utils/pg_client.py | 209 ++++++------------------------- ee/connectors/utils/worker.py | 32 ++--- 3 files changed, 57 insertions(+), 194 deletions(-) diff --git a/ee/connectors/consumer_pool.py b/ee/connectors/consumer_pool.py index 5a2dbf34e..f33dbbc37 100644 --- a/ee/connectors/consumer_pool.py +++ b/ee/connectors/consumer_pool.py @@ -1,7 +1,8 @@ from decouple import config, Csv -import asyncio +import signal +# import asyncio from db.api import DBConnection -from utils import pg_client +# from utils import pg_client from utils.worker import WorkerPool @@ -10,7 +11,7 @@ def main(): database_api = DBConnection(DATABASE) allowed_projects = config('PROJECT_IDS', default=None, cast=Csv(int)) - w_pool = WorkerPool(n_workers=60, + w_pool = WorkerPool(n_workers=config('OR_EE_CONNECTOR_WORKER_COUNT', cast=int, default=60), project_filter=allowed_projects) try: w_pool.load_checkpoint(database_api) @@ -24,6 +25,7 @@ def main(): if __name__ == '__main__': - asyncio.run(pg_client.init()) + # asyncio.run(pg_client.init()) main() + raise Exception('Script terminated') diff --git a/ee/connectors/utils/pg_client.py b/ee/connectors/utils/pg_client.py index 1bfad6d36..55c5d2172 100644 --- a/ee/connectors/utils/pg_client.py +++ b/ee/connectors/utils/pg_client.py @@ -1,23 +1,30 @@ import logging import time -from threading import Semaphore - -import psycopg2 -import psycopg2.extras +from sqlalchemy import create_engine +from sqlalchemy import MetaData +from sqlalchemy.orm import sessionmaker, session +from contextlib import contextmanager +import logging +from decouple import config as _config +from decouple import Choices +from contextlib import contextmanager from decouple import config -from psycopg2 import pool logging.basicConfig(level=config("LOGLEVEL", default=logging.INFO)) logging.getLogger('apscheduler').setLevel(config("LOGLEVEL", default=logging.INFO)) +sslmode = _config('DB_SSLMODE', + cast=Choices(['disable', 'allow', 'prefer', 'require', 'verify-ca', 'verify-full']), + default='allow' +) + conn_str = config('string_connection', default='') if conn_str == '': - _PG_CONFIG = {"host": config("pg_host"), - "database": config("pg_dbname"), - "user": config("pg_user"), - "password": config("pg_password"), - "port": config("pg_port", cast=int), - "application_name": config("APP_NAME", default="PY")} + pg_host = config("pg_host") + pg_dbname = config("pg_dbname") + pg_user = config("pg_user") + pg_password = config("pg_password") + pg_port = config("pg_port", cast=int) else: import urllib.parse conn_str = urllib.parse.unquote(conn_str) @@ -32,174 +39,28 @@ else: else: pg_host, pg_port = host_info.split(':') pg_port = int(pg_port) - _PG_CONFIG = {"host": pg_host, - "database": pg_dbname, - "user": pg_user, - "password": pg_password, - "port": pg_port, - "application_name": config("APP_NAME", default="PY")} - -PG_CONFIG = dict(_PG_CONFIG) -if config("PG_TIMEOUT", cast=int, default=0) > 0: - PG_CONFIG["options"] = f"-c statement_timeout={config('PG_TIMEOUT', cast=int) * 1000}" - - -class ORThreadedConnectionPool(psycopg2.pool.ThreadedConnectionPool): - def __init__(self, minconn, maxconn, *args, **kwargs): - self._semaphore = Semaphore(maxconn) - super().__init__(minconn, maxconn, *args, **kwargs) - - def getconn(self, *args, **kwargs): - self._semaphore.acquire() - try: - return super().getconn(*args, **kwargs) - except psycopg2.pool.PoolError as e: - if str(e) == "connection pool is closed": - make_pool() - raise e - - def putconn(self, *args, **kwargs): - try: - super().putconn(*args, **kwargs) - self._semaphore.release() - except psycopg2.pool.PoolError as e: - if str(e) == "trying to put unkeyed connection": - print("!!! trying to put unkeyed connection") - print(f"env-PG_POOL:{config('PG_POOL', default=None)}") - return - raise e - - -postgreSQL_pool: ORThreadedConnectionPool = None - -RETRY_MAX = config("PG_RETRY_MAX", cast=int, default=50) -RETRY_INTERVAL = config("PG_RETRY_INTERVAL", cast=int, default=2) -RETRY = 0 - - -def make_pool(): - if not config('PG_POOL', cast=bool, default=True): - return - global postgreSQL_pool - global RETRY - if postgreSQL_pool is not None: - try: - postgreSQL_pool.closeall() - except (Exception, psycopg2.DatabaseError) as error: - logging.error("Error while closing all connexions to PostgreSQL", error) - try: - postgreSQL_pool = ORThreadedConnectionPool(config("PG_MINCONN", cast=int, default=20), - config("PG_MAXCONN", cast=int, default=80), - **PG_CONFIG) - if (postgreSQL_pool): - logging.info("Connection pool created successfully") - except (Exception, psycopg2.DatabaseError) as error: - logging.error("Error while connecting to PostgreSQL", error) - if RETRY < RETRY_MAX: - RETRY += 1 - logging.info(f"waiting for {RETRY_INTERVAL}s before retry n°{RETRY}") - time.sleep(RETRY_INTERVAL) - make_pool() - else: - raise error +conn_str = f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_dbname}" class PostgresClient: - connection = None - cursor = None - long_query = False - unlimited_query = False + CONNECTION_STRING: str = conn_str + _sessions = sessionmaker() - def __init__(self, long_query=False, unlimited_query=False, use_pool=True): - self.long_query = long_query - self.unlimited_query = unlimited_query - self.use_pool = use_pool - if unlimited_query: - long_config = dict(_PG_CONFIG) - long_config["application_name"] += "-UNLIMITED" - self.connection = psycopg2.connect(**long_config) - elif long_query: - long_config = dict(_PG_CONFIG) - long_config["application_name"] += "-LONG" - long_config["options"] = f"-c statement_timeout=" \ - f"{config('pg_long_timeout', cast=int, default=5 * 60) * 1000}" - self.connection = psycopg2.connect(**long_config) - elif not use_pool or not config('PG_POOL', cast=bool, default=True): - single_config = dict(_PG_CONFIG) - single_config["application_name"] += "-NOPOOL" - single_config["options"] = f"-c statement_timeout={config('PG_TIMEOUT', cast=int, default=30) * 1000}" - self.connection = psycopg2.connect(**single_config) - else: - self.connection = postgreSQL_pool.getconn() + def __init__(self): + self.engine = create_engine(self.CONNECTION_STRING, connect_args={'sslmode': sslmode}) - def __enter__(self): - if self.cursor is None: - self.cursor = self.connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - self.cursor.cursor_execute = self.cursor.execute - self.cursor.execute = self.__execute - self.cursor.recreate = self.recreate_cursor - return self.cursor + @contextmanager + def get_live_session(self) -> session: + """ + This is a session that can be committed. + Changes will be reflected in the database. + """ + # Automatic transaction and connection handling in session + connection = self.engine.connect() + my_session = type(self)._sessions(bind=connection) - def __exit__(self, *args): - try: - self.connection.commit() - self.cursor.close() - if not self.use_pool or self.long_query or self.unlimited_query: - self.connection.close() - except Exception as error: - logging.error("Error while committing/closing PG-connection", error) - if str(error) == "connection already closed" \ - and self.use_pool \ - and not self.long_query \ - and not self.unlimited_query \ - and config('PG_POOL', cast=bool, default=True): - logging.info("Recreating the connexion pool") - make_pool() - else: - raise error - finally: - if config('PG_POOL', cast=bool, default=True) \ - and self.use_pool \ - and not self.long_query \ - and not self.unlimited_query: - postgreSQL_pool.putconn(self.connection) + yield my_session - def __execute(self, query, vars=None): - try: - result = self.cursor.cursor_execute(query=query, vars=vars) - except psycopg2.Error as error: - logging.error(f"!!! Error of type:{type(error)} while executing query:") - logging.error(query) - logging.info("starting rollback to allow future execution") - self.connection.rollback() - raise error - return result + my_session.close() + connection.close() - def recreate_cursor(self, rollback=False): - if rollback: - try: - self.connection.rollback() - except Exception as error: - logging.error("Error while rollbacking connection for recreation", error) - try: - self.cursor.close() - except Exception as error: - logging.error("Error while closing cursor for recreation", error) - self.cursor = None - return self.__enter__() - - -async def init(): - logging.info(f">PG_POOL:{config('PG_POOL', default=None)}") - if config('PG_POOL', cast=bool, default=True): - make_pool() - - -async def terminate(): - global postgreSQL_pool - if postgreSQL_pool is not None: - try: - postgreSQL_pool.closeall() - logging.info("Closed all connexions to PostgreSQL") - except (Exception, psycopg2.DatabaseError) as error: - logging.error("Error while closing all connexions to PostgreSQL", error) diff --git a/ee/connectors/utils/worker.py b/ee/connectors/utils/worker.py index f3135e0fe..f945c708f 100644 --- a/ee/connectors/utils/worker.py +++ b/ee/connectors/utils/worker.py @@ -151,7 +151,7 @@ class ProjectFilter: def read_from_kafka(pipe: Connection, params: dict): global UPLOAD_RATE, max_kafka_read # try: - asyncio.run(pg_client.init()) + # asyncio.run(pg_client.init()) kafka_consumer = init_consumer() project_filter = params['project_filter'] capture_messages = list() @@ -207,7 +207,7 @@ def read_from_kafka(pipe: Connection, params: dict): print('[WORKER INFO] Closing consumer') close_consumer(kafka_consumer) print('[WORKER INFO] Closing pg connection') - asyncio.run(pg_client.terminate()) + # asyncio.run(pg_client.terminate()) print('[WORKER INFO] Successfully closed reader task') # except Exception as e: # print('[WARN]', repr(e)) @@ -223,12 +223,12 @@ def into_batch(batch: list[Event | DetailedEvent], session_id: int, n: Session): def project_from_session(sessionId: int): """Search projectId of requested sessionId in PG table sessions""" - with pg_client.PostgresClient() as conn: - conn.execute( + with pg_client.PostgresClient().get_live_session() as conn: + cur = conn.execute( conn.mogrify("SELECT project_id FROM sessions WHERE session_id=%(sessionId)s LIMIT 1", {'sessionId': sessionId}) ) - res = conn.fetchone() + res = cur.fetchone() if res is None: print(f'[WORKER WARN] sessionid {sessionId} not found in sessions table') return None @@ -241,13 +241,13 @@ def project_from_sessions(sessionIds: list[int]): while sessionIds: sessIds = sessionIds[-1000:] try: - with pg_client.PostgresClient() as conn: - conn.execute( + with pg_client.PostgresClient().get_live_session() as conn: + cur = conn.execute( "SELECT session_id, project_id FROM sessions WHERE session_id IN ({sessionIds})".format( sessionIds=','.join([str(sessId) for sessId in sessIds]) ) ) - res = conn.fetchall() + res = cur.fetchall() except Exception as e: print('[WORKER project_from_sessions]', repr(e)) raise e @@ -320,16 +320,16 @@ def fix_missing_redshift(): return # logging.info(f'[FILL INFO] {len(res)} length response') sessionids = list(map(lambda k: str(k), res['sessionid'])) - asyncio.run(pg_client.init()) + # asyncio.run(pg_client.init()) try: - with pg_client.PostgresClient() as conn: - conn.execute('SELECT session_id, user_id FROM sessions WHERE session_id IN ({session_id_list})'.format( + with pg_client.PostgresClient().get_live_session() as conn: + cur = conn.execute('SELECT session_id, user_id FROM sessions WHERE session_id IN ({session_id_list})'.format( session_id_list=','.join(sessionids)) ) - pg_res = conn.fetchall() + pg_res = cur.fetchall() except Exception as e: #logging.error(f'[ERROR] Error while selecting from pg: {repr(e)}') - asyncio.run(pg_client.terminate()) + # asyncio.run(pg_client.terminate()) return logging.info(f'response from pg, length {len(pg_res)}') df = pd.DataFrame(pg_res) @@ -350,7 +350,7 @@ def fix_missing_redshift(): if len(all_ids) == 0: logging.info('[FILL INFO] No ids obtained') database_api.close() - asyncio.run(pg_client.terminate()) + # asyncio.run(pg_client.terminate()) return # logging.info(f'[FILL INFO] {base_query}') try: @@ -359,11 +359,11 @@ def fix_missing_redshift(): logging.error(f'[ERROR] Error while executing query. {repr(e)}') logging.error(f'[ERROR INFO] query: {base_query}') database_api.close() - asyncio.run(pg_client.terminate()) + # asyncio.run(pg_client.terminate()) return logging.info(f'[FILL-INFO] {time() - t} - for {len(sessionids)} elements') database_api.close() - asyncio.run(pg_client.terminate()) + # asyncio.run(pg_client.terminate()) return