openreplay/api/chalicelib/utils/pg_client.py

305 lines
11 KiB
Python

import logging
import time
import asyncio
from threading import Semaphore
from typing import Dict, Any, Optional
import psycopg2
import psycopg2.extras
from decouple import config
from psycopg2 import pool
logger = logging.getLogger(__name__)
_PG_CONFIG = {"host": config("pg_host"),
"database": config("pg_dbname"),
"user": config("pg_user"),
"password": config("pg_password"),
"port": config("pg_port", cast=int),
"application_name": config("APP_NAME", default="PY")}
PG_CONFIG = dict(_PG_CONFIG)
if config("PG_TIMEOUT", cast=int, default=0) > 0:
PG_CONFIG["options"] = f"-c statement_timeout={config('PG_TIMEOUT', cast=int) * 1000}"
class ORThreadedConnectionPool(psycopg2.pool.ThreadedConnectionPool):
def __init__(self, minconn, maxconn, *args, **kwargs):
self._semaphore = Semaphore(maxconn)
super().__init__(minconn, maxconn, *args, **kwargs)
def getconn(self, *args, **kwargs):
self._semaphore.acquire()
try:
return super().getconn(*args, **kwargs)
except psycopg2.pool.PoolError as e:
if str(e) == "connection pool is closed":
make_pool()
raise e
def putconn(self, *args, **kwargs):
try:
super().putconn(*args, **kwargs)
self._semaphore.release()
except psycopg2.pool.PoolError as e:
if str(e) == "trying to put unkeyed connection":
logger.warning("!!! trying to put unkeyed connection")
logger.warning(f"env-PG_POOL:{config('PG_POOL', default=None)}")
return
raise e
postgreSQL_pool: ORThreadedConnectionPool = None
RETRY_MAX = config("PG_RETRY_MAX", cast=int, default=50)
RETRY_INTERVAL = config("PG_RETRY_INTERVAL", cast=int, default=2)
RETRY = 0
def make_pool():
if not config('PG_POOL', cast=bool, default=True):
return
global postgreSQL_pool
global RETRY
if postgreSQL_pool is not None:
try:
postgreSQL_pool.closeall()
except (Exception, psycopg2.DatabaseError) as error:
logger.error("Error while closing all connexions to PostgreSQL", exc_info=error)
try:
postgreSQL_pool = ORThreadedConnectionPool(config("PG_MINCONN", cast=int, default=4),
config("PG_MAXCONN", cast=int, default=8),
**PG_CONFIG)
if postgreSQL_pool is not None:
logger.info("Connection pool created successfully")
except (Exception, psycopg2.DatabaseError) as error:
logger.error("Error while connecting to PostgreSQL", exc_info=error)
if RETRY < RETRY_MAX:
RETRY += 1
logger.info(f"Waiting for {RETRY_INTERVAL}s before retry n°{RETRY}")
time.sleep(RETRY_INTERVAL)
make_pool()
else:
raise error
class PostgresClient:
connection = None
cursor = None
long_query = False
unlimited_query = False
def __init__(self, long_query=False, unlimited_query=False, use_pool=True):
self.long_query = long_query
self.unlimited_query = unlimited_query
self.use_pool = use_pool
if unlimited_query:
long_config = dict(_PG_CONFIG)
long_config["application_name"] += "-UNLIMITED"
self.connection = psycopg2.connect(**long_config)
elif long_query:
long_config = dict(_PG_CONFIG)
long_config["application_name"] += "-LONG"
if config('PG_TIMEOUT_LONG', cast=int, default=1) > 0:
long_config["options"] = f"-c statement_timeout=" \
f"{config('PG_TIMEOUT_LONG', cast=int, default=5 * 60) * 1000}"
else:
logger.info("Disabled timeout for long query")
self.connection = psycopg2.connect(**long_config)
elif not use_pool or not config('PG_POOL', cast=bool, default=True):
single_config = dict(_PG_CONFIG)
single_config["application_name"] += "-NOPOOL"
if config('PG_TIMEOUT', cast=int, default=1) > 0:
single_config["options"] = f"-c statement_timeout={config('PG_TIMEOUT', cast=int, default=30) * 1000}"
self.connection = psycopg2.connect(**single_config)
else:
self.connection = postgreSQL_pool.getconn()
def __enter__(self):
if self.cursor is None:
self.cursor = self.connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
self.cursor.cursor_execute = self.cursor.execute
self.cursor.execute = self.__execute
self.cursor.recreate = self.recreate_cursor
return self.cursor
def __exit__(self, *args):
try:
self.connection.commit()
self.cursor.close()
if not self.use_pool or self.long_query or self.unlimited_query:
self.connection.close()
except Exception as error:
logger.error("Error while committing/closing PG-connection", exc_info=error)
if str(error) == "connection already closed" \
and self.use_pool \
and not self.long_query \
and not self.unlimited_query \
and config('PG_POOL', cast=bool, default=True):
logger.info("Recreating the connexion pool")
make_pool()
else:
raise error
finally:
if config('PG_POOL', cast=bool, default=True) \
and self.use_pool \
and not self.long_query \
and not self.unlimited_query:
postgreSQL_pool.putconn(self.connection)
def __execute(self, query, vars=None):
try:
result = self.cursor.cursor_execute(query=query, vars=vars)
except psycopg2.Error as error:
logger.error(f"!!! Error of type:{type(error)} while executing query:")
logger.error(query)
logger.info("starting rollback to allow future execution")
try:
self.connection.rollback()
except psycopg2.InterfaceError as e:
logger.error("!!! Error while rollbacking connection", exc_info=e)
logger.error("!!! Trying to recreate the cursor")
self.recreate_cursor()
raise error
return result
def recreate_cursor(self, rollback=False):
if rollback:
try:
self.connection.rollback()
except Exception as error:
logger.error("Error while rollbacking connection for recreation", exc_info=error)
try:
self.cursor.close()
except Exception as error:
logger.error("Error while closing cursor for recreation", exc_info=error)
self.cursor = None
return self.__enter__()
async def health_check(self) -> Dict[str, Any]:
"""
Instance method to check DB connection health
"""
try:
start_time = asyncio.get_event_loop().time()
# Run the query in a thread pool to avoid blocking the event loop
loop = asyncio.get_event_loop()
def check_db():
cursor = self.connection.cursor()
cursor.execute("SELECT 1")
cursor.close()
return True
await loop.run_in_executor(None, check_db)
end_time = asyncio.get_event_loop().time()
ping_time_ms = (end_time - start_time) * 1000
return {
"status": "ok",
"message": "PostgreSQL connection is healthy",
"ping_time_ms": round(ping_time_ms, 2)
}
except Exception as e:
logger.error(f"PostgreSQL health check failed: {e}")
return {
"status": "error",
"message": f"Failed to connect to PostgreSQL: {str(e)}"
}
@classmethod
async def health_check(cls) -> Dict[str, Any]:
"""
Class method to check if PostgreSQL connection works.
Can be called directly on the class: await PostgresClient.health_check()
"""
try:
# Create a temporary client for the health check
client = cls()
start_time = asyncio.get_event_loop().time()
# Run the query in a thread pool
loop = asyncio.get_event_loop()
def check_db():
cursor = client.connection.cursor()
cursor.execute("SELECT 1")
cursor.close()
return True
await loop.run_in_executor(None, check_db)
end_time = asyncio.get_event_loop().time()
ping_time_ms = (end_time - start_time) * 1000
# Properly clean up the connection
if not client.use_pool or client.long_query or client.unlimited_query:
client.connection.close()
else:
postgreSQL_pool.putconn(client.connection)
return {
"status": "ok",
"message": "PostgreSQL connection is healthy",
"ping_time_ms": round(ping_time_ms, 2)
}
except Exception as e:
logger.error(f"PostgreSQL health check failed: {e}")
return {
"status": "error",
"message": f"Failed to connect to PostgreSQL: {str(e)}"
}
# Add get_client function at module level
def get_client(long_query=False, unlimited_query=False, use_pool=True) -> PostgresClient:
"""
Get a PostgreSQL client instance.
Args:
long_query: Set True for queries with extended timeout
unlimited_query: Set True for queries with no timeout
use_pool: Whether to use the connection pool
Returns:
PostgresClient instance
"""
return PostgresClient(long_query=long_query, unlimited_query=unlimited_query, use_pool=use_pool)
async def init():
logger.info(f">use PG_POOL:{config('PG_POOL', default=True)}")
if config('PG_POOL', cast=bool, default=True):
make_pool()
# Do a health check at initialization
try:
health_status = await PostgresClient.health_check()
if health_status["status"] == "ok":
logger.info(f"PostgreSQL connection verified. Ping: {health_status.get('ping_time_ms', 'N/A')}ms")
else:
logger.warning(f"PostgreSQL connection check failed: {health_status['message']}")
except Exception as e:
logger.error(f"Error during initialization health check: {str(e)}")
async def terminate():
global postgreSQL_pool
if postgreSQL_pool is not None:
try:
postgreSQL_pool.closeall()
logger.info("Closed all connexions to PostgreSQL")
except (Exception, psycopg2.DatabaseError) as error:
logger.error("Error while closing all connexions to PostgreSQL", exc_info=error)
async def health_check() -> Dict[str, Any]:
"""
Public health check function that can be used by the application.
Returns:
Health status dict
"""
return await PostgresClient.health_check()