openreplay/ee/connectors/utils/worker.py
MauricioGarciaS 5938fd95de
fix(connector-redshift): Changed PG pool (#1821)
* Added exception in pool

* Solved issue with message codec

* Changed pg pool to normal pg connection

* fix(redshift-connector): Fixed close connection when exception
2024-01-17 10:33:21 +01:00

562 lines
23 KiB
Python

from multiprocessing import Pool, Process, Pipe, TimeoutError
from multiprocessing.connection import Connection
from db.api import DBConnection
from msgcodec import MessageCodec
from messages import SessionEnd
from utils.uploader import insertBatch
from utils.cache import CachedSessions
from db.models import DetailedEvent, Event, Session, events_detailed_table_name, events_table_name, sessions_table_name
from handler import handle_normal_message, handle_message, handle_session
from datetime import datetime
from decouple import config
from utils import pg_client
from utils.signal_handler import signal_handler
from copy import deepcopy
from confluent_kafka import Consumer
import pandas as pd
from time import time
import logging
import json
import asyncio
EVENT_TYPE = config('EVENT_TYPE')
DATABASE = config('CLOUD_SERVICE')
UPLOAD_RATE = config('upload_rate', default=30, cast=int)
if EVENT_TYPE == 'detailed':
table_name = events_detailed_table_name
elif EVENT_TYPE == 'normal':
table_name = events_table_name
TOPICS = config("TOPICS", default="saas-raw").split(',')
ssl_protocol = config('KAFKA_USE_SSL', default=True, cast=bool)
consumer_settings = {
"bootstrap.servers": config('KAFKA_SERVERS'),
"group.id": f"connector_{DATABASE}",
"auto.offset.reset": "earliest",
"enable.auto.commit": False
}
if ssl_protocol:
consumer_settings['security.protocol'] = 'SSL'
session_messages = [1, 25, 28, 29, 30, 31, 32, 54, 56, 62, 69, 78, 125, 126]
if EVENT_TYPE == 'normal':
events_messages = [21, 22, 25, 27, 64, 69, 78, 125]
elif EVENT_TYPE == 'detailed':
events_messages = [1, 4, 21, 22, 25, 27, 31, 32, 39, 48, 59, 64, 69, 78, 125, 126]
allowed_messages = list(set(session_messages + events_messages))
codec = MessageCodec(allowed_messages)
max_kafka_read = config('MAX_KAFKA_READ', default=60000, cast=int)
def init_consumer():
global DATABASE, consumer_settings
consumer = Consumer(consumer_settings)
consumer.subscribe(TOPICS)
return consumer
def close_consumer(consumer):
consumer.unsubscribe()
consumer.close()
def session_to_dict(sess: Session):
_dict = sess.__dict__
try:
del _dict['_sa_instance_state']
except KeyError:
...
return _dict
def dict_to_session(session_dict: dict):
n = Session()
n.__dict__ |= session_dict
return n
def event_to_dict(event: Event | DetailedEvent):
_dict = event.__dict__
try:
del _dict['_sa_instance_state']
except KeyError:
...
return _dict
def dict_to_event(event_dict: dict):
global EVENT_TYPE
if EVENT_TYPE == 'detailed':
n = DetailedEvent()
else:
n = Event()
n.__dict__ |= event_dict
return n
class ProjectFilter:
def __init__(self, project_filter):
self.max_lifespan = config('MAX_UNWANTED_SESSION_LIFE', default=7800, cast=int)
self.project_filter = project_filter
self.sessions_lifespan = CachedSessions()
self.non_valid_sessions_cache = dict()
def is_valid(self, sessionId: int):
if len(self.project_filter) == 0:
return True
elif sessionId in self.sessions_lifespan.session_project.keys():
return True
elif sessionId in self.non_valid_sessions_cache.keys():
return False
else:
projectId = project_from_session(sessionId)
if projectId not in self.project_filter:
self.non_valid_sessions_cache[sessionId] = int(datetime.now().timestamp())
return False
else:
return True
def already_checked(self, sessionId):
if len(self.project_filter) == 0:
return True, True
elif sessionId in self.sessions_lifespan.session_project.keys():
return True, True
elif sessionId in self.non_valid_sessions_cache.keys():
return True, False
else:
return False, None
def are_valid(self, sessionIds: list[int]):
valid_sessions = list()
if len(self.project_filter) == 0:
return sessionIds
projects_session = project_from_sessions(list(set(sessionIds)))
current_datetime = int(datetime.now().timestamp())
for projectId, sessionId in projects_session:
if projectId not in self.project_filter:
self.non_valid_sessions_cache[sessionId] = current_datetime
else:
valid_sessions.append(sessionId)
return valid_sessions
def handle_clean(self):
if len(self.project_filter) == 0:
return
else:
current_timestamp = datetime.now().timestamp()
self.non_valid_sessions_cache = {sessionId: start_timestamp for sessionId, start_timestamp in
self.non_valid_sessions_cache.items() if
self.max_lifespan > current_timestamp - start_timestamp}
def read_from_kafka(pipe: Connection, params: dict):
global UPLOAD_RATE, max_kafka_read
# try:
# asyncio.run(pg_client.init())
kafka_consumer = init_consumer()
project_filter = params['project_filter']
capture_messages = list()
capture_sessions = list()
while True:
to_decode = list()
sessionIds = list()
start_time = datetime.now().timestamp()
broken_batchs = 0
n_messages = 0
while datetime.now().timestamp() - start_time < UPLOAD_RATE and max_kafka_read > n_messages:
try:
msg = kafka_consumer.poll(5.0)
except Exception as e:
print('[WORKER Exception]', e)
if msg is None:
continue
n_messages += 1
try:
sessionId = codec.decode_key(msg.key())
except Exception:
broken_batchs += 1
continue
checked, is_valid = project_filter.already_checked(sessionId)
if not checked:
capture_sessions.append(sessionId)
capture_messages.append(msg.value())
elif is_valid:
to_decode.append(msg.value())
sessionIds.append(sessionId)
# if project_filter.is_valid(sessionId):
# to_decode.append(msg.value())
# sessionIds.append(sessionId)
valid_sessions = project_filter.are_valid(list(set(capture_sessions)))
while capture_sessions:
sessId = capture_sessions.pop()
msg = capture_messages.pop()
if sessId in valid_sessions:
sessionIds.append(sessId)
to_decode.append(msg)
if n_messages != 0:
print(
f'[WORKER INFO-bg] Found {broken_batchs} broken batch over {n_messages} read messages ({100 * broken_batchs / n_messages:.2f}%)')
else:
print('[WORKER WARN-bg] No messages read')
non_valid_updated = project_filter.non_valid_sessions_cache
pipe.send((non_valid_updated, sessionIds, to_decode))
continue_signal = pipe.recv()
if continue_signal == 'CLOSE':
print('[WORKER SHUTDOWN-reader] Reader shutting down')
break
kafka_consumer.commit()
print('[WORKER INFO] Closing consumer')
close_consumer(kafka_consumer)
print('[WORKER INFO] Closing pg connection')
# asyncio.run(pg_client.terminate())
print('[WORKER INFO] Successfully closed reader task')
# except Exception as e:
# print('[WARN]', repr(e))
def into_batch(batch: list[Event | DetailedEvent], session_id: int, n: Session):
n.sessionid = session_id
n.received_at = int(datetime.now().timestamp() * 1000)
n.batch_order_number = len(batch)
batch.append(n)
return batch
def project_from_session(sessionId: int):
"""Search projectId of requested sessionId in PG table sessions"""
with pg_client.PostgresClient().get_live_session() as conn:
cur = conn.execute(
conn.mogrify("SELECT project_id FROM sessions WHERE session_id=%(sessionId)s LIMIT 1",
{'sessionId': sessionId})
)
res = cur.fetchone()
if res is None:
print(f'[WORKER WARN] sessionid {sessionId} not found in sessions table')
return None
return res['project_id']
def project_from_sessions(sessionIds: list[int]):
"""Search projectId of requested sessionId in PG table sessions"""
response = list()
while sessionIds:
sessIds = sessionIds[-1000:]
try:
with pg_client.PostgresClient().get_live_session() as conn:
cur = conn.execute(
"SELECT session_id, project_id FROM sessions WHERE session_id IN ({sessionIds})".format(
sessionIds=','.join([str(sessId) for sessId in sessIds])
)
)
res = cur.fetchall()
except Exception as e:
print('[WORKER project_from_sessions]', repr(e))
raise e
if res is None:
print(f'[WORKER WARN] sessionids {",".join([str(sessId) for sessId in sessIds])} not found in sessions table')
else:
response += res
sessionIds = sessionIds[:-1000]
if not response:
return []
return [(e['project_id'], e['session_id']) for e in response]
def decode_message(params: dict):
global codec, session_messages, events_messages, EVENT_TYPE
if len(params['message']) == 0:
return list(), None, list()
memory = {sessId: dict_to_session(sessObj) for sessId, sessObj in params['memory'].items()}
events_worker_batch = list()
sessionid_ended = list()
for session_id, encoded_message in params['message']:
messages = codec.decode_detailed(encoded_message)
if messages is None:
continue
for message in messages:
if message is None:
continue
if message.__id__ in events_messages and EVENT_TYPE != 'detailed':
n = handle_normal_message(message)
if n:
events_worker_batch = into_batch(batch=events_worker_batch, session_id=session_id, n=n)
elif message.__id__ in events_messages and EVENT_TYPE == 'detailed':
n = handle_message(message)
if n:
events_worker_batch = into_batch(batch=events_worker_batch, session_id=session_id, n=n)
if message.__id__ in session_messages:
try:
memory[session_id] = handle_session(memory[session_id], message)
except KeyError:
memory[session_id] = handle_session(None, message)
memory[session_id].sessionid = session_id
if isinstance(message, SessionEnd):
sessionid_ended.append(session_id)
memory = {sessId: session_to_dict(sessObj) for sessId, sessObj in memory.items()}
return events_worker_batch, memory, sessionid_ended
def fix_missing_redshift():
DATABASE = config('CLOUD_SERVICE')
table = sessions_table_name
database_api = DBConnection(DATABASE)
limit = config('FILL_QUERY_LIMIT', default=100, cast=int)
t = time()
query = "SELECT sessionid FROM {table} WHERE user_id = 'NULL' ORDER BY session_start_timestamp ASC LIMIT {limit}"
try:
res = database_api.pdredshift.redshift_to_pandas(query.format(table=table, limit=limit))
except Exception as e:
# logging.error(f'[ERROR] Error while selecting query. {repr(e)}')
database_api.close()
return
if res is None:
logging.info('[FILL INFO] response is None')
database_api.close()
return
elif len(res) == 0:
logging.info('[FILL INFO] zero length response')
database_api.close()
return
# logging.info(f'[FILL INFO] {len(res)} length response')
sessionids = list(map(lambda k: str(k), res['sessionid']))
# asyncio.run(pg_client.init())
try:
with pg_client.PostgresClient().get_live_session() as conn:
cur = conn.execute('SELECT session_id, user_id FROM sessions WHERE session_id IN ({session_id_list})'.format(
session_id_list=','.join(sessionids))
)
pg_res = cur.fetchall()
except Exception as e:
#logging.error(f'[ERROR] Error while selecting from pg: {repr(e)}')
# asyncio.run(pg_client.terminate())
return
logging.info(f'response from pg, length {len(pg_res)}')
df = pd.DataFrame(pg_res)
df.fillna('NN', inplace=True)
df = df.groupby('user_id').agg({'session_id': lambda x: list(x)})
base_query = "UPDATE {table} SET user_id = CASE".format(table=table)
template = "\nWHEN sessionid IN ({session_ids}) THEN '{user_id}'"
all_ids = list()
# logging.info(f'[FILL INFO] {pg_res[:5]}')
for i in range(len(df)):
user = df.iloc[i].name.replace("'", "''")
aux = [str(sess) for sess in df.iloc[i].session_id if sess != 'NN']
all_ids += aux
if len(aux) == 0:
continue
base_query += template.format(user_id=user, session_ids=','.join(aux))
base_query += f"\nEND WHERE sessionid IN ({','.join(all_ids)})"
if len(all_ids) == 0:
logging.info('[FILL INFO] No ids obtained')
database_api.close()
# asyncio.run(pg_client.terminate())
return
# logging.info(f'[FILL INFO] {base_query}')
try:
database_api.pdredshift.exec_commit(base_query)
except Exception as e:
logging.error(f'[ERROR] Error while executing query. {repr(e)}')
logging.error(f'[ERROR INFO] query: {base_query}')
database_api.close()
# asyncio.run(pg_client.terminate())
return
logging.info(f'[FILL-INFO] {time() - t} - for {len(sessionids)} elements')
database_api.close()
# asyncio.run(pg_client.terminate())
return
def work_assigner(params):
flag = params.pop('flag')
if flag == 'decoder':
return {'flag': 'decoder', 'value': decode_message(params)}
elif flag == 'fix':
return {'flag': 'fix', 'value': fix_missing_redshift()}
class WorkerPool:
def __init__(self, n_workers: int, project_filter: list[int]):
self.pool = Pool(n_workers)
self.sessions = dict()
self.assigned_worker = dict()
self.pointer = 0
self.n_workers = n_workers
self.project_filter_class = ProjectFilter(project_filter)
self.sessions_update_batch = dict()
self.sessions_insert_batch = dict()
self.events_batch = list()
self.n_of_loops = config('LOOPS_BEFORE_UPLOAD', default=4, cast=int)
def get_worker(self, session_id: int) -> int:
if session_id in self.assigned_worker.keys():
worker_id = self.assigned_worker[session_id]
else:
worker_id = self.pointer
self.pointer = (self.pointer + 1) % self.n_workers
self.assigned_worker[session_id] = worker_id
return worker_id
def _pool_response_handler(self, pool_results):
count = 0
for js_response in pool_results:
flag = js_response.pop('flag')
if flag == 'decoder':
worker_events, worker_memory, end_sessions = js_response['value']
if worker_memory is None:
continue
self.events_batch += worker_events
for session_id in worker_memory.keys():
self.sessions[session_id] = dict_to_session(worker_memory[session_id])
self.project_filter_class.sessions_lifespan.add(session_id)
for session_id in end_sessions:
if self.sessions[session_id].session_start_timestamp:
old_status = self.project_filter_class.sessions_lifespan.close(session_id)
if (old_status == 'UPDATE' or old_status == 'CLOSE') and session_id not in self.sessions_insert_batch.keys():
self.sessions_update_batch[session_id] = deepcopy(self.sessions[session_id])
elif (old_status == 'UPDATE' or old_status == 'CLOSE') and session_id in self.sessions_insert_batch.keys():
self.sessions_insert_batch[session_id] = deepcopy(self.sessions[session_id])
elif old_status == 'OPEN':
self.sessions_insert_batch[session_id] = deepcopy(self.sessions[session_id])
else:
print(f'[WORKER Exception] Unknown session status: {old_status}')
elif flag == 'reader':
count += 1
if count > 1:
raise Exception('Pool only accepts one reader task')
non_valid_updated, session_ids, messages = js_response['value']
self.project_filter_class.non_valid_sessions_cache = non_valid_updated
self.project_filter_class.handle_clean()
sessions_to_delete = self.project_filter_class.sessions_lifespan.clear_sessions()
for sess_id in sessions_to_delete:
try:
del self.sessions[sess_id]
except KeyError:
...
try:
del self.assigned_worker[sess_id]
except KeyError:
...
return session_ids, messages
def run_workers(self, database_api):
global sessions_table_name, table_name, EVENT_TYPE
session_ids = list()
messages = list()
main_conn, reader_conn = Pipe()
kafka_task_params = {'flag': 'reader',
'project_filter': self.project_filter_class}
kafka_reader_process = Process(target=read_from_kafka, args=(reader_conn, kafka_task_params))
kafka_reader_process.start()
current_loop_number = 0
n_kafka_restarts = 0
while signal_handler.KEEP_PROCESSING:
current_loop_number = (current_loop_number + 1) % self.n_of_loops
# Setup of parameters for workers
if not kafka_reader_process.is_alive():
if n_kafka_restarts > 3:
break
print('[WORKER-INFO] Restarting reader task')
del kafka_reader_process
kafka_reader_process = Process(target=read_from_kafka, args=(reader_conn, kafka_task_params))
kafka_reader_process.start()
n_kafka_restarts += 1
decoding_params = [{'flag': 'decoder',
'message': list(),
'memory': dict()} for _ in range(self.n_workers)
]
for i in range(len(session_ids)):
session_id = session_ids[i]
worker_id = self.get_worker(session_id)
decoding_params[worker_id]['message'].append([session_id, messages[i]])
if session_id not in decoding_params[worker_id]['memory'].keys():
try:
decoding_params[worker_id]['memory'][session_id] = session_to_dict(self.sessions[session_id])
except KeyError:
...
# Hand tasks to workers
async_results = list()
for params in decoding_params:
if params['message']:
async_results.append(self.pool.apply_async(work_assigner, args=[params]))
results = [{'flag': 'reader', 'value': main_conn.recv()}]
async_results.append(self.pool.apply_async(work_assigner, args=[{'flag': 'fix'}]))
for async_result in async_results:
try:
results.append(async_result.get(timeout=32 * UPLOAD_RATE))
except TimeoutError as e:
print('[WORKER-TimeoutError] Decoding of messages is taking longer than expected')
raise e
except Exception as e:
print(f'[Exception] {e}')
self.sessions_update_batch = dict()
self.sessions_insert_batch = dict()
self.events_batch = list()
continue
session_ids, messages = self._pool_response_handler(
pool_results=results)
if current_loop_number == 0:
insertBatch(self.events_batch, self.sessions_insert_batch.values(), self.sessions_update_batch.values(),
database_api, sessions_table_name, table_name, EVENT_TYPE)
self.sessions_update_batch = dict()
self.sessions_insert_batch = dict()
self.events_batch = list()
self.save_snapshot(database_api)
main_conn.send('CONTINUE')
print('[WORKER-INFO] Sending close signal')
main_conn.send('CLOSE')
self.terminate(database_api)
kafka_reader_process.terminate()
print('[WORKER-SHUTDOWN] Process terminated')
def load_checkpoint(self, database_api):
file = database_api.load_binary(name='checkpoint')
checkpoint = json.loads(file.getvalue().decode('utf-8'))
file.close()
if 'version' not in checkpoint.keys():
sessions_cache_list = checkpoint['cache']
reload_default_time = datetime.now().timestamp()
self.project_filter_class.non_valid_sessions_cache = {int(sessId): reload_default_time for sessId, value in
sessions_cache_list.items() if not value[1]}
self.project_filter_class.sessions_lifespan.session_project = checkpoint['cached_sessions']
elif checkpoint['version'] == 'v1.0':
for sessionId, session_dict in checkpoint['sessions']:
self.sessions[sessionId] = dict_to_session(session_dict)
self.project_filter_class.sessions_lifespan.session_project = checkpoint['cached_sessions']
elif checkpoint['version'] == 'v1.1':
for sessionId, session_dict in checkpoint['sessions']:
self.sessions[sessionId] = dict_to_session(session_dict)
self.project_filter_class.sessions_lifespan.session_project = checkpoint['cached_sessions']
for sessionId in checkpoint['sessions_update_batch']:
try:
self.sessions_update_batch[sessionId] = self.sessions[sessionId]
except Exception:
continue
for sessionId in checkpoint['sessions_insert_batch']:
try:
self.sessions_insert_batch[sessionId] = self.sessions[sessionId]
except Exception:
continue
self.events_batch = [dict_to_event(event) for event in checkpoint['events_batch']]
else:
raise Exception('Error in version of snapshot')
def terminate(self, database_api):
self.pool.close()
self.save_snapshot(database_api)
database_api.close()
def save_snapshot(self, database_api):
session_snapshot = list()
for sessionId, session in self.sessions.items():
session_snapshot.append([sessionId, session_to_dict(session)])
checkpoint = {
'version': 'v1.1',
'sessions': session_snapshot,
'cached_sessions': self.project_filter_class.sessions_lifespan.session_project,
'sessions_update_batch': list(self.sessions_update_batch.keys()),
'sessions_insert_batch': list(self.sessions_insert_batch.keys()),
'events_batch': [event_to_dict(event) for event in self.events_batch]
}
database_api.save_binary(binary_data=json.dumps(checkpoint).encode('utf-8'), name='checkpoint')