* fix(connector): fixed bug of cache dict size error * fix(connector): Added method to save state in s3 for redshift if sigterm arise * fix(connector): Added exit signal handler and checkpoint method * Added sslmode selection for connection to database, added use_ssl parameter for S3 connection
193 lines
7.5 KiB
Python
193 lines
7.5 KiB
Python
from sqlalchemy import create_engine
|
|
from sqlalchemy import MetaData
|
|
from sqlalchemy.orm import sessionmaker, session
|
|
from contextlib import contextmanager
|
|
import logging
|
|
from decouple import config as _config
|
|
from decouple import Choices
|
|
from pathlib import Path
|
|
import io
|
|
|
|
DATABASE = _config('CLOUD_SERVICE')
|
|
sslmode = _config('DB_SSLMODE',
|
|
cast=Choices(['disable', 'allow', 'prefer', 'require', 'verify-ca', 'verify-full']),
|
|
default='allow'
|
|
)
|
|
if DATABASE == 'redshift':
|
|
import pandas_redshift as pr
|
|
import botocore
|
|
use_ssl = _config('S3_USE_SSL', default=True, cast=bool)
|
|
|
|
base_path = Path(__file__).parent.parent
|
|
|
|
from db.models import Base
|
|
|
|
logger = logging.getLogger(__file__)
|
|
|
|
|
|
def get_class_by_tablename(tablename):
|
|
"""Return class reference mapped to table.
|
|
Raise an exception if class not found
|
|
|
|
:param tablename: String with name of table.
|
|
:return: Class reference.
|
|
"""
|
|
for c in Base._decl_class_registry.values():
|
|
if hasattr(c, '__tablename__') and c.__tablename__ == tablename:
|
|
return c
|
|
raise AttributeError(f'No model with tablename "{tablename}"')
|
|
|
|
|
|
class DBConnection:
|
|
"""
|
|
Initializes connection to a database
|
|
To update models file use:
|
|
sqlacodegen --outfile models_universal.py mysql+pymysql://{USER}:{pwd}@{HOST}
|
|
"""
|
|
_sessions = sessionmaker()
|
|
|
|
def __init__(self, config) -> None:
|
|
self.metadata = MetaData()
|
|
self.config = config
|
|
|
|
if config == 'redshift':
|
|
self.pdredshift = pr
|
|
ci = _config('cluster_info', default='')
|
|
cluster_info = dict()
|
|
if ci == '':
|
|
cluster_info['USER'] = _config('USER')
|
|
cluster_info['HOST'] = _config('HOST')
|
|
cluster_info['PORT'] = _config('PORT')
|
|
cluster_info['PASSWORD'] = _config('PASSWORD')
|
|
cluster_info['DBNAME'] = _config('DBNAME')
|
|
else:
|
|
ci = ci.split(' ')
|
|
cluster_info = dict()
|
|
for _d in ci:
|
|
k,v = _d.split('=')
|
|
cluster_info[k]=v
|
|
self.pdredshift.connect_to_redshift(dbname=cluster_info['DBNAME'],
|
|
host=cluster_info['HOST'],
|
|
port=cluster_info['PORT'],
|
|
user=cluster_info['USER'],
|
|
password=cluster_info['PASSWORD'],
|
|
sslmode=sslmode)
|
|
|
|
self.pdredshift.connect_to_s3(aws_access_key_id=_config('AWS_ACCESS_KEY_ID'),
|
|
aws_secret_access_key=_config('AWS_SECRET_ACCESS_KEY'),
|
|
bucket=_config('BUCKET'),
|
|
subdirectory=_config('SUBDIRECTORY', default=None),
|
|
use_ssl=use_ssl)
|
|
|
|
self.CONNECTION_STRING = _config('CONNECTION_STRING').format(
|
|
USER=cluster_info['USER'],
|
|
PASSWORD=cluster_info['PASSWORD'],
|
|
HOST=cluster_info['HOST'],
|
|
PORT=cluster_info['PORT'],
|
|
DBNAME=cluster_info['DBNAME']
|
|
)
|
|
self.engine = create_engine(self.CONNECTION_STRING, connect_args={'sslmode': sslmode})
|
|
|
|
elif config == 'clickhouse':
|
|
self.CONNECTION_STRING = _config('CONNECTION_STRING').format(
|
|
HOST=_config('HOST'),
|
|
DATABASE=_config('DATABASE')
|
|
)
|
|
self.engine = create_engine(self.CONNECTION_STRING, connect_args={'sslmode': sslmode})
|
|
elif config == 'pg':
|
|
self.CONNECTION_STRING = _config('CONNECTION_STRING').format(
|
|
USER=_config('USER'),
|
|
PASSWORD=_config('PASSWORD'),
|
|
HOST=_config('HOST'),
|
|
PORT=_config('PORT'),
|
|
DATABASE=_config('DATABASE')
|
|
)
|
|
self.engine = create_engine(self.CONNECTION_STRING, connect_args={'sslmode': sslmode})
|
|
elif config == 'bigquery':
|
|
pass
|
|
elif config == 'snowflake':
|
|
self.CONNECTION_STRING = _config('CONNECTION_STRING').format(
|
|
USER=_config('USER'),
|
|
PASSWORD=_config('PASSWORD'),
|
|
ACCOUNT=_config('ACCOUNT'),
|
|
DATABASE=_config('DATABASE'),
|
|
DBNAME = _config('DBNAME'),
|
|
WAREHOUSE = _config('WAREHOUSE')
|
|
)
|
|
self.engine = create_engine(self.CONNECTION_STRING, connect_args={'sslmode': sslmode})
|
|
else:
|
|
raise ValueError("This db configuration doesn't exist. Add into keys file.")
|
|
|
|
@contextmanager
|
|
def get_test_session(self, **kwargs) -> session:
|
|
"""
|
|
Test session context, even commits won't be persisted into db.
|
|
:Keyword Arguments:
|
|
* autoflush (``bool``) -- default: True
|
|
* autocommit (``bool``) -- default: False
|
|
* expire_on_commit (``bool``) -- default: True
|
|
"""
|
|
connection = self.engine.connect()
|
|
transaction = connection.begin()
|
|
my_session = type(self)._sessions(bind=connection, **kwargs)
|
|
yield my_session
|
|
|
|
# Do cleanup, rollback and closing, whatever happens
|
|
my_session.close()
|
|
transaction.rollback()
|
|
connection.close()
|
|
|
|
@contextmanager
|
|
def get_live_session(self) -> session:
|
|
"""
|
|
This is a session that can be committed.
|
|
Changes will be reflected in the database.
|
|
"""
|
|
# Automatic transaction and connection handling in session
|
|
connection = self.engine.connect()
|
|
my_session = type(self)._sessions(bind=connection)
|
|
|
|
yield my_session
|
|
|
|
my_session.close()
|
|
connection.close()
|
|
|
|
def restart(self):
|
|
self.close()
|
|
self.__init__(config=self.config)
|
|
|
|
def save_binary(self, binary_data, name, **kwargs):
|
|
if self.config == 'redshift':
|
|
try:
|
|
self.pdredshift.core.s3.Object(bucket_name=self.pdredshift.core.s3_bucket_var,
|
|
key=self.pdredshift.core.s3_subdirectory_var + name).put(
|
|
Body=binary_data, **kwargs)
|
|
print(f'[INFO] Content saved: {name}')
|
|
except botocore.exceptions.ClientError as err:
|
|
print(repr(err))
|
|
|
|
def load_binary(self, name):
|
|
if self.config == 'redshift':
|
|
try:
|
|
s3_object = self.pdredshift.core.s3.Object(bucket_name=self.pdredshift.core.s3_bucket_var,
|
|
key=self.pdredshift.core.s3_subdirectory_var + name)
|
|
f = io.BytesIO()
|
|
s3_object.download_fileobj(f)
|
|
print(f'[INFO] Content downloaded: {name}')
|
|
return f
|
|
except botocore.exceptions.ClientError as err:
|
|
print(repr(err))
|
|
|
|
def delete_binary(self, name):
|
|
if self.config == 'redshift':
|
|
try:
|
|
s3_object = self.pdredshift.core.s3.Object(bucket_name=self.pdredshift.core.s3_bucket_var,
|
|
key=self.pdredshift.core.s3_subdirectory_var + name)
|
|
s3_object.delete()
|
|
print(f'[INFO] s3 object {name} deleted')
|
|
except botocore.exceptions.ClientError as err:
|
|
print(repr(err))
|
|
|
|
def close(self):
|
|
if self.config == 'redshift':
|
|
self.pdredshift.close_up_shop()
|