fix(connector): Redshift pool fixes (#1393)

* fix(connector): set env variable for replace and set it to one minute default

* style(connector): different log for redshift and pg

* style(connector): different log for redshift select and replace

* fix(connector): replacing from oldest to newest to avoid blocking

* fix(connector): empty string changed to NN and fixed str issue for cron job

* Changing methods in connector

* fix(connectors): solved issues when replacing with null, reduced number of queries to redshift

* fix(connectors): fixed save method

* fix(connectors): fixed issue while saving event object
This commit is contained in:
MauricioGarciaS 2023-07-07 15:33:22 +02:00 committed by GitHub
parent 18c4dcd475
commit 4cf50b3066
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 101 additions and 36 deletions

View file

@ -5,7 +5,8 @@ from apscheduler.triggers.interval import IntervalTrigger
from utils import pg_client
from decouple import config, Choices
import asyncio
from time import time
from time import time, sleep
import logging
DATABASE = config('CLOUD_SERVICE')
@ -28,62 +29,102 @@ else:
cluster_info = dict()
for _d in ci:
k,v = _d.split('=')
cluster_info[k]=v
pdredshift.connect_to_redshift(dbname=cluster_info['DBNAME'],
cluster_info[k] = v
class RDSHFT:
def __init__(self):
self.pdredshift = pdredshift
self.pdredshift.connect_to_redshift(dbname=cluster_info['DBNAME'],
host=cluster_info['HOST'],
port=cluster_info['PORT'],
user=cluster_info['USER'],
password=cluster_info['PASSWORD'],
sslmode=sslmode)
def restart(self):
self.close()
self.__init__()
def redshift_to_pandas(self, query):
return self.pdredshift.redshift_to_pandas(query)
def exec_commit(self, base_query):
try:
self.pdredshift.exec_commit(base_query)
except Exception as e:
logging.warning('[FILL Exception]', repr(e))
self.pdredshift.connect.rollback()
raise
def close(self):
self.pdredshift.close_up_shop()
api = RDSHFT()
def try_method(f, params, on_exeption=None, _try=0):
try:
res = f(params)
return res
except Exception as e:
if _try > 3:
if on_exeption is None:
return
on_exeption.close()
else:
logging.warning('[FILL Exception]', repr(e), 'retrying..')
sleep(1)
return try_method(f=f, params=params, on_exeption=on_exeption, _try=_try+1)
return
async def main():
limit = config('FILL_QUERY_LIMIT', default=100, cast=int)
t = time()
query = "SELECT sessionid FROM {table} WHERE user_id = 'NULL' LIMIT {limit}"
try:
res = pdredshift.redshift_to_pandas(query.format(table=table, limit=limit))
except Exception as e:
print('[FILL Exception]',repr(e))
res = list()
query = "SELECT sessionid FROM {table} WHERE user_id = 'NULL' ORDER BY session_start_timestamp ASC LIMIT {limit}"
res = api.redshift_to_pandas(query.format(table=table, limit=limit))
if res is None:
logging.info('[FILL INFO] response is None')
return
elif len(res) == 0:
logging.info('[FILL INFO] zero length response')
return
# logging.info(f'[FILL INFO] {len(res)} length response')
sessionids = list(map(lambda k: str(k), res['sessionid']))
with pg_client.PostgresClient() as conn:
conn.execute('SELECT session_id, user_id FROM sessions WHERE session_id IN ({session_id_list})'.format(
session_id_list = ','.join(sessionids))
session_id_list=','.join(sessionids))
)
pg_res = conn.fetchall()
logging.info(f'response from pg, length {len(pg_res)}')
df = pd.DataFrame(pg_res)
df.dropna(inplace=True)
df.fillna('NN', inplace=True)
df = df.groupby('user_id').agg({'session_id': lambda x: list(x)})
base_query = "UPDATE {table} SET user_id = CASE".format(table=table)
template = "\nWHEN sessionid IN ({session_ids}) THEN '{user_id}'"
all_ids = list()
# logging.info(f'[FILL INFO] {pg_res[:5]}')
for i in range(len(df)):
user = df.iloc[i].name
if user == '' or user == 'None' or user == 'NULL':
continue
aux = [str(sess) for sess in df.iloc[i].session_id]
aux = [str(sess) for sess in df.iloc[i].session_id if sess != 'NN']
all_ids += aux
if len(aux) == 0:
continue
base_query += template.format(user_id=user, session_ids=','.join(aux))
base_query += f"\nEND WHERE sessionid IN ({','.join(all_ids)})"
if len(all_ids) == 0:
logging.info('[FILL INFO] No ids obtained')
return
try:
pdredshift.exec_commit(base_query)
except Exception as e:
print('[FILL Exception]',repr(e))
print(f'[FILL-INFO] {time()-t} - for {len(sessionids)} elements')
# logging.info(f'[FILL INFO] {base_query}')
api.exec_commit(base_query)
logging.info(f'[FILL-INFO] {time()-t} - for {len(sessionids)} elements')
cron_jobs = [
{"func": main, "trigger": IntervalTrigger(seconds=15), "misfire_grace_time": 60, "max_instances": 1},
{"func": main, "trigger": IntervalTrigger(seconds=config('REPLACE_INTERVAL_USERID', default=60, cast=int)), "misfire_grace_time": 60, "max_instances": 1},
]

View file

@ -308,6 +308,10 @@ class WorkerPool:
self.pointer = 0
self.n_workers = n_workers
self.project_filter_class = ProjectFilter(project_filter)
self.sessions_update_batch = dict()
self.sessions_insert_batch = dict()
self.events_batch = list()
self.n_of_loops = config('LOOPS_BEFORE_UPLOAD', default=4, cast=int)
def get_worker(self, session_id: int) -> int:
if session_id in self.assigned_worker.keys():
@ -319,9 +323,6 @@ class WorkerPool:
return worker_id
def _pool_response_handler(self, pool_results):
events_batch = list()
sessions_update_batch = list()
sessions_insert_batch = list()
count = 0
for js_response in pool_results:
flag = js_response.pop('flag')
@ -329,19 +330,21 @@ class WorkerPool:
worker_events, worker_memory, end_sessions = js_response['value']
if worker_memory is None:
continue
events_batch += worker_events
self.events_batch += worker_events
for session_id in worker_memory.keys():
self.sessions[session_id] = dict_to_session(worker_memory[session_id])
self.project_filter_class.sessions_lifespan.add(session_id)
for session_id in end_sessions:
if self.sessions[session_id].session_start_timestamp:
old_status = self.project_filter_class.sessions_lifespan.close(session_id)
if old_status == 'UPDATE':
sessions_update_batch.append(deepcopy(self.sessions[session_id]))
if (old_status == 'UPDATE' or old_status == 'CLOSE') and session_id not in self.sessions_insert_batch.keys():
self.sessions_update_batch[session_id] = deepcopy(self.sessions[session_id])
elif (old_status == 'UPDATE' or old_status == 'CLOSE') and session_id in self.sessions_insert_batch.keys():
self.sessions_insert_batch[session_id] = deepcopy(self.sessions[session_id])
elif old_status == 'OPEN':
sessions_insert_batch.append(deepcopy(self.sessions[session_id]))
self.sessions_insert_batch[session_id] = deepcopy(self.sessions[session_id])
else:
print('[WORKER-WARN] Closed session should not be closed again')
print(f'[WORKER Exception] Unknown session status: {old_status}')
elif flag == 'reader':
count += 1
if count > 1:
@ -360,7 +363,7 @@ class WorkerPool:
del self.assigned_worker[sess_id]
except KeyError:
...
return events_batch, sessions_insert_batch, sessions_update_batch, session_ids, messages
return session_ids, messages
def run_workers(self, database_api):
global sessions_table_name, table_name, EVENT_TYPE
@ -371,7 +374,9 @@ class WorkerPool:
'project_filter': self.project_filter_class}
kafka_reader_process = Process(target=read_from_kafka, args=(reader_conn, kafka_task_params))
kafka_reader_process.start()
current_loop_number = 0
while signal_handler.KEEP_PROCESSING:
current_loop_number = (current_loop_number + 1) % self.n_of_loops
# Setup of parameters for workers
if not kafka_reader_process.is_alive():
print('[WORKER-INFO] Restarting reader task')
@ -394,8 +399,6 @@ class WorkerPool:
# Hand tasks to workers
async_results = list()
# for params in kafka_task_params:
# async_results.append(self.pool.apply_async(work_assigner, args=[params]))
for params in decoding_params:
if params['message']:
async_results.append(self.pool.apply_async(work_assigner, args=[params]))
@ -406,10 +409,14 @@ class WorkerPool:
except TimeoutError as e:
print('[WORKER-TimeoutError] Decoding of messages is taking longer than expected')
raise e
events_batch, sessions_insert_batch, sessions_update_batch, session_ids, messages = self._pool_response_handler(
session_ids, messages = self._pool_response_handler(
pool_results=results)
insertBatch(events_batch, sessions_insert_batch, sessions_update_batch, database_api, sessions_table_name,
table_name, EVENT_TYPE)
if current_loop_number == 0:
insertBatch(self.events_batch, self.sessions_insert_batch.values(), self.sessions_update_batch.values(),
database_api, sessions_table_name, table_name, EVENT_TYPE)
self.sessions_update_batch = dict()
self.sessions_insert_batch = dict()
self.events_batch = list()
self.save_snapshot(database_api)
main_conn.send('CONTINUE')
print('[WORKER-INFO] Sending close signal')
@ -432,7 +439,21 @@ class WorkerPool:
for sessionId, session_dict in checkpoint['sessions']:
self.sessions[sessionId] = dict_to_session(session_dict)
self.project_filter_class.sessions_lifespan.session_project = checkpoint['cached_sessions']
elif checkpoint['version'] == 'v1.1':
for sessionId, session_dict in checkpoint['sessions']:
self.sessions[sessionId] = dict_to_session(session_dict)
self.project_filter_class.sessions_lifespan.session_project = checkpoint['cached_sessions']
for sessionId in checkpoint['sessions_update_batch']:
try:
self.sessions_update_batch[sessionId] = self.sessions[sessionId]
except Exception:
continue
for sessionId in checkpoint['sessions_insert_batch']:
try:
self.sessions_insert_batch[sessionId] = self.sessions[sessionId]
except Exception:
continue
self.events_batch = [dict_to_event(event) for event in checkpoint['events_batch']]
else:
raise Exception('Error in version of snapshot')
@ -446,8 +467,11 @@ class WorkerPool:
for sessionId, session in self.sessions.items():
session_snapshot.append([sessionId, session_to_dict(session)])
checkpoint = {
'version': 'v1.0',
'version': 'v1.1',
'sessions': session_snapshot,
'cached_sessions': self.project_filter_class.sessions_lifespan.session_project,
'sessions_update_batch': list(self.sessions_update_batch.keys()),
'sessions_insert_batch': list(self.sessions_insert_batch.keys()),
'events_batch': [event_to_dict(event) for event in self.events_batch]
}
database_api.save_binary(binary_data=json.dumps(checkpoint).encode('utf-8'), name='checkpoint')