openreplay/ee/connectors/db/api.py
2021-05-18 17:10:08 +03:00

129 lines
4.6 KiB
Python

from sqlalchemy import create_engine
from sqlalchemy import MetaData
from sqlalchemy.orm import sessionmaker, session
from contextlib import contextmanager
import logging
import os
from pathlib import Path
DATABASE = os.environ['DATABASE_NAME']
if DATABASE == 'redshift':
import pandas_redshift as pr
base_path = Path(__file__).parent.parent
from db.models import Base
logger = logging.getLogger(__file__)
def get_class_by_tablename(tablename):
"""Return class reference mapped to table.
Raise an exception if class not found
:param tablename: String with name of table.
:return: Class reference.
"""
for c in Base._decl_class_registry.values():
if hasattr(c, '__tablename__') and c.__tablename__ == tablename:
return c
raise AttributeError(f'No model with tablename "{tablename}"')
class DBConnection:
"""
Initializes connection to a database
To update models file use:
sqlacodegen --outfile models_universal.py mysql+pymysql://{user}:{pwd}@{address}
"""
_sessions = sessionmaker()
def __init__(self, config) -> None:
self.metadata = MetaData()
self.config = config
if config == 'redshift':
self.pdredshift = pr
self.pdredshift.connect_to_redshift(dbname=os.environ['schema'],
host=os.environ['address'],
port=os.environ['port'],
user=os.environ['user'],
password=os.environ['password'])
self.pdredshift.connect_to_s3(aws_access_key_id=os.environ['aws_access_key_id'],
aws_secret_access_key=os.environ['aws_secret_access_key'],
bucket=os.environ['bucket'],
subdirectory=os.environ['subdirectory'])
self.connect_str = os.environ['connect_str'].format(
user=os.environ['user'],
password=os.environ['password'],
address=os.environ['address'],
port=os.environ['port'],
schema=os.environ['schema']
)
self.engine = create_engine(self.connect_str)
elif config == 'clickhouse':
self.connect_str = os.environ['connect_str'].format(
address=os.environ['address'],
database=os.environ['database']
)
self.engine = create_engine(self.connect_str)
elif config == 'pg':
self.connect_str = os.environ['connect_str'].format(
user=os.environ['user'],
password=os.environ['password'],
address=os.environ['address'],
port=os.environ['port'],
database=os.environ['database']
)
self.engine = create_engine(self.connect_str)
elif config == 'bigquery':
pass
elif config == 'snowflake':
self.connect_str = os.environ['connect_str'].format(
user=os.environ['user'],
password=os.environ['password'],
account=os.environ['account'],
database=os.environ['database'],
schema = os.environ['schema'],
warehouse = os.environ['warehouse']
)
self.engine = create_engine(self.connect_str)
else:
raise ValueError("This db configuration doesn't exist. Add into keys file.")
@contextmanager
def get_test_session(self, **kwargs) -> session:
"""
Test session context, even commits won't be persisted into db.
:Keyword Arguments:
* autoflush (``bool``) -- default: True
* autocommit (``bool``) -- default: False
* expire_on_commit (``bool``) -- default: True
"""
connection = self.engine.connect()
transaction = connection.begin()
my_session = type(self)._sessions(bind=connection, **kwargs)
yield my_session
# Do cleanup, rollback and closing, whatever happens
my_session.close()
transaction.rollback()
connection.close()
@contextmanager
def get_live_session(self) -> session:
"""
This is a session that can be committed.
Changes will be reflected in the database.
"""
# Automatic transaction and connection handling in session
connection = self.engine.connect()
my_session = type(self)._sessions(bind=connection)
yield my_session
my_session.close()
connection.close()