fix(connector-redshift): Changed PG pool (#1821)

* Added exception in pool

* Solved issue with message codec

* Changed pg pool to normal pg connection

* fix(redshift-connector): Fixed close connection when exception
This commit is contained in:
MauricioGarciaS 2024-01-17 10:33:21 +01:00 committed by GitHub
parent 7dac657885
commit 5938fd95de
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 465 additions and 596 deletions

View file

@ -1,7 +1,8 @@
from decouple import config, Csv from decouple import config, Csv
import asyncio import signal
# import asyncio
from db.api import DBConnection from db.api import DBConnection
from utils import pg_client # from utils import pg_client
from utils.worker import WorkerPool from utils.worker import WorkerPool
@ -10,7 +11,7 @@ def main():
database_api = DBConnection(DATABASE) database_api = DBConnection(DATABASE)
allowed_projects = config('PROJECT_IDS', default=None, cast=Csv(int)) allowed_projects = config('PROJECT_IDS', default=None, cast=Csv(int))
w_pool = WorkerPool(n_workers=60, w_pool = WorkerPool(n_workers=config('OR_EE_CONNECTOR_WORKER_COUNT', cast=int, default=60),
project_filter=allowed_projects) project_filter=allowed_projects)
try: try:
w_pool.load_checkpoint(database_api) w_pool.load_checkpoint(database_api)
@ -24,6 +25,7 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(pg_client.init()) # asyncio.run(pg_client.init())
main() main()
raise Exception('Script terminated')

File diff suppressed because it is too large Load diff

View file

@ -1,23 +1,30 @@
import logging import logging
import time import time
from threading import Semaphore from sqlalchemy import create_engine
from sqlalchemy import MetaData
import psycopg2 from sqlalchemy.orm import sessionmaker, session
import psycopg2.extras from contextlib import contextmanager
import logging
from decouple import config as _config
from decouple import Choices
from contextlib import contextmanager
from decouple import config from decouple import config
from psycopg2 import pool
logging.basicConfig(level=config("LOGLEVEL", default=logging.INFO)) logging.basicConfig(level=config("LOGLEVEL", default=logging.INFO))
logging.getLogger('apscheduler').setLevel(config("LOGLEVEL", default=logging.INFO)) logging.getLogger('apscheduler').setLevel(config("LOGLEVEL", default=logging.INFO))
sslmode = _config('DB_SSLMODE',
cast=Choices(['disable', 'allow', 'prefer', 'require', 'verify-ca', 'verify-full']),
default='allow'
)
conn_str = config('string_connection', default='') conn_str = config('string_connection', default='')
if conn_str == '': if conn_str == '':
_PG_CONFIG = {"host": config("pg_host"), pg_host = config("pg_host")
"database": config("pg_dbname"), pg_dbname = config("pg_dbname")
"user": config("pg_user"), pg_user = config("pg_user")
"password": config("pg_password"), pg_password = config("pg_password")
"port": config("pg_port", cast=int), pg_port = config("pg_port", cast=int)
"application_name": config("APP_NAME", default="PY")}
else: else:
import urllib.parse import urllib.parse
conn_str = urllib.parse.unquote(conn_str) conn_str = urllib.parse.unquote(conn_str)
@ -32,174 +39,28 @@ else:
else: else:
pg_host, pg_port = host_info.split(':') pg_host, pg_port = host_info.split(':')
pg_port = int(pg_port) pg_port = int(pg_port)
_PG_CONFIG = {"host": pg_host,
"database": pg_dbname,
"user": pg_user,
"password": pg_password,
"port": pg_port,
"application_name": config("APP_NAME", default="PY")}
PG_CONFIG = dict(_PG_CONFIG)
if config("PG_TIMEOUT", cast=int, default=0) > 0:
PG_CONFIG["options"] = f"-c statement_timeout={config('PG_TIMEOUT', cast=int) * 1000}"
class ORThreadedConnectionPool(psycopg2.pool.ThreadedConnectionPool):
def __init__(self, minconn, maxconn, *args, **kwargs):
self._semaphore = Semaphore(maxconn)
super().__init__(minconn, maxconn, *args, **kwargs)
def getconn(self, *args, **kwargs):
self._semaphore.acquire()
try:
return super().getconn(*args, **kwargs)
except psycopg2.pool.PoolError as e:
if str(e) == "connection pool is closed":
make_pool()
raise e
def putconn(self, *args, **kwargs):
try:
super().putconn(*args, **kwargs)
self._semaphore.release()
except psycopg2.pool.PoolError as e:
if str(e) == "trying to put unkeyed connection":
print("!!! trying to put unkeyed connection")
print(f"env-PG_POOL:{config('PG_POOL', default=None)}")
return
raise e
postgreSQL_pool: ORThreadedConnectionPool = None
RETRY_MAX = config("PG_RETRY_MAX", cast=int, default=50)
RETRY_INTERVAL = config("PG_RETRY_INTERVAL", cast=int, default=2)
RETRY = 0
def make_pool():
if not config('PG_POOL', cast=bool, default=True):
return
global postgreSQL_pool
global RETRY
if postgreSQL_pool is not None:
try:
postgreSQL_pool.closeall()
except (Exception, psycopg2.DatabaseError) as error:
logging.error("Error while closing all connexions to PostgreSQL", error)
try:
postgreSQL_pool = ORThreadedConnectionPool(config("PG_MINCONN", cast=int, default=20),
config("PG_MAXCONN", cast=int, default=80),
**PG_CONFIG)
if (postgreSQL_pool):
logging.info("Connection pool created successfully")
except (Exception, psycopg2.DatabaseError) as error:
logging.error("Error while connecting to PostgreSQL", error)
if RETRY < RETRY_MAX:
RETRY += 1
logging.info(f"waiting for {RETRY_INTERVAL}s before retry n°{RETRY}")
time.sleep(RETRY_INTERVAL)
make_pool()
else:
raise error
conn_str = f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_dbname}"
class PostgresClient: class PostgresClient:
connection = None CONNECTION_STRING: str = conn_str
cursor = None _sessions = sessionmaker()
long_query = False
unlimited_query = False
def __init__(self, long_query=False, unlimited_query=False, use_pool=True): def __init__(self):
self.long_query = long_query self.engine = create_engine(self.CONNECTION_STRING, connect_args={'sslmode': sslmode})
self.unlimited_query = unlimited_query
self.use_pool = use_pool
if unlimited_query:
long_config = dict(_PG_CONFIG)
long_config["application_name"] += "-UNLIMITED"
self.connection = psycopg2.connect(**long_config)
elif long_query:
long_config = dict(_PG_CONFIG)
long_config["application_name"] += "-LONG"
long_config["options"] = f"-c statement_timeout=" \
f"{config('pg_long_timeout', cast=int, default=5 * 60) * 1000}"
self.connection = psycopg2.connect(**long_config)
elif not use_pool or not config('PG_POOL', cast=bool, default=True):
single_config = dict(_PG_CONFIG)
single_config["application_name"] += "-NOPOOL"
single_config["options"] = f"-c statement_timeout={config('PG_TIMEOUT', cast=int, default=30) * 1000}"
self.connection = psycopg2.connect(**single_config)
else:
self.connection = postgreSQL_pool.getconn()
def __enter__(self): @contextmanager
if self.cursor is None: def get_live_session(self) -> session:
self.cursor = self.connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor) """
self.cursor.cursor_execute = self.cursor.execute This is a session that can be committed.
self.cursor.execute = self.__execute Changes will be reflected in the database.
self.cursor.recreate = self.recreate_cursor """
return self.cursor # Automatic transaction and connection handling in session
connection = self.engine.connect()
my_session = type(self)._sessions(bind=connection)
def __exit__(self, *args): yield my_session
try:
self.connection.commit()
self.cursor.close()
if not self.use_pool or self.long_query or self.unlimited_query:
self.connection.close()
except Exception as error:
logging.error("Error while committing/closing PG-connection", error)
if str(error) == "connection already closed" \
and self.use_pool \
and not self.long_query \
and not self.unlimited_query \
and config('PG_POOL', cast=bool, default=True):
logging.info("Recreating the connexion pool")
make_pool()
else:
raise error
finally:
if config('PG_POOL', cast=bool, default=True) \
and self.use_pool \
and not self.long_query \
and not self.unlimited_query:
postgreSQL_pool.putconn(self.connection)
def __execute(self, query, vars=None): my_session.close()
try: connection.close()
result = self.cursor.cursor_execute(query=query, vars=vars)
except psycopg2.Error as error:
logging.error(f"!!! Error of type:{type(error)} while executing query:")
logging.error(query)
logging.info("starting rollback to allow future execution")
self.connection.rollback()
raise error
return result
def recreate_cursor(self, rollback=False):
if rollback:
try:
self.connection.rollback()
except Exception as error:
logging.error("Error while rollbacking connection for recreation", error)
try:
self.cursor.close()
except Exception as error:
logging.error("Error while closing cursor for recreation", error)
self.cursor = None
return self.__enter__()
async def init():
logging.info(f">PG_POOL:{config('PG_POOL', default=None)}")
if config('PG_POOL', cast=bool, default=True):
make_pool()
async def terminate():
global postgreSQL_pool
if postgreSQL_pool is not None:
try:
postgreSQL_pool.closeall()
logging.info("Closed all connexions to PostgreSQL")
except (Exception, psycopg2.DatabaseError) as error:
logging.error("Error while closing all connexions to PostgreSQL", error)

View file

@ -151,7 +151,7 @@ class ProjectFilter:
def read_from_kafka(pipe: Connection, params: dict): def read_from_kafka(pipe: Connection, params: dict):
global UPLOAD_RATE, max_kafka_read global UPLOAD_RATE, max_kafka_read
# try: # try:
asyncio.run(pg_client.init()) # asyncio.run(pg_client.init())
kafka_consumer = init_consumer() kafka_consumer = init_consumer()
project_filter = params['project_filter'] project_filter = params['project_filter']
capture_messages = list() capture_messages = list()
@ -207,7 +207,7 @@ def read_from_kafka(pipe: Connection, params: dict):
print('[WORKER INFO] Closing consumer') print('[WORKER INFO] Closing consumer')
close_consumer(kafka_consumer) close_consumer(kafka_consumer)
print('[WORKER INFO] Closing pg connection') print('[WORKER INFO] Closing pg connection')
asyncio.run(pg_client.terminate()) # asyncio.run(pg_client.terminate())
print('[WORKER INFO] Successfully closed reader task') print('[WORKER INFO] Successfully closed reader task')
# except Exception as e: # except Exception as e:
# print('[WARN]', repr(e)) # print('[WARN]', repr(e))
@ -223,12 +223,12 @@ def into_batch(batch: list[Event | DetailedEvent], session_id: int, n: Session):
def project_from_session(sessionId: int): def project_from_session(sessionId: int):
"""Search projectId of requested sessionId in PG table sessions""" """Search projectId of requested sessionId in PG table sessions"""
with pg_client.PostgresClient() as conn: with pg_client.PostgresClient().get_live_session() as conn:
conn.execute( cur = conn.execute(
conn.mogrify("SELECT project_id FROM sessions WHERE session_id=%(sessionId)s LIMIT 1", conn.mogrify("SELECT project_id FROM sessions WHERE session_id=%(sessionId)s LIMIT 1",
{'sessionId': sessionId}) {'sessionId': sessionId})
) )
res = conn.fetchone() res = cur.fetchone()
if res is None: if res is None:
print(f'[WORKER WARN] sessionid {sessionId} not found in sessions table') print(f'[WORKER WARN] sessionid {sessionId} not found in sessions table')
return None return None
@ -241,13 +241,13 @@ def project_from_sessions(sessionIds: list[int]):
while sessionIds: while sessionIds:
sessIds = sessionIds[-1000:] sessIds = sessionIds[-1000:]
try: try:
with pg_client.PostgresClient() as conn: with pg_client.PostgresClient().get_live_session() as conn:
conn.execute( cur = conn.execute(
"SELECT session_id, project_id FROM sessions WHERE session_id IN ({sessionIds})".format( "SELECT session_id, project_id FROM sessions WHERE session_id IN ({sessionIds})".format(
sessionIds=','.join([str(sessId) for sessId in sessIds]) sessionIds=','.join([str(sessId) for sessId in sessIds])
) )
) )
res = conn.fetchall() res = cur.fetchall()
except Exception as e: except Exception as e:
print('[WORKER project_from_sessions]', repr(e)) print('[WORKER project_from_sessions]', repr(e))
raise e raise e
@ -320,16 +320,16 @@ def fix_missing_redshift():
return return
# logging.info(f'[FILL INFO] {len(res)} length response') # logging.info(f'[FILL INFO] {len(res)} length response')
sessionids = list(map(lambda k: str(k), res['sessionid'])) sessionids = list(map(lambda k: str(k), res['sessionid']))
asyncio.run(pg_client.init()) # asyncio.run(pg_client.init())
try: try:
with pg_client.PostgresClient() as conn: with pg_client.PostgresClient().get_live_session() as conn:
conn.execute('SELECT session_id, user_id FROM sessions WHERE session_id IN ({session_id_list})'.format( cur = conn.execute('SELECT session_id, user_id FROM sessions WHERE session_id IN ({session_id_list})'.format(
session_id_list=','.join(sessionids)) session_id_list=','.join(sessionids))
) )
pg_res = conn.fetchall() pg_res = cur.fetchall()
except Exception as e: except Exception as e:
#logging.error(f'[ERROR] Error while selecting from pg: {repr(e)}') #logging.error(f'[ERROR] Error while selecting from pg: {repr(e)}')
asyncio.run(pg_client.terminate()) # asyncio.run(pg_client.terminate())
return return
logging.info(f'response from pg, length {len(pg_res)}') logging.info(f'response from pg, length {len(pg_res)}')
df = pd.DataFrame(pg_res) df = pd.DataFrame(pg_res)
@ -350,7 +350,7 @@ def fix_missing_redshift():
if len(all_ids) == 0: if len(all_ids) == 0:
logging.info('[FILL INFO] No ids obtained') logging.info('[FILL INFO] No ids obtained')
database_api.close() database_api.close()
asyncio.run(pg_client.terminate()) # asyncio.run(pg_client.terminate())
return return
# logging.info(f'[FILL INFO] {base_query}') # logging.info(f'[FILL INFO] {base_query}')
try: try:
@ -359,11 +359,11 @@ def fix_missing_redshift():
logging.error(f'[ERROR] Error while executing query. {repr(e)}') logging.error(f'[ERROR] Error while executing query. {repr(e)}')
logging.error(f'[ERROR INFO] query: {base_query}') logging.error(f'[ERROR INFO] query: {base_query}')
database_api.close() database_api.close()
asyncio.run(pg_client.terminate()) # asyncio.run(pg_client.terminate())
return return
logging.info(f'[FILL-INFO] {time() - t} - for {len(sessionids)} elements') logging.info(f'[FILL-INFO] {time() - t} - for {len(sessionids)} elements')
database_api.close() database_api.close()
asyncio.run(pg_client.terminate()) # asyncio.run(pg_client.terminate())
return return
@ -488,6 +488,12 @@ class WorkerPool:
except TimeoutError as e: except TimeoutError as e:
print('[WORKER-TimeoutError] Decoding of messages is taking longer than expected') print('[WORKER-TimeoutError] Decoding of messages is taking longer than expected')
raise e raise e
except Exception as e:
print(f'[Exception] {e}')
self.sessions_update_batch = dict()
self.sessions_insert_batch = dict()
self.events_batch = list()
continue
session_ids, messages = self._pool_response_handler( session_ids, messages = self._pool_response_handler(
pool_results=results) pool_results=results)
if current_loop_number == 0: if current_loop_number == 0:
@ -500,7 +506,7 @@ class WorkerPool:
main_conn.send('CONTINUE') main_conn.send('CONTINUE')
print('[WORKER-INFO] Sending close signal') print('[WORKER-INFO] Sending close signal')
main_conn.send('CLOSE') main_conn.send('CLOSE')
self.terminate() self.terminate(database_api)
kafka_reader_process.terminate() kafka_reader_process.terminate()
print('[WORKER-SHUTDOWN] Process terminated') print('[WORKER-SHUTDOWN] Process terminated')