diff --git a/ee/api/.gitignore b/ee/api/.gitignore index 7140a891d..80beeee41 100644 --- a/ee/api/.gitignore +++ b/ee/api/.gitignore @@ -283,4 +283,3 @@ Pipfile.lock /chalicelib/utils/contextual_validators.py /routers/subs/product_analytics.py /schemas/product_analytics.py -/ee/bin/* diff --git a/ee/api/Pipfile b/ee/api/Pipfile index cf41528a8..93bd8134a 100644 --- a/ee/api/Pipfile +++ b/ee/api/Pipfile @@ -26,6 +26,8 @@ xmlsec = "==1.3.14" python-multipart = "==0.0.20" redis = "==6.1.0" azure-storage-blob = "==12.25.1" +scim2-server = "*" +scim2-models = "*" [dev-packages] diff --git a/ee/api/app.py b/ee/api/app.py index a9d9c59cd..4d9be2fec 100644 --- a/ee/api/app.py +++ b/ee/api/app.py @@ -9,6 +9,7 @@ from decouple import config from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware +from fastapi.middleware.wsgi import WSGIMiddleware from psycopg import AsyncConnection from psycopg.rows import dict_row from starlette import status @@ -21,12 +22,20 @@ from chalicelib.utils import pg_client, ch_client from crons import core_crons, ee_crons, core_dynamic_crons from routers import core, core_dynamic from routers import ee -from routers.subs import insights, metrics, v1_api, health, usability_tests, spot, product_analytics +from routers.subs import ( + insights, + metrics, + v1_api, + health, + usability_tests, + spot, + product_analytics, +) from routers.subs import v1_api_ee if config("ENABLE_SSO", cast=bool, default=True): from routers import saml - from routers import scim + from routers.scim import api as scim loglevel = config("LOGLEVEL", default=logging.WARNING) print(f">Loglevel set to: {loglevel}") @@ -34,7 +43,6 @@ logging.basicConfig(level=loglevel) class ORPYAsyncConnection(AsyncConnection): - def __init__(self, *args, **kwargs): super().__init__(*args, row_factory=dict_row, **kwargs) @@ -43,7 +51,7 @@ class ORPYAsyncConnection(AsyncConnection): async def lifespan(app: FastAPI): # Startup logging.info(">>>>> starting up <<<<<") - ap_logger = logging.getLogger('apscheduler') + ap_logger = logging.getLogger("apscheduler") ap_logger.setLevel(loglevel) app.schedule = AsyncIOScheduler() @@ -53,12 +61,23 @@ async def lifespan(app: FastAPI): await events_queue.init() app.schedule.start() - for job in core_crons.cron_jobs + core_dynamic_crons.cron_jobs + traces.cron_jobs + ee_crons.ee_cron_jobs: + for job in ( + core_crons.cron_jobs + + core_dynamic_crons.cron_jobs + + traces.cron_jobs + + ee_crons.ee_cron_jobs + ): app.schedule.add_job(id=job["func"].__name__, **job) ap_logger.info(">Scheduled jobs:") for job in app.schedule.get_jobs(): - ap_logger.info({"Name": str(job.id), "Run Frequency": str(job.trigger), "Next Run": str(job.next_run_time)}) + ap_logger.info( + { + "Name": str(job.id), + "Run Frequency": str(job.trigger), + "Next Run": str(job.next_run_time), + } + ) database = { "host": config("pg_host", default="localhost"), @@ -69,9 +88,12 @@ async def lifespan(app: FastAPI): "application_name": "AIO" + config("APP_NAME", default="PY"), } - database = psycopg_pool.AsyncConnectionPool(kwargs=database, connection_class=ORPYAsyncConnection, - min_size=config("PG_AIO_MINCONN", cast=int, default=1), - max_size=config("PG_AIO_MAXCONN", cast=int, default=5), ) + database = psycopg_pool.AsyncConnectionPool( + kwargs=database, + connection_class=ORPYAsyncConnection, + min_size=config("PG_AIO_MINCONN", cast=int, default=1), + max_size=config("PG_AIO_MAXCONN", cast=int, default=5), + ) app.state.postgresql = database # App listening @@ -86,16 +108,24 @@ async def lifespan(app: FastAPI): await pg_client.terminate() -app = FastAPI(root_path=config("root_path", default="/api"), docs_url=config("docs_url", default=""), - redoc_url=config("redoc_url", default=""), lifespan=lifespan) +app = FastAPI( + root_path=config("root_path", default="/api"), + docs_url=config("docs_url", default=""), + redoc_url=config("redoc_url", default=""), + lifespan=lifespan, +) app.add_middleware(GZipMiddleware, minimum_size=1000) -@app.middleware('http') +@app.middleware("http") async def or_middleware(request: Request, call_next): from chalicelib.core import unlock + if not unlock.is_valid(): - return JSONResponse(content={"errors": ["expired license"]}, status_code=status.HTTP_403_FORBIDDEN) + return JSONResponse( + content={"errors": ["expired license"]}, + status_code=status.HTTP_403_FORBIDDEN, + ) if helper.TRACK_TIME: now = time.time() @@ -110,8 +140,10 @@ async def or_middleware(request: Request, call_next): now = time.time() - now if now > 2: now = round(now, 2) - logging.warning(f"Execution time: {now} s for {request.method}: {request.url.path}") - response.headers["x-robots-tag"] = 'noindex, nofollow' + logging.warning( + f"Execution time: {now} s for {request.method}: {request.url.path}" + ) + response.headers["x-robots-tag"] = "noindex, nofollow" return response @@ -162,3 +194,4 @@ if config("ENABLE_SSO", cast=bool, default=True): app.include_router(scim.public_app) app.include_router(scim.app) app.include_router(scim.app_apikey) + app.mount("/sso/scim/v2", WSGIMiddleware(scim.scim_app)) diff --git a/ee/api/chalicelib/core/roles.py b/ee/api/chalicelib/core/roles.py index 955c76af0..321ca1102 100644 --- a/ee/api/chalicelib/core/roles.py +++ b/ee/api/chalicelib/core/roles.py @@ -1,4 +1,3 @@ -import json from typing import Optional from fastapi import HTTPException, status @@ -10,13 +9,15 @@ from chalicelib.utils.TimeUTC import TimeUTC def __exists_by_name(tenant_id: int, name: str, exclude_id: Optional[int]) -> bool: with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""SELECT EXISTS(SELECT 1 + query = cur.mogrify( + f"""SELECT EXISTS(SELECT 1 FROM public.roles WHERE tenant_id = %(tenant_id)s AND name ILIKE %(name)s AND deleted_at ISNULL {"AND role_id!=%(exclude_id)s" if exclude_id else ""}) AS exists;""", - {"tenant_id": tenant_id, "name": name, "exclude_id": exclude_id}) + {"tenant_id": tenant_id, "name": name, "exclude_id": exclude_id}, + ) cur.execute(query=query) row = cur.fetchone() return row["exists"] @@ -28,24 +29,31 @@ def update(tenant_id, user_id, role_id, data: schemas.RolePayloadSchema): if not admin["admin"] and not admin["superAdmin"]: return {"errors": ["unauthorized"]} if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=role_id): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"name already exists.") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="name already exists." + ) if not data.all_projects and (data.projects is None or len(data.projects) == 0): return {"errors": ["must specify a project or all projects"]} if data.projects is not None and len(data.projects) > 0 and not data.all_projects: - data.projects = projects.is_authorized_batch(project_ids=data.projects, tenant_id=tenant_id) + data.projects = projects.is_authorized_batch( + project_ids=data.projects, tenant_id=tenant_id + ) with pg_client.PostgresClient() as cur: - query = cur.mogrify("""SELECT 1 + query = cur.mogrify( + """SELECT 1 FROM public.roles WHERE role_id = %(role_id)s AND tenant_id = %(tenant_id)s AND protected = TRUE LIMIT 1;""", - {"tenant_id": tenant_id, "role_id": role_id}) + {"tenant_id": tenant_id, "role_id": role_id}, + ) cur.execute(query=query) if cur.fetchone() is not None: return {"errors": ["this role is protected"]} - query = cur.mogrify("""UPDATE public.roles + query = cur.mogrify( + """UPDATE public.roles SET name= %(name)s, description= %(description)s, permissions= %(permissions)s, @@ -57,43 +65,36 @@ def update(tenant_id, user_id, role_id, data: schemas.RolePayloadSchema): RETURNING *, COALESCE((SELECT ARRAY_AGG(project_id) FROM roles_projects WHERE roles_projects.role_id=%(role_id)s),'{}') AS projects;""", - {"tenant_id": tenant_id, "role_id": role_id, **data.model_dump()}) + {"tenant_id": tenant_id, "role_id": role_id, **data.model_dump()}, + ) cur.execute(query=query) row = cur.fetchone() row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"]) if not data.all_projects: d_projects = [i for i in row["projects"] if i not in data.projects] if len(d_projects) > 0: - query = cur.mogrify("""DELETE FROM roles_projects + query = cur.mogrify( + """DELETE FROM roles_projects WHERE role_id=%(role_id)s AND project_id IN %(project_ids)s""", - {"role_id": role_id, "project_ids": tuple(d_projects)}) + {"role_id": role_id, "project_ids": tuple(d_projects)}, + ) cur.execute(query=query) n_projects = [i for i in data.projects if i not in row["projects"]] if len(n_projects) > 0: - query = cur.mogrify(f"""INSERT INTO roles_projects(role_id, project_id) + query = cur.mogrify( + f"""INSERT INTO roles_projects(role_id, project_id) VALUES {",".join([f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(n_projects))])}""", - {"role_id": role_id, **{f"project_id_{i}": p for i, p in enumerate(n_projects)}}) + { + "role_id": role_id, + **{f"project_id_{i}": p for i, p in enumerate(n_projects)}, + }, + ) cur.execute(query=query) row["projects"] = data.projects return helper.dict_to_camel_case(row) -def update_group_name(tenant_id, group_id, name): - with pg_client.PostgresClient() as cur: - query = cur.mogrify("""UPDATE public.roles - SET name= %(name)s - WHERE roles.data->>'group_id' = %(group_id)s - AND tenant_id = %(tenant_id)s - AND deleted_at ISNULL - AND protected = FALSE - RETURNING *;""", - {"tenant_id": tenant_id, "group_id": group_id, "name": name }) - cur.execute(query=query) - row = cur.fetchone() - - return helper.dict_to_camel_case(row) - def create(tenant_id, user_id, data: schemas.RolePayloadSchema): admin = users.get(user_id=user_id, tenant_id=tenant_id) @@ -102,57 +103,44 @@ def create(tenant_id, user_id, data: schemas.RolePayloadSchema): return {"errors": ["unauthorized"]} if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=None): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"name already exists.") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="name already exists." + ) if not data.all_projects and (data.projects is None or len(data.projects) == 0): return {"errors": ["must specify a project or all projects"]} if data.projects is not None and len(data.projects) > 0 and not data.all_projects: - data.projects = projects.is_authorized_batch(project_ids=data.projects, tenant_id=tenant_id) + data.projects = projects.is_authorized_batch( + project_ids=data.projects, tenant_id=tenant_id + ) with pg_client.PostgresClient() as cur: - query = cur.mogrify("""INSERT INTO roles(tenant_id, name, description, permissions, all_projects) + query = cur.mogrify( + """INSERT INTO roles(tenant_id, name, description, permissions, all_projects) VALUES (%(tenant_id)s, %(name)s, %(description)s, %(permissions)s::text[], %(all_projects)s) RETURNING *;""", - {"tenant_id": tenant_id, "name": data.name, "description": data.description, - "permissions": data.permissions, "all_projects": data.all_projects}) + { + "tenant_id": tenant_id, + "name": data.name, + "description": data.description, + "permissions": data.permissions, + "all_projects": data.all_projects, + }, + ) cur.execute(query=query) row = cur.fetchone() row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"]) row["projects"] = [] if not data.all_projects: role_id = row["role_id"] - query = cur.mogrify(f"""INSERT INTO roles_projects(role_id, project_id) + query = cur.mogrify( + f"""INSERT INTO roles_projects(role_id, project_id) VALUES {",".join(f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(data.projects)))} RETURNING project_id;""", - {"role_id": role_id, **{f"project_id_{i}": p for i, p in enumerate(data.projects)}}) - cur.execute(query=query) - row["projects"] = [r["project_id"] for r in cur.fetchall()] - return helper.dict_to_camel_case(row) - -def create_as_admin(tenant_id, group_id, data: schemas.RolePayloadSchema): - - if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=None): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"name already exists.") - - if not data.all_projects and (data.projects is None or len(data.projects) == 0): - return {"errors": ["must specify a project or all projects"]} - if data.projects is not None and len(data.projects) > 0 and not data.all_projects: - data.projects = projects.is_authorized_batch(project_ids=data.projects, tenant_id=tenant_id) - with pg_client.PostgresClient() as cur: - query = cur.mogrify("""INSERT INTO roles(tenant_id, name, description, permissions, all_projects, data) - VALUES (%(tenant_id)s, %(name)s, %(description)s, %(permissions)s::text[], %(all_projects)s, %(data)s) - RETURNING *;""", - {"tenant_id": tenant_id, "name": data.name, "description": data.description, - "permissions": data.permissions, "all_projects": data.all_projects, "data": json.dumps({ "group_id": group_id })}) - cur.execute(query=query) - row = cur.fetchone() - row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"]) - row["projects"] = [] - if not data.all_projects: - role_id = row["role_id"] - query = cur.mogrify(f"""INSERT INTO roles_projects(role_id, project_id) - VALUES {",".join(f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(data.projects)))} - RETURNING project_id;""", - {"role_id": role_id, **{f"project_id_{i}": p for i, p in enumerate(data.projects)}}) + { + "role_id": role_id, + **{f"project_id_{i}": p for i, p in enumerate(data.projects)}, + }, + ) cur.execute(query=query) row["projects"] = [r["project_id"] for r in cur.fetchall()] return helper.dict_to_camel_case(row) @@ -160,7 +148,8 @@ def create_as_admin(tenant_id, group_id, data: schemas.RolePayloadSchema): def get_roles(tenant_id): with pg_client.PostgresClient() as cur: - query = cur.mogrify("""SELECT roles.*, COALESCE(projects, '{}') AS projects + query = cur.mogrify( + """SELECT roles.*, COALESCE(projects, '{}') AS projects FROM public.roles LEFT JOIN LATERAL (SELECT array_agg(project_id) AS projects FROM roles_projects @@ -171,66 +160,25 @@ def get_roles(tenant_id): AND deleted_at IS NULL AND not service_role ORDER BY role_id;""", - {"tenant_id": tenant_id}) + {"tenant_id": tenant_id}, + ) cur.execute(query=query) rows = cur.fetchall() for r in rows: r["created_at"] = TimeUTC.datetime_to_timestamp(r["created_at"]) return helper.list_to_camel_case(rows) -def get_roles_with_uuid(tenant_id): - with pg_client.PostgresClient() as cur: - query = cur.mogrify("""SELECT roles.*, COALESCE(projects, '{}') AS projects - FROM public.roles - LEFT JOIN LATERAL (SELECT array_agg(project_id) AS projects - FROM roles_projects - INNER JOIN projects USING (project_id) - WHERE roles_projects.role_id = roles.role_id - AND projects.deleted_at ISNULL ) AS role_projects ON (TRUE) - WHERE tenant_id =%(tenant_id)s - AND data ? 'group_id' - AND deleted_at IS NULL - AND not service_role - ORDER BY role_id;""", - {"tenant_id": tenant_id}) - cur.execute(query=query) - rows = cur.fetchall() - for r in rows: - r["created_at"] = TimeUTC.datetime_to_timestamp(r["created_at"]) - return helper.list_to_camel_case(rows) - -def get_roles_with_uuid_paginated(tenant_id, start_index, count=None, name=None): - with pg_client.PostgresClient() as cur: - query = cur.mogrify("""SELECT roles.*, COALESCE(projects, '{}') AS projects - FROM public.roles - LEFT JOIN LATERAL (SELECT array_agg(project_id) AS projects - FROM roles_projects - INNER JOIN projects USING (project_id) - WHERE roles_projects.role_id = roles.role_id - AND projects.deleted_at ISNULL ) AS role_projects ON (TRUE) - WHERE tenant_id =%(tenant_id)s - AND data ? 'group_id' - AND deleted_at IS NULL - AND not service_role - AND name = COALESCE(%(name)s, name) - ORDER BY role_id - LIMIT %(count)s - OFFSET %(startIndex)s;""", - {"tenant_id": tenant_id, "name": name, "startIndex": start_index - 1, "count": count}) - cur.execute(query=query) - rows = cur.fetchall() - return helper.list_to_camel_case(rows) - def get_role_by_name(tenant_id, name): - ### "name" isn't unique in database with pg_client.PostgresClient() as cur: - query = cur.mogrify("""SELECT * + query = cur.mogrify( + """SELECT * FROM public.roles WHERE tenant_id =%(tenant_id)s AND deleted_at IS NULL AND name ILIKE %(name)s;""", - {"tenant_id": tenant_id, "name": name}) + {"tenant_id": tenant_id, "name": name}, + ) cur.execute(query=query) row = cur.fetchone() if row is not None: @@ -244,139 +192,55 @@ def delete(tenant_id, user_id, role_id): if not admin["admin"] and not admin["superAdmin"]: return {"errors": ["unauthorized"]} with pg_client.PostgresClient() as cur: - query = cur.mogrify("""SELECT 1 + query = cur.mogrify( + """SELECT 1 FROM public.roles WHERE role_id = %(role_id)s AND tenant_id = %(tenant_id)s AND protected = TRUE LIMIT 1;""", - {"tenant_id": tenant_id, "role_id": role_id}) + {"tenant_id": tenant_id, "role_id": role_id}, + ) cur.execute(query=query) if cur.fetchone() is not None: return {"errors": ["this role is protected"]} - query = cur.mogrify("""SELECT 1 + query = cur.mogrify( + """SELECT 1 FROM public.users WHERE role_id = %(role_id)s AND tenant_id = %(tenant_id)s LIMIT 1;""", - {"tenant_id": tenant_id, "role_id": role_id}) + {"tenant_id": tenant_id, "role_id": role_id}, + ) cur.execute(query=query) if cur.fetchone() is not None: return {"errors": ["this role is already attached to other user(s)"]} - query = cur.mogrify("""UPDATE public.roles + query = cur.mogrify( + """UPDATE public.roles SET deleted_at = timezone('utc'::text, now()) WHERE role_id = %(role_id)s AND tenant_id = %(tenant_id)s AND protected = FALSE;""", - {"tenant_id": tenant_id, "role_id": role_id}) + {"tenant_id": tenant_id, "role_id": role_id}, + ) cur.execute(query=query) return get_roles(tenant_id=tenant_id) -def delete_scim_group(tenant_id, group_uuid): - - with pg_client.PostgresClient() as cur: - query = cur.mogrify("""SELECT 1 - FROM public.roles - WHERE data->>'group_id' = %(group_uuid)s - AND tenant_id = %(tenant_id)s - AND protected = TRUE - LIMIT 1;""", - {"tenant_id": tenant_id, "group_uuid": group_uuid}) - cur.execute(query) - if cur.fetchone() is not None: - return {"errors": ["this role is protected"]} - - query = cur.mogrify( - f"""DELETE FROM public.roles - WHERE roles.data->>'group_id' = %(group_uuid)s;""", # removed this: AND users.deleted_at IS NOT NULL - {"group_uuid": group_uuid}) - cur.execute(query) - - return get_roles(tenant_id=tenant_id) - - def get_role(tenant_id, role_id): with pg_client.PostgresClient() as cur: - query = cur.mogrify("""SELECT roles.* + query = cur.mogrify( + """SELECT roles.* FROM public.roles WHERE tenant_id =%(tenant_id)s AND deleted_at IS NULL AND not service_role AND role_id = %(role_id)s LIMIT 1;""", - {"tenant_id": tenant_id, "role_id": role_id}) + {"tenant_id": tenant_id, "role_id": role_id}, + ) cur.execute(query=query) row = cur.fetchone() if row is not None: row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"]) return helper.dict_to_camel_case(row) - -def get_role_by_group_id(tenant_id, group_id): - with pg_client.PostgresClient() as cur: - query = cur.mogrify("""SELECT roles.* - FROM public.roles - WHERE tenant_id =%(tenant_id)s - AND deleted_at IS NULL - AND not service_role - AND data->>'group_id' = %(group_id)s - LIMIT 1;""", - {"tenant_id": tenant_id, "group_id": group_id}) - cur.execute(query=query) - row = cur.fetchone() - if row is not None: - row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"]) - return helper.dict_to_camel_case(row) - -def get_users_by_group_uuid(tenant_id, group_id): - with pg_client.PostgresClient() as cur: - query = cur.mogrify("""SELECT - u.user_id, - u.name, - u.data - FROM public.roles r - LEFT JOIN public.users u USING (role_id, tenant_id) - WHERE u.tenant_id = %(tenant_id)s - AND u.deleted_at IS NULL - AND r.data->>'group_id' = %(group_id)s - """, - {"tenant_id": tenant_id, "group_id": group_id}) - cur.execute(query=query) - rows = cur.fetchall() - return helper.list_to_camel_case(rows) - -def get_member_permissions(tenant_id): - with pg_client.PostgresClient() as cur: - query = cur.mogrify("""SELECT - r.permissions - FROM public.roles r - WHERE r.tenant_id = %(tenant_id)s - AND r.name = 'Member' - AND r.deleted_at IS NULL - """, - {"tenant_id": tenant_id}) - cur.execute(query=query) - row = cur.fetchone() - return helper.dict_to_camel_case(row) - -def remove_group_membership(tenant_id, group_id, user_id): - with pg_client.PostgresClient() as cur: - query = cur.mogrify("""WITH r AS ( - SELECT role_id - FROM public.roles - WHERE data->>'group_id' = %(group_id)s - LIMIT 1 - ) - UPDATE public.users u - SET role_id= NULL - FROM r - WHERE u.data->>'user_id' = %(user_id)s - AND u.role_id = r.role_id - AND u.tenant_id = %(tenant_id)s - AND u.deleted_at IS NULL - RETURNING *;""", - {"tenant_id": tenant_id, "group_id": group_id, "user_id": user_id}) - cur.execute(query=query) - row = cur.fetchone() - - return helper.dict_to_camel_case(row) diff --git a/ee/api/chalicelib/core/users.py b/ee/api/chalicelib/core/users.py index 94d1e8d41..d80907bad 100644 --- a/ee/api/chalicelib/core/users.py +++ b/ee/api/chalicelib/core/users.py @@ -25,12 +25,15 @@ def __generate_invitation_token(): return secrets.token_urlsafe(64) -def create_new_member(tenant_id, email, invitation_token, admin, name, owner=False, role_id=None): +def create_new_member( + tenant_id, email, invitation_token, admin, name, owner=False, role_id=None +): with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""\ + query = cur.mogrify( + """\ WITH u AS ( INSERT INTO public.users (tenant_id, email, role, name, data, role_id) - VALUES (%(tenant_id)s, %(email)s, %(role)s, %(name)s, %(data)s, + VALUES (%(tenant_id)s, %(email)s, %(role)s, %(name)s, %(data)s, (SELECT COALESCE((SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND role_id = %(role_id)s), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name = 'Member' LIMIT 1), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name != 'Owner' LIMIT 1)))) @@ -53,10 +56,16 @@ def create_new_member(tenant_id, email, invitation_token, admin, name, owner=Fal roles.name AS role_name, TRUE AS has_password FROM au,u LEFT JOIN roles USING(tenant_id) WHERE roles.role_id IS NULL OR roles.role_id = (SELECT u.role_id FROM u);""", - {"tenant_id": tenant_id, "email": email, - "role": "owner" if owner else "admin" if admin else "member", "name": name, - "data": json.dumps({"lastAnnouncementView": TimeUTC.now()}), - "invitation_token": invitation_token, "role_id": role_id}) + { + "tenant_id": tenant_id, + "email": email, + "role": "owner" if owner else "admin" if admin else "member", + "name": name, + "data": json.dumps({"lastAnnouncementView": TimeUTC.now()}), + "invitation_token": invitation_token, + "role_id": role_id, + }, + ) cur.execute(query) row = helper.dict_to_camel_case(cur.fetchone()) if row: @@ -64,9 +73,12 @@ def create_new_member(tenant_id, email, invitation_token, admin, name, owner=Fal return row -def restore_member(tenant_id, user_id, email, invitation_token, admin, name, owner=False, role_id=None): +def restore_member( + tenant_id, user_id, email, invitation_token, admin, name, owner=False, role_id=None +): with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""\ + query = cur.mogrify( + """\ WITH u AS (UPDATE public.users SET name= %(name)s, role = %(role)s, @@ -78,7 +90,7 @@ def restore_member(tenant_id, user_id, email, invitation_token, admin, name, own (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name = 'Member' LIMIT 1), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name != 'Owner' LIMIT 1))) WHERE user_id=%(user_id)s - RETURNING + RETURNING tenant_id, user_id, email, @@ -104,11 +116,18 @@ def restore_member(tenant_id, user_id, email, invitation_token, admin, name, own u.role_id, roles.name AS role_name, TRUE AS has_password - FROM au,u LEFT JOIN roles USING(tenant_id) + FROM au,u LEFT JOIN roles USING(tenant_id) WHERE roles.role_id IS NULL OR roles.role_id = (SELECT u.role_id FROM u);""", - {"tenant_id": tenant_id, "user_id": user_id, "email": email, - "role": "owner" if owner else "admin" if admin else "member", "name": name, - "role_id": role_id, "invitation_token": invitation_token}) + { + "tenant_id": tenant_id, + "user_id": user_id, + "email": email, + "role": "owner" if owner else "admin" if admin else "member", + "name": name, + "role_id": role_id, + "invitation_token": invitation_token, + }, + ) cur.execute(query) result = cur.fetchone() result["created_at"] = TimeUTC.datetime_to_timestamp(result["created_at"]) @@ -118,7 +137,8 @@ def restore_member(tenant_id, user_id, email, invitation_token, admin, name, own def generate_new_invitation(user_id): invitation_token = __generate_invitation_token() with pg_client.PostgresClient() as cur: - query = cur.mogrify("""\ + query = cur.mogrify( + """\ UPDATE public.basic_authentication SET invitation_token = %(invitation_token)s, invited_at = timezone('utc'::text, now()), @@ -126,10 +146,9 @@ def generate_new_invitation(user_id): change_pwd_token = NULL WHERE user_id=%(user_id)s RETURNING invitation_token;""", - {"user_id": user_id, "invitation_token": invitation_token}) - cur.execute( - query + {"user_id": user_id, "invitation_token": invitation_token}, ) + cur.execute(query) return __get_invitation_link(cur.fetchone().pop("invitation_token")) @@ -144,7 +163,13 @@ def reset_member(tenant_id, editor_id, user_id_to_update): def update(tenant_id, user_id, changes, output=True): - AUTH_KEYS = ["password", "invitationToken", "invitedAt", "changePwdExpireAt", "changePwdToken"] + AUTH_KEYS = [ + "password", + "invitationToken", + "invitedAt", + "changePwdExpireAt", + "changePwdToken", + ] if len(changes.keys()) == 0: return None @@ -153,7 +178,9 @@ def update(tenant_id, user_id, changes, output=True): for key in changes.keys(): if key in AUTH_KEYS: if key == "password": - sub_query_bauth.append("password = crypt(%(password)s, gen_salt('bf', 12))") + sub_query_bauth.append( + "password = crypt(%(password)s, gen_salt('bf', 12))" + ) sub_query_bauth.append("changed_at = timezone('utc'::text, now())") else: sub_query_bauth.append(f"{helper.key_to_snake_case(key)} = %({key})s") @@ -162,7 +189,9 @@ def update(tenant_id, user_id, changes, output=True): sub_query_users.append("""role_id=(SELECT COALESCE((SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND role_id = %(role_id)s), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name = 'Member' LIMIT 1), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name != 'Owner' LIMIT 1)))""") - elif key == "data": # this is hardcoded, maybe a generic solution would be better + elif ( + key == "data" + ): # this is hardcoded, maybe a generic solution would be better sub_query_users.append(f"data = data || %({(key)})s") else: sub_query_users.append(f"{helper.key_to_snake_case(key)} = %({key})s") @@ -171,26 +200,35 @@ def update(tenant_id, user_id, changes, output=True): changes["data"] = Json(changes["data"]) with pg_client.PostgresClient() as cur: if len(sub_query_users) > 0: - query = cur.mogrify(f"""\ + query = cur.mogrify( + f"""\ UPDATE public.users SET {" ,".join(sub_query_users)} WHERE users.user_id = %(user_id)s AND users.tenant_id = %(tenant_id)s;""", - {"tenant_id": tenant_id, "user_id": user_id, **changes}) + {"tenant_id": tenant_id, "user_id": user_id, **changes}, + ) cur.execute(query) if len(sub_query_bauth) > 0: - query = cur.mogrify(f"""\ + query = cur.mogrify( + f"""\ UPDATE public.basic_authentication SET {" ,".join(sub_query_bauth)} WHERE basic_authentication.user_id = %(user_id)s;""", - {"tenant_id": tenant_id, "user_id": user_id, **changes}) + {"tenant_id": tenant_id, "user_id": user_id, **changes}, + ) cur.execute(query) if not output: return None return get(user_id=user_id, tenant_id=tenant_id) -def create_member(tenant_id, user_id, data: schemas.CreateMemberSchema, background_tasks: BackgroundTasks): +def create_member( + tenant_id, + user_id, + data: schemas.CreateMemberSchema, + background_tasks: BackgroundTasks, +): admin = get(tenant_id=tenant_id, user_id=user_id) if not admin["admin"] and not admin["superAdmin"]: return {"errors": ["unauthorized"]} @@ -204,7 +242,9 @@ def create_member(tenant_id, user_id, data: schemas.CreateMemberSchema, backgrou data.name = data.email role_id = data.roleId if role_id is None: - role_id = roles.get_role_by_name(tenant_id=tenant_id, name="member").get("roleId") + role_id = roles.get_role_by_name(tenant_id=tenant_id, name="member").get( + "roleId" + ) else: role = roles.get_role(tenant_id=tenant_id, role_id=role_id) if role is None: @@ -214,22 +254,46 @@ def create_member(tenant_id, user_id, data: schemas.CreateMemberSchema, backgrou invitation_token = __generate_invitation_token() user = get_deleted_user_by_email(email=data.email) if user is not None and user["tenantId"] == tenant_id: - new_member = restore_member(tenant_id=tenant_id, email=data.email, invitation_token=invitation_token, - admin=data.admin, name=data.name, user_id=user["userId"], role_id=role_id) + new_member = restore_member( + tenant_id=tenant_id, + email=data.email, + invitation_token=invitation_token, + admin=data.admin, + name=data.name, + user_id=user["userId"], + role_id=role_id, + ) elif user is not None: __hard_delete_user(user_id=user["userId"]) - new_member = create_new_member(tenant_id=tenant_id, email=data.email, invitation_token=invitation_token, - admin=data.admin, name=data.name, role_id=role_id) + new_member = create_new_member( + tenant_id=tenant_id, + email=data.email, + invitation_token=invitation_token, + admin=data.admin, + name=data.name, + role_id=role_id, + ) else: - new_member = create_new_member(tenant_id=tenant_id, email=data.email, invitation_token=invitation_token, - admin=data.admin, name=data.name, role_id=role_id) - new_member["invitationLink"] = __get_invitation_link(new_member.pop("invitationToken")) - background_tasks.add_task(email_helper.send_team_invitation, **{ - "recipient": data.email, - "invitation_link": new_member["invitationLink"], - "client_id": tenants.get_by_tenant_id(tenant_id)["name"], - "sender_name": admin["name"] - }) + new_member = create_new_member( + tenant_id=tenant_id, + email=data.email, + invitation_token=invitation_token, + admin=data.admin, + name=data.name, + role_id=role_id, + ) + new_member["invitationLink"] = __get_invitation_link( + new_member.pop("invitationToken") + ) + background_tasks.add_task( + email_helper.send_team_invitation, + **{ + "recipient": data.email, + "invitation_link": new_member["invitationLink"], + "client_id": tenants.get_by_tenant_id(tenant_id)["name"], + "sender_name": admin["name"], + }, + ) return {"data": new_member} @@ -240,14 +304,14 @@ def __get_invitation_link(invitation_token): def allow_password_change(user_id, delta_min=10): pass_token = secrets.token_urlsafe(8) with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""UPDATE public.basic_authentication + query = cur.mogrify( + """UPDATE public.basic_authentication SET change_pwd_expire_at = timezone('utc'::text, now()+INTERVAL '%(delta)s MINUTES'), change_pwd_token = %(pass_token)s WHERE user_id = %(user_id)s""", - {"user_id": user_id, "delta": delta_min, "pass_token": pass_token}) - cur.execute( - query + {"user_id": user_id, "delta": delta_min, "pass_token": pass_token}, ) + cur.execute(query) return pass_token @@ -255,11 +319,11 @@ def get(user_id, tenant_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + """SELECT users.user_id, users.tenant_id, - email, - role, + email, + role, users.name, (CASE WHEN role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, (CASE WHEN role = 'admin' THEN TRUE ELSE FALSE END) AS admin, @@ -279,93 +343,24 @@ def get(user_id, tenant_id): AND users.deleted_at IS NULL AND (roles.role_id IS NULL OR roles.deleted_at IS NULL AND roles.tenant_id = %(tenant_id)s) LIMIT 1;""", - {"userId": user_id, "tenant_id": tenant_id}) + {"userId": user_id, "tenant_id": tenant_id}, + ) ) r = cur.fetchone() return helper.dict_to_camel_case(r) - -def get_by_uuid(user_uuid, tenant_id): - with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify( - f"""SELECT - users.user_id, - users.tenant_id, - email, - role, - users.name, - users.data, - users.internal_id, - (CASE WHEN role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, - (CASE WHEN role = 'admin' THEN TRUE ELSE FALSE END) AS admin, - (CASE WHEN role = 'member' THEN TRUE ELSE FALSE END) AS member, - origin, - role_id, - roles.name AS role_name, - roles.permissions, - roles.all_projects, - basic_authentication.password IS NOT NULL AS has_password, - users.service_account - FROM public.users LEFT JOIN public.basic_authentication ON users.user_id=basic_authentication.user_id - LEFT JOIN public.roles USING (role_id) - WHERE - users.data->>'user_id' = %(user_uuid)s - AND users.tenant_id = %(tenant_id)s - AND users.deleted_at IS NULL - AND (roles.role_id IS NULL OR roles.deleted_at IS NULL AND roles.tenant_id = %(tenant_id)s) - LIMIT 1;""", - {"user_uuid": user_uuid, "tenant_id": tenant_id}) - ) - r = cur.fetchone() - return helper.dict_to_camel_case(r) - -def get_deleted_by_uuid(user_uuid, tenant_id): - with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify( - f"""SELECT - users.user_id, - users.tenant_id, - email, - role, - users.name, - users.data, - users.internal_id, - (CASE WHEN role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, - (CASE WHEN role = 'admin' THEN TRUE ELSE FALSE END) AS admin, - (CASE WHEN role = 'member' THEN TRUE ELSE FALSE END) AS member, - origin, - role_id, - roles.name AS role_name, - roles.permissions, - roles.all_projects, - basic_authentication.password IS NOT NULL AS has_password, - users.service_account - FROM public.users LEFT JOIN public.basic_authentication ON users.user_id=basic_authentication.user_id - LEFT JOIN public.roles USING (role_id) - WHERE - users.data->>'user_id' = %(user_uuid)s - AND users.tenant_id = %(tenant_id)s - AND users.deleted_at IS NOT NULL - AND (roles.role_id IS NULL OR roles.deleted_at IS NULL AND roles.tenant_id = %(tenant_id)s) - LIMIT 1;""", - {"user_uuid": user_uuid, "tenant_id": tenant_id}) - ) - r = cur.fetchone() - return helper.dict_to_camel_case(r) - def generate_new_api_key(user_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""UPDATE public.users + """UPDATE public.users SET api_key=generate_api_key(20) WHERE users.user_id = %(userId)s AND deleted_at IS NULL RETURNING api_key;""", - {"userId": user_id}) + {"userId": user_id}, + ) ) r = cur.fetchone() return helper.dict_to_camel_case(r) @@ -375,22 +370,27 @@ def __get_account_info(tenant_id, user_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT users.name, - tenants.name AS tenant_name, + """SELECT users.name, + tenants.name AS tenant_name, tenants.opt_out FROM public.users INNER JOIN public.tenants USING (tenant_id) WHERE users.user_id = %(userId)s AND tenants.tenant_id= %(tenantId)s AND tenants.deleted_at IS NULL AND users.deleted_at IS NULL;""", - {"tenantId": tenant_id, "userId": user_id}) + {"tenantId": tenant_id, "userId": user_id}, + ) ) r = cur.fetchone() return helper.dict_to_camel_case(r) def edit_account(user_id, tenant_id, changes: schemas.EditAccountSchema): - if changes.opt_out is not None or changes.tenantName is not None and len(changes.tenantName) > 0: + if ( + changes.opt_out is not None + or changes.tenantName is not None + and len(changes.tenantName) > 0 + ): user = get(user_id=user_id, tenant_id=tenant_id) if not user["superAdmin"] and not user["admin"]: return {"errors": ["unauthorized"]} @@ -410,7 +410,9 @@ def edit_account(user_id, tenant_id, changes: schemas.EditAccountSchema): return {"data": __get_account_info(tenant_id=tenant_id, user_id=user_id)} -def edit_member(user_id_to_update, tenant_id, changes: schemas.EditMemberSchema, editor_id): +def edit_member( + user_id_to_update, tenant_id, changes: schemas.EditMemberSchema, editor_id +): user = get_member(user_id=user_id_to_update, tenant_id=tenant_id) _changes = {} if editor_id != user_id_to_update: @@ -448,7 +450,12 @@ def edit_member(user_id_to_update, tenant_id, changes: schemas.EditMemberSchema, return {"errors": ["invalid role"]} if len(_changes.keys()) > 0: - update(tenant_id=tenant_id, user_id=user_id_to_update, changes=_changes, output=False) + update( + tenant_id=tenant_id, + user_id=user_id_to_update, + changes=_changes, + output=False, + ) return {"data": get_member(user_id=user_id_to_update, tenant_id=tenant_id)} return {"data": user} @@ -457,11 +464,11 @@ def get_by_email_only(email): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + """SELECT users.user_id, users.tenant_id, - users.email, - users.role, + users.email, + users.role, users.name, (CASE WHEN users.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, (CASE WHEN users.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, @@ -473,58 +480,25 @@ def get_by_email_only(email): roles.name AS role_name FROM public.users LEFT JOIN public.basic_authentication USING(user_id) INNER JOIN public.roles USING(role_id) - WHERE users.email = %(email)s + WHERE users.email = %(email)s AND users.deleted_at IS NULL LIMIT 1;""", - {"email": email}) + {"email": email}, + ) ) r = cur.fetchone() return helper.dict_to_camel_case(r) -def get_users_paginated(start_index, count=None, email=None): - with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify( - f"""SELECT - users.user_id AS id, - users.tenant_id, - users.email AS email, - users.data AS data, - users.role, - users.name AS name, - (CASE WHEN users.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, - (CASE WHEN users.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, - (CASE WHEN users.role = 'member' THEN TRUE ELSE FALSE END) AS member, - origin, - basic_authentication.password IS NOT NULL AS has_password, - role_id, - internal_id, - roles.name AS role_name - FROM public.users LEFT JOIN public.basic_authentication USING(user_id) - INNER JOIN public.roles USING(role_id) - WHERE users.deleted_at IS NULL - AND users.data ? 'user_id' - AND email = COALESCE(%(email)s, email) - LIMIT %(count)s - OFFSET %(startIndex)s;;""", - {"startIndex": start_index - 1, "count": count, "email": email}) - ) - r = cur.fetchall() - if len(r): - r = helper.list_to_camel_case(r) - return r - return [] - def get_member(tenant_id, user_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + """SELECT users.user_id, - users.email, - users.role, - users.name, + users.email, + users.role, + users.name, users.created_at, (CASE WHEN users.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, (CASE WHEN users.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, @@ -535,12 +509,13 @@ def get_member(tenant_id, user_id): invitation_token, role_id, roles.name AS role_name - FROM public.users + FROM public.users LEFT JOIN public.basic_authentication ON users.user_id=basic_authentication.user_id LEFT JOIN public.roles USING (role_id) WHERE users.tenant_id = %(tenant_id)s AND users.deleted_at IS NULL AND users.user_id = %(user_id)s ORDER BY name, user_id""", - {"tenant_id": tenant_id, "user_id": user_id}) + {"tenant_id": tenant_id, "user_id": user_id}, + ) ) u = helper.dict_to_camel_case(cur.fetchone()) if u: @@ -557,11 +532,11 @@ def get_members(tenant_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + """SELECT users.user_id, - users.email, - users.role, - users.name, + users.email, + users.role, + users.name, users.created_at, (CASE WHEN users.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, (CASE WHEN users.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, @@ -572,14 +547,15 @@ def get_members(tenant_id): invitation_token, role_id, roles.name AS role_name - FROM public.users + FROM public.users LEFT JOIN public.basic_authentication ON users.user_id=basic_authentication.user_id LEFT JOIN public.roles USING (role_id) - WHERE users.tenant_id = %(tenant_id)s + WHERE users.tenant_id = %(tenant_id)s AND users.deleted_at IS NULL AND NOT users.service_account ORDER BY name, user_id""", - {"tenant_id": tenant_id}) + {"tenant_id": tenant_id}, + ) ) r = cur.fetchall() if len(r): @@ -587,7 +563,9 @@ def get_members(tenant_id): for u in r: u["createdAt"] = TimeUTC.datetime_to_timestamp(u["createdAt"]) if u["invitationToken"]: - u["invitationLink"] = __get_invitation_link(u.pop("invitationToken")) + u["invitationLink"] = __get_invitation_link( + u.pop("invitationToken") + ) else: u["invitationLink"] = None return r @@ -612,96 +590,48 @@ def delete_member(user_id, tenant_id, id_to_delete): with pg_client.PostgresClient() as cur: cur.execute( - cur.mogrify(f"""UPDATE public.users + cur.mogrify( + """UPDATE public.users SET deleted_at = timezone('utc'::text, now()), - jwt_iat= NULL, jwt_refresh_jti= NULL, + jwt_iat= NULL, jwt_refresh_jti= NULL, jwt_refresh_iat= NULL, role_id=NULL WHERE user_id=%(user_id)s AND tenant_id=%(tenant_id)s;""", - {"user_id": id_to_delete, "tenant_id": tenant_id})) - cur.execute( - cur.mogrify(f"""UPDATE public.basic_authentication - SET password= NULL, invitation_token= NULL, - invited_at= NULL, changed_at= NULL, - change_pwd_expire_at= NULL, change_pwd_token= NULL - WHERE user_id=%(user_id)s;""", - {"user_id": id_to_delete, "tenant_id": tenant_id})) - return {"data": get_members(tenant_id=tenant_id)} - - -def delete_member_as_admin(tenant_id, id_to_delete): - - with pg_client.PostgresClient() as cur: + {"user_id": id_to_delete, "tenant_id": tenant_id}, + ) + ) cur.execute( cur.mogrify( - f"""SELECT - users.user_id AS user_id, - users.tenant_id, - email, - role, - users.name, - origin, - role_id, - roles.name AS role_name, - (CASE WHEN role = 'member' THEN TRUE ELSE FALSE END) AS member, - roles.permissions, - roles.all_projects, - basic_authentication.password IS NOT NULL AS has_password, - users.service_account - FROM public.users LEFT JOIN public.basic_authentication ON users.user_id=basic_authentication.user_id - LEFT JOIN public.roles USING (role_id) - WHERE - role = 'owner' - AND users.tenant_id = %(tenant_id)s - AND users.deleted_at IS NULL - AND (roles.role_id IS NULL OR roles.deleted_at IS NULL AND roles.tenant_id = %(tenant_id)s) - LIMIT 1;""", - {"tenant_id": tenant_id, "user_uuid": id_to_delete}) - ) - r = cur.fetchone() - - if r["user_id"] == id_to_delete: - return {"errors": ["unauthorized, cannot delete self"]} - - if r["member"]: - return {"errors": ["unauthorized"]} - - to_delete = get(user_id=id_to_delete, tenant_id=tenant_id) - if to_delete is None: - return {"errors": ["not found"]} - - if to_delete["superAdmin"]: - return {"errors": ["cannot delete super admin"]} - - with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify(f"""UPDATE public.users - SET deleted_at = timezone('utc'::text, now()), - jwt_iat= NULL, jwt_refresh_jti= NULL, - jwt_refresh_iat= NULL, - role_id=NULL - WHERE user_id=%(user_id)s AND tenant_id=%(tenant_id)s;""", - {"user_id": id_to_delete, "tenant_id": tenant_id})) - cur.execute( - cur.mogrify(f"""UPDATE public.basic_authentication + """UPDATE public.basic_authentication SET password= NULL, invitation_token= NULL, invited_at= NULL, changed_at= NULL, change_pwd_expire_at= NULL, change_pwd_token= NULL WHERE user_id=%(user_id)s;""", - {"user_id": id_to_delete, "tenant_id": tenant_id})) + {"user_id": id_to_delete, "tenant_id": tenant_id}, + ) + ) return {"data": get_members(tenant_id=tenant_id)} - def change_password(tenant_id, user_id, email, old_password, new_password): item = get(tenant_id=tenant_id, user_id=user_id) if item is None: return {"errors": ["access denied"]} - if item["origin"] is not None and config("enforce_SSO", cast=bool, default=False) \ - and not item["superAdmin"] and helper.is_saml2_available(): - return {"errors": ["Please use your SSO to change your password, enforced by admin"]} + if ( + item["origin"] is not None + and config("enforce_SSO", cast=bool, default=False) + and not item["superAdmin"] + and helper.is_saml2_available() + ): + return { + "errors": ["Please use your SSO to change your password, enforced by admin"] + } if item["origin"] is not None and item["hasPassword"] is False: - return {"errors": ["cannot change your password because you are logged-in from an SSO service"]} + return { + "errors": [ + "cannot change your password because you are logged-in from an SSO service" + ] + } if old_password == new_password: return {"errors": ["old and new password are the same"]} auth = authenticate(email, old_password, for_change_password=True) @@ -709,23 +639,7 @@ def change_password(tenant_id, user_id, email, old_password, new_password): return {"errors": ["wrong password"]} changes = {"password": new_password} user = update(tenant_id=tenant_id, user_id=user_id, changes=changes) - r = authenticate(user['email'], new_password) - - return { - "jwt": r.pop("jwt"), - "refreshToken": r.pop("refreshToken"), - "refreshTokenMaxAge": r.pop("refreshTokenMaxAge"), - "spotJwt": r.pop("spotJwt"), - "spotRefreshToken": r.pop("spotRefreshToken"), - "spotRefreshTokenMaxAge": r.pop("spotRefreshTokenMaxAge"), - "tenantId": tenant_id - } - - -def set_password_invitation(tenant_id, user_id, new_password): - changes = {"password": new_password} - user = update(tenant_id=tenant_id, user_id=user_id, changes=changes) - r = authenticate(user['email'], new_password) + r = authenticate(user["email"], new_password) return { "jwt": r.pop("jwt"), @@ -735,7 +649,23 @@ def set_password_invitation(tenant_id, user_id, new_password): "spotRefreshToken": r.pop("spotRefreshToken"), "spotRefreshTokenMaxAge": r.pop("spotRefreshTokenMaxAge"), "tenantId": tenant_id, - **r + } + + +def set_password_invitation(tenant_id, user_id, new_password): + changes = {"password": new_password} + user = update(tenant_id=tenant_id, user_id=user_id, changes=changes) + r = authenticate(user["email"], new_password) + + return { + "jwt": r.pop("jwt"), + "refreshToken": r.pop("refreshToken"), + "refreshTokenMaxAge": r.pop("refreshTokenMaxAge"), + "spotJwt": r.pop("spotJwt"), + "spotRefreshToken": r.pop("spotRefreshToken"), + "spotRefreshTokenMaxAge": r.pop("spotRefreshTokenMaxAge"), + "tenantId": tenant_id, + **r, } @@ -743,14 +673,15 @@ def email_exists(email): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT - count(user_id) + """SELECT + count(user_id) FROM public.users WHERE email = %(email)s AND deleted_at IS NULL LIMIT 1;""", - {"email": email}) + {"email": email}, + ) ) r = cur.fetchone() return r["count"] > 0 @@ -760,14 +691,15 @@ def get_deleted_user_by_email(email): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT - * + """SELECT + * FROM public.users WHERE email = %(email)s AND deleted_at NOTNULL LIMIT 1;""", - {"email": email}) + {"email": email}, + ) ) r = cur.fetchone() return helper.dict_to_camel_case(r) @@ -777,7 +709,7 @@ def get_by_invitation_token(token, pass_token=None): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + """SELECT *, DATE_PART('day',timezone('utc'::text, now()) \ - COALESCE(basic_authentication.invited_at,'2000-01-01'::timestamp ))>=1 AS expired_invitation, @@ -786,7 +718,8 @@ def get_by_invitation_token(token, pass_token=None): FROM public.users INNER JOIN public.basic_authentication USING(user_id) WHERE invitation_token = %(token)s {"AND change_pwd_token = %(pass_token)s" if pass_token else ""} LIMIT 1;""", - {"token": token, "pass_token": pass_token}) + {"token": token, "pass_token": pass_token}, + ) ) r = cur.fetchone() return helper.dict_to_camel_case(r) @@ -796,37 +729,42 @@ def auth_exists(user_id, tenant_id, jwt_iat) -> bool: with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT user_id, - EXTRACT(epoch FROM jwt_iat)::BIGINT AS jwt_iat, + """SELECT user_id, + EXTRACT(epoch FROM jwt_iat)::BIGINT AS jwt_iat, changed_at, service_account, basic_authentication.user_id IS NOT NULL AS has_basic_auth - FROM public.users - LEFT JOIN public.basic_authentication USING(user_id) - WHERE user_id = %(userId)s - AND tenant_id = %(tenant_id)s - AND deleted_at IS NULL + FROM public.users + LEFT JOIN public.basic_authentication USING(user_id) + WHERE user_id = %(userId)s + AND tenant_id = %(tenant_id)s + AND deleted_at IS NULL LIMIT 1;""", - {"userId": user_id, "tenant_id": tenant_id}) + {"userId": user_id, "tenant_id": tenant_id}, + ) ) r = cur.fetchone() - return r is not None \ - and (r["service_account"] and not r["has_basic_auth"] - or r.get("jwt_iat") is not None \ - and (abs(jwt_iat - r["jwt_iat"]) <= 1)) + return r is not None and ( + r["service_account"] + and not r["has_basic_auth"] + or r.get("jwt_iat") is not None + and (abs(jwt_iat - r["jwt_iat"]) <= 1) + ) def refresh_auth_exists(user_id, tenant_id, jwt_jti=None): with pg_client.PostgresClient() as cur: cur.execute( - cur.mogrify(f"""SELECT user_id - FROM public.users - WHERE user_id = %(userId)s + cur.mogrify( + """SELECT user_id + FROM public.users + WHERE user_id = %(userId)s AND tenant_id= %(tenant_id)s AND deleted_at IS NULL AND jwt_refresh_jti = %(jwt_jti)s LIMIT 1;""", - {"userId": user_id, "tenant_id": tenant_id, "jwt_jti": jwt_jti}) + {"userId": user_id, "tenant_id": tenant_id, "jwt_jti": jwt_jti}, + ) ) r = cur.fetchone() return r is not None @@ -864,21 +802,23 @@ class RefreshSpotJWTs(FullLoginJWTs): def change_jwt_iat_jti(user_id): with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""UPDATE public.users + query = cur.mogrify( + """UPDATE public.users SET jwt_iat = timezone('utc'::text, now()-INTERVAL '10s'), - jwt_refresh_jti = 0, + jwt_refresh_jti = 0, jwt_refresh_iat = timezone('utc'::text, now()-INTERVAL '10s'), spot_jwt_iat = timezone('utc'::text, now()-INTERVAL '10s'), - spot_jwt_refresh_jti = 0, - spot_jwt_refresh_iat = timezone('utc'::text, now()-INTERVAL '10s') - WHERE user_id = %(user_id)s - RETURNING EXTRACT (epoch FROM jwt_iat)::BIGINT AS jwt_iat, - jwt_refresh_jti, + spot_jwt_refresh_jti = 0, + spot_jwt_refresh_iat = timezone('utc'::text, now()-INTERVAL '10s') + WHERE user_id = %(user_id)s + RETURNING EXTRACT (epoch FROM jwt_iat)::BIGINT AS jwt_iat, + jwt_refresh_jti, EXTRACT (epoch FROM jwt_refresh_iat)::BIGINT AS jwt_refresh_iat, - EXTRACT (epoch FROM spot_jwt_iat)::BIGINT AS spot_jwt_iat, - spot_jwt_refresh_jti, + EXTRACT (epoch FROM spot_jwt_iat)::BIGINT AS spot_jwt_iat, + spot_jwt_refresh_jti, EXTRACT (epoch FROM spot_jwt_refresh_iat)::BIGINT AS spot_jwt_refresh_iat;""", - {"user_id": user_id}) + {"user_id": user_id}, + ) cur.execute(query) row = cur.fetchone() return FullLoginJWTs(**row) @@ -886,14 +826,16 @@ def change_jwt_iat_jti(user_id): def refresh_jwt_iat_jti(user_id): with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""UPDATE public.users + query = cur.mogrify( + """UPDATE public.users SET jwt_iat = timezone('utc'::text, now()-INTERVAL '10s'), - jwt_refresh_jti = jwt_refresh_jti + 1 - WHERE user_id = %(user_id)s - RETURNING EXTRACT (epoch FROM jwt_iat)::BIGINT AS jwt_iat, - jwt_refresh_jti, + jwt_refresh_jti = jwt_refresh_jti + 1 + WHERE user_id = %(user_id)s + RETURNING EXTRACT (epoch FROM jwt_iat)::BIGINT AS jwt_iat, + jwt_refresh_jti, EXTRACT (epoch FROM jwt_refresh_iat)::BIGINT AS jwt_refresh_iat;""", - {"user_id": user_id}) + {"user_id": user_id}, + ) cur.execute(query) row = cur.fetchone() return RefreshLoginJWTs(**row) @@ -904,7 +846,7 @@ def authenticate(email, password, for_change_password=False) -> dict | bool | No return {"errors": ["must sign-in with SSO, enforced by admin"]} with pg_client.PostgresClient() as cur: query = cur.mogrify( - f"""SELECT + """SELECT users.user_id, users.tenant_id, users.role, @@ -919,24 +861,26 @@ def authenticate(email, password, for_change_password=False) -> dict | bool | No users.service_account FROM public.users AS users INNER JOIN public.basic_authentication USING(user_id) LEFT JOIN public.roles ON (roles.role_id = users.role_id AND roles.tenant_id = users.tenant_id) - WHERE users.email = %(email)s + WHERE users.email = %(email)s AND basic_authentication.password = crypt(%(password)s, basic_authentication.password) AND basic_authentication.user_id = (SELECT su.user_id FROM public.users AS su WHERE su.email=%(email)s AND su.deleted_at IS NULL LIMIT 1) AND (roles.role_id IS NULL OR roles.deleted_at IS NULL) LIMIT 1;""", - {"email": email, "password": password}) + {"email": email, "password": password}, + ) cur.execute(query) r = cur.fetchone() if r is None and helper.is_saml2_available(): query = cur.mogrify( - f"""SELECT 1 + """SELECT 1 FROM public.users - WHERE users.email = %(email)s + WHERE users.email = %(email)s AND users.deleted_at IS NULL AND users.origin IS NOT NULL LIMIT 1;""", - {"email": email}) + {"email": email}, + ) cur.execute(query) if cur.fetchone() is not None: return {"errors": ["must sign-in with SSO"]} @@ -946,33 +890,51 @@ def authenticate(email, password, for_change_password=False) -> dict | bool | No return True r = helper.dict_to_camel_case(r) if r["serviceAccount"]: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, - detail="service account is not authorized to login") - elif config("enforce_SSO", cast=bool, default=False) and helper.is_saml2_available(): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="service account is not authorized to login", + ) + elif ( + config("enforce_SSO", cast=bool, default=False) + and helper.is_saml2_available() + ): return {"errors": ["must sign-in with SSO, enforced by admin"]} - j_r = change_jwt_iat_jti(user_id=r['userId']) + j_r = change_jwt_iat_jti(user_id=r["userId"]) response = { - "jwt": authorizers.generate_jwt(user_id=r['userId'], tenant_id=r['tenantId'], iat=j_r.jwt_iat, - aud=AUDIENCE), - "refreshToken": authorizers.generate_jwt_refresh(user_id=r['userId'], - tenant_id=r['tenantId'], - iat=j_r.jwt_refresh_iat, - aud=AUDIENCE, - jwt_jti=j_r.jwt_refresh_jti, - for_spot=False), + "jwt": authorizers.generate_jwt( + user_id=r["userId"], + tenant_id=r["tenantId"], + iat=j_r.jwt_iat, + aud=AUDIENCE, + ), + "refreshToken": authorizers.generate_jwt_refresh( + user_id=r["userId"], + tenant_id=r["tenantId"], + iat=j_r.jwt_refresh_iat, + aud=AUDIENCE, + jwt_jti=j_r.jwt_refresh_jti, + for_spot=False, + ), "refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int), "email": email, - "spotJwt": authorizers.generate_jwt(user_id=r['userId'], tenant_id=r['tenantId'], - iat=j_r.spot_jwt_iat, aud=spot.AUDIENCE, for_spot=True), - "spotRefreshToken": authorizers.generate_jwt_refresh(user_id=r['userId'], - tenant_id=r['tenantId'], - iat=j_r.spot_jwt_refresh_iat, - aud=spot.AUDIENCE, - jwt_jti=j_r.spot_jwt_refresh_jti, - for_spot=True), + "spotJwt": authorizers.generate_jwt( + user_id=r["userId"], + tenant_id=r["tenantId"], + iat=j_r.spot_jwt_iat, + aud=spot.AUDIENCE, + for_spot=True, + ), + "spotRefreshToken": authorizers.generate_jwt_refresh( + user_id=r["userId"], + tenant_id=r["tenantId"], + iat=j_r.spot_jwt_refresh_iat, + aud=spot.AUDIENCE, + jwt_jti=j_r.spot_jwt_refresh_jti, + for_spot=True, + ), "spotRefreshTokenMaxAge": config("JWT_SPOT_REFRESH_EXPIRATION", cast=int), - **r + **r, } return response @@ -983,73 +945,30 @@ def get_user_role(tenant_id, user_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + """SELECT users.user_id, - users.email, - users.role, - users.name, + users.email, + users.role, + users.name, users.created_at, (CASE WHEN users.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, (CASE WHEN users.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, (CASE WHEN users.role = 'member' THEN TRUE ELSE FALSE END) AS member - FROM public.users - WHERE users.deleted_at IS NULL + FROM public.users + WHERE users.deleted_at IS NULL AND users.user_id=%(user_id)s AND users.tenant_id=%(tenant_id)s LIMIT 1""", - {"tenant_id": tenant_id, "user_id": user_id}) + {"tenant_id": tenant_id, "user_id": user_id}, + ) ) return helper.dict_to_camel_case(cur.fetchone()) def create_sso_user(tenant_id, email, admin, name, origin, role_id, internal_id=None): with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""\ - WITH u AS ( - INSERT INTO public.users (tenant_id, email, role, name, data, origin, internal_id, role_id) - VALUES (%(tenant_id)s, %(email)s, %(role)s, %(name)s, %(data)s, %(origin)s, %(internal_id)s, - (SELECT COALESCE((SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND role_id = %(role_id)s), - (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name = 'Member' LIMIT 1), - (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name != 'Owner' LIMIT 1)))) - RETURNING * - ), - au AS ( - INSERT INTO public.basic_authentication(user_id) - VALUES ((SELECT user_id FROM u)) - ) - SELECT u.user_id AS id, - u.email, - u.role, - u.name, - (CASE WHEN u.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, - (CASE WHEN u.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, - (CASE WHEN u.role = 'member' THEN TRUE ELSE FALSE END) AS member, - origin - FROM u;""", - {"tenant_id": tenant_id, "email": email, "internal_id": internal_id, - "role": "admin" if admin else "member", "name": name, "origin": origin, - "role_id": role_id, "data": json.dumps({"lastAnnouncementView": TimeUTC.now()})}) - cur.execute( - query - ) - return helper.dict_to_camel_case(cur.fetchone()) - -def create_scim_user( - tenant_id, - user_uuid, - email, - admin, - display_name, - full_name: dict, - emails, - origin, - locale, - role_id, - internal_id=None, -): - - with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""\ + query = cur.mogrify( + """\ WITH u AS ( INSERT INTO public.users (tenant_id, email, role, name, data, origin, internal_id, role_id) VALUES (%(tenant_id)s, %(email)s, %(role)s, %(name)s, %(data)s, %(origin)s, %(internal_id)s, @@ -1066,36 +985,33 @@ def create_scim_user( u.email, u.role, u.name, - u.data, (CASE WHEN u.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, (CASE WHEN u.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, (CASE WHEN u.role = 'member' THEN TRUE ELSE FALSE END) AS member, origin FROM u;""", - {"tenant_id": tenant_id, "email": email, "internal_id": internal_id, - "role": "admin" if admin else "member", "name": display_name, "origin": origin, - "role_id": role_id, "data": json.dumps({"lastAnnouncementView": TimeUTC.now(), "user_id": user_uuid, "locale": locale, "name": full_name, "emails": emails})}) - cur.execute( - query + { + "tenant_id": tenant_id, + "email": email, + "internal_id": internal_id, + "role": "admin" if admin else "member", + "name": name, + "origin": origin, + "role_id": role_id, + "data": json.dumps({"lastAnnouncementView": TimeUTC.now()}), + }, ) + cur.execute(query) return helper.dict_to_camel_case(cur.fetchone()) - def __hard_delete_user(user_id): with pg_client.PostgresClient() as cur: query = cur.mogrify( - f"""DELETE FROM public.users + """DELETE FROM public.users WHERE users.user_id = %(user_id)s AND users.deleted_at IS NOT NULL ;""", - {"user_id": user_id}) - cur.execute(query) - -def __hard_delete_user_uuid(user_uuid): - with pg_client.PostgresClient() as cur: - query = cur.mogrify( - f"""DELETE FROM public.users - WHERE users.data->>'user_id' = %(user_uuid)s;""", # removed this: AND users.deleted_at IS NOT NULL - {"user_uuid": user_uuid}) + {"user_id": user_id}, + ) cur.execute(query) @@ -1106,25 +1022,33 @@ def logout(user_id: int): SET jwt_iat = NULL, jwt_refresh_jti = NULL, jwt_refresh_iat = NULL, spot_jwt_iat = NULL, spot_jwt_refresh_jti = NULL, spot_jwt_refresh_iat = NULL WHERE user_id = %(user_id)s;""", - {"user_id": user_id}) + {"user_id": user_id}, + ) cur.execute(query) def refresh(user_id: int, tenant_id: int = -1) -> dict: j = refresh_jwt_iat_jti(user_id=user_id) return { - "jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=j.jwt_iat, - aud=AUDIENCE), - "refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=j.jwt_refresh_iat, - aud=AUDIENCE, jwt_jti=j.jwt_refresh_jti), - "refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int) - (j.jwt_iat - j.jwt_refresh_iat), + "jwt": authorizers.generate_jwt( + user_id=user_id, tenant_id=tenant_id, iat=j.jwt_iat, aud=AUDIENCE + ), + "refreshToken": authorizers.generate_jwt_refresh( + user_id=user_id, + tenant_id=tenant_id, + iat=j.jwt_refresh_iat, + aud=AUDIENCE, + jwt_jti=j.jwt_refresh_jti, + ), + "refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int) + - (j.jwt_iat - j.jwt_refresh_iat), } def authenticate_sso(email: str, internal_id: str): with pg_client.PostgresClient() as cur: query = cur.mogrify( - f"""SELECT + """SELECT users.user_id, users.tenant_id, users.role, @@ -1137,7 +1061,8 @@ def authenticate_sso(email: str, internal_id: str): service_account FROM public.users AS users WHERE users.email = %(email)s AND internal_id = %(internal_id)s;""", - {"email": email, "internal_id": internal_id}) + {"email": email, "internal_id": internal_id}, + ) cur.execute(query) r = cur.fetchone() @@ -1145,41 +1070,64 @@ def authenticate_sso(email: str, internal_id: str): if r is not None: r = helper.dict_to_camel_case(r) if r["serviceAccount"]: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, - detail="service account is not authorized to login") - j_r = change_jwt_iat_jti(user_id=r['userId']) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="service account is not authorized to login", + ) + j_r = change_jwt_iat_jti(user_id=r["userId"]) response = { - "jwt": authorizers.generate_jwt(user_id=r['userId'], tenant_id=r['tenantId'], iat=j_r.jwt_iat, - aud=AUDIENCE), - "refreshToken": authorizers.generate_jwt_refresh(user_id=r['userId'], tenant_id=r['tenantId'], - iat=j_r.jwt_refresh_iat, - aud=AUDIENCE, jwt_jti=j_r.jwt_refresh_jti), + "jwt": authorizers.generate_jwt( + user_id=r["userId"], + tenant_id=r["tenantId"], + iat=j_r.jwt_iat, + aud=AUDIENCE, + ), + "refreshToken": authorizers.generate_jwt_refresh( + user_id=r["userId"], + tenant_id=r["tenantId"], + iat=j_r.jwt_refresh_iat, + aud=AUDIENCE, + jwt_jti=j_r.jwt_refresh_jti, + ), "refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int), - "spotJwt": authorizers.generate_jwt(user_id=r['userId'], tenant_id=r['tenantId'], - iat=j_r.spot_jwt_iat, aud=spot.AUDIENCE, for_spot=True), - "spotRefreshToken": authorizers.generate_jwt_refresh(user_id=r['userId'], - tenant_id=r['tenantId'], - iat=j_r.spot_jwt_refresh_iat, - aud=spot.AUDIENCE, - jwt_jti=j_r.spot_jwt_refresh_jti, for_spot=True), - "spotRefreshTokenMaxAge": config("JWT_SPOT_REFRESH_EXPIRATION", cast=int) + "spotJwt": authorizers.generate_jwt( + user_id=r["userId"], + tenant_id=r["tenantId"], + iat=j_r.spot_jwt_iat, + aud=spot.AUDIENCE, + for_spot=True, + ), + "spotRefreshToken": authorizers.generate_jwt_refresh( + user_id=r["userId"], + tenant_id=r["tenantId"], + iat=j_r.spot_jwt_refresh_iat, + aud=spot.AUDIENCE, + jwt_jti=j_r.spot_jwt_refresh_jti, + for_spot=True, + ), + "spotRefreshTokenMaxAge": config("JWT_SPOT_REFRESH_EXPIRATION", cast=int), } return response - logger.warning(f"SSO user not found with email: {email} and internal_id: {internal_id}") + logger.warning( + f"SSO user not found with email: {email} and internal_id: {internal_id}" + ) return None -def restore_sso_user(user_id, tenant_id, email, admin, name, origin, role_id, internal_id=None): +def restore_sso_user( + user_id, tenant_id, email, admin, name, origin, role_id, internal_id=None +): with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""\ + query = cur.mogrify( + """\ WITH u AS ( - UPDATE public.users + UPDATE public.users SET tenant_id= %(tenant_id)s, - role= %(role)s, + role= %(role)s, name= %(name)s, - data= %(data)s, - origin= %(origin)s, - internal_id= %(internal_id)s, + data= %(data)s, + origin= %(origin)s, + internal_id= %(internal_id)s, role_id= (SELECT COALESCE((SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND role_id = %(role_id)s), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name = 'Member' LIMIT 1), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name != 'Owner' LIMIT 1))), @@ -1198,7 +1146,7 @@ def restore_sso_user(user_id, tenant_id, email, admin, name, origin, role_id, in invited_at= default, change_pwd_token= default, change_pwd_expire_at= default, - changed_at= NULL + changed_at= NULL WHERE user_id = %(user_id)s RETURNING user_id ) @@ -1211,92 +1159,35 @@ def restore_sso_user(user_id, tenant_id, email, admin, name, origin, role_id, in (CASE WHEN u.role = 'member' THEN TRUE ELSE FALSE END) AS member, origin FROM u;""", - {"tenant_id": tenant_id, "email": email, "internal_id": internal_id, - "role": "admin" if admin else "member", "name": name, "origin": origin, - "role_id": role_id, "data": json.dumps({"lastAnnouncementView": TimeUTC.now()}), - "user_id": user_id}) - cur.execute( - query + { + "tenant_id": tenant_id, + "email": email, + "internal_id": internal_id, + "role": "admin" if admin else "member", + "name": name, + "origin": origin, + "role_id": role_id, + "data": json.dumps({"lastAnnouncementView": TimeUTC.now()}), + "user_id": user_id, + }, ) + cur.execute(query) return helper.dict_to_camel_case(cur.fetchone()) -def restore_scim_user( - user_id, - tenant_id, - user_uuid, - email, - admin, - display_name, - full_name: dict, - emails, - origin, - locale, - role_id, - internal_id=None): - with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""\ - WITH u AS ( - UPDATE public.users - SET tenant_id= %(tenant_id)s, - role= %(role)s, - name= %(name)s, - data= %(data)s, - origin= %(origin)s, - internal_id= %(internal_id)s, - role_id= (SELECT COALESCE((SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND role_id = %(role_id)s), - (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name = 'Member' LIMIT 1), - (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name != 'Owner' LIMIT 1))), - deleted_at= NULL, - created_at= default, - api_key= default, - jwt_iat= NULL, - weekly_report= default - WHERE user_id = %(user_id)s - RETURNING * - ), - au AS ( - UPDATE public.basic_authentication - SET password= default, - invitation_token= default, - invited_at= default, - change_pwd_token= default, - change_pwd_expire_at= default, - changed_at= NULL - WHERE user_id = %(user_id)s - RETURNING user_id - ) - SELECT u.user_id AS id, - u.email, - u.role, - u.name, - u.data, - (CASE WHEN u.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, - (CASE WHEN u.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, - (CASE WHEN u.role = 'member' THEN TRUE ELSE FALSE END) AS member, - origin - FROM u;""", - {"tenant_id": tenant_id, "email": email, "internal_id": internal_id, - "role": "admin" if admin else "member", "name": display_name, "origin": origin, - "role_id": role_id, "data": json.dumps({"lastAnnouncementView": TimeUTC.now(), "user_id": user_uuid, "locale": locale, "name": full_name, "emails": emails}), - "user_id": user_id}) - cur.execute( - query - ) - return helper.dict_to_camel_case(cur.fetchone()) - def get_user_settings(user_id): # read user settings from users.settings:jsonb column with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + """SELECT settings - FROM public.users - WHERE users.deleted_at IS NULL + FROM public.users + WHERE users.deleted_at IS NULL AND users.user_id=%(user_id)s LIMIT 1""", - {"user_id": user_id}) + {"user_id": user_id}, + ) ) return helper.dict_to_camel_case(cur.fetchone()) @@ -1328,11 +1219,12 @@ def update_user_settings(user_id, settings): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""UPDATE public.users + """UPDATE public.users SET settings = %(settings)s WHERE users.user_id = %(user_id)s AND deleted_at IS NULL RETURNING settings;""", - {"user_id": user_id, "settings": json.dumps(settings)}) + {"user_id": user_id, "settings": json.dumps(settings)}, + ) ) return helper.dict_to_camel_case(cur.fetchone()) diff --git a/ee/api/chalicelib/utils/SAML2_helper.py b/ee/api/chalicelib/utils/SAML2_helper.py index cbfcccaab..5c484e5c3 100644 --- a/ee/api/chalicelib/utils/SAML2_helper.py +++ b/ee/api/chalicelib/utils/SAML2_helper.py @@ -23,20 +23,18 @@ SAML2 = { "entityId": config("SITE_URL") + API_PREFIX + "/sso/saml2/metadata/", "assertionConsumerService": { "url": config("SITE_URL") + API_PREFIX + "/sso/saml2/acs/", - "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" + "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST", }, "singleLogoutService": { "url": config("SITE_URL") + API_PREFIX + "/sso/saml2/sls/", - "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" + "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect", }, "NameIDFormat": "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress", "x509cert": config("sp_crt", default=""), "privateKey": config("sp_key", default=""), }, - "security": { - "requestedAuthnContext": False - }, - "idp": None + "security": {"requestedAuthnContext": False}, + "idp": None, } # in case tenantKey is included in the URL @@ -50,25 +48,29 @@ if config("SAML2_MD_URL", default=None) is not None and len(config("SAML2_MD_URL print("SAML2_MD_URL provided, getting IdP metadata config") from onelogin.saml2.idp_metadata_parser import OneLogin_Saml2_IdPMetadataParser - idp_data = OneLogin_Saml2_IdPMetadataParser.parse_remote(config("SAML2_MD_URL", default=None)) + idp_data = OneLogin_Saml2_IdPMetadataParser.parse_remote( + config("SAML2_MD_URL", default=None) + ) idp = idp_data.get("idp") if SAML2["idp"] is None: - if len(config("idp_entityId", default="")) > 0 \ - and len(config("idp_sso_url", default="")) > 0 \ - and len(config("idp_x509cert", default="")) > 0: + if ( + len(config("idp_entityId", default="")) > 0 + and len(config("idp_sso_url", default="")) > 0 + and len(config("idp_x509cert", default="")) > 0 + ): idp = { "entityId": config("idp_entityId"), "singleSignOnService": { "url": config("idp_sso_url"), - "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" + "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect", }, - "x509cert": config("idp_x509cert") + "x509cert": config("idp_x509cert"), } if len(config("idp_sls_url", default="")) > 0: idp["singleLogoutService"] = { "url": config("idp_sls_url"), - "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" + "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect", } if idp is None: @@ -106,8 +108,8 @@ async def prepare_request(request: Request): session = {} # If server is behind proxys or balancers use the HTTP_X_FORWARDED fields headers = request.headers - proto = headers.get('x-forwarded-proto', 'http') - url_data = urlparse('%s://%s' % (proto, headers['host'])) + proto = headers.get("x-forwarded-proto", "http") + url_data = urlparse("%s://%s" % (proto, headers["host"])) path = request.url.path site_url = urlparse(config("SITE_URL")) # to support custom port without changing IDP config @@ -117,21 +119,21 @@ async def prepare_request(request: Request): # add / to /acs if not path.endswith("/"): - path = path + '/' + path = path + "/" if len(API_PREFIX) > 0 and not path.startswith(API_PREFIX): path = API_PREFIX + path return { - 'https': 'on' if proto == 'https' else 'off', - 'http_host': request.headers['host'] + host_suffix, - 'server_port': url_data.port, - 'script_name': path, - 'get_data': request.args.copy(), + "https": "on" if proto == "https" else "off", + "http_host": request.headers["host"] + host_suffix, + "server_port": url_data.port, + "script_name": path, + "get_data": request.args.copy(), # Uncomment if using ADFS as IdP, https://github.com/onelogin/python-saml/pull/144 # 'lowercase_urlencoding': True, - 'post_data': request.form.copy(), - 'cookie': {"session": session}, - 'request': request + "post_data": request.form.copy(), + "cookie": {"session": session}, + "request": request, } @@ -140,8 +142,11 @@ def is_saml2_available(): def get_saml2_provider(): - return config("idp_name", default="saml2") if is_saml2_available() and len( - config("idp_name", default="saml2")) > 0 else None + return ( + config("idp_name", default="saml2") + if is_saml2_available() and len(config("idp_name", default="saml2")) > 0 + else None + ) def get_landing_URL(query_params: dict = None, redirect_to_link2=False): @@ -152,11 +157,14 @@ def get_landing_URL(query_params: dict = None, redirect_to_link2=False): if redirect_to_link2: if len(config("sso_landing_override", default="")) == 0: - logging.warning("SSO trying to redirect to custom URL, but sso_landing_override env var is empty") + logging.warning( + "SSO trying to redirect to custom URL, but sso_landing_override env var is empty" + ) else: return config("sso_landing_override") + query_params - return config("SITE_URL") + config("sso_landing", default="/login") + query_params + base_url = config("SITE_URLx") if config("LOCAL_DEV") else config("SITE_URL") + return base_url + config("sso_landing", default="/login") + query_params environ["hastSAML2"] = str(is_saml2_available()) diff --git a/ee/api/chalicelib/utils/scim_auth.py b/ee/api/chalicelib/utils/scim_auth.py index 83e779c40..c31dcd058 100644 --- a/ee/api/chalicelib/utils/scim_auth.py +++ b/ee/api/chalicelib/utils/scim_auth.py @@ -13,23 +13,9 @@ REFRESH_SECRET_KEY = config("SCIM_REFRESH_SECRET_KEY") ALGORITHM = config("SCIM_JWT_ALGORITHM") ACCESS_TOKEN_EXPIRE_SECONDS = int(config("SCIM_ACCESS_TOKEN_EXPIRE_SECONDS")) REFRESH_TOKEN_EXPIRE_SECONDS = int(config("SCIM_REFRESH_TOKEN_EXPIRE_SECONDS")) -AUDIENCE="okta_client" -ISSUER=config("JWT_ISSUER"), +AUDIENCE = config("SCIM_AUDIENCE") +ISSUER = (config("JWT_ISSUER"),) -# Simulated Okta Client Credentials -# OKTA_CLIENT_ID = "okta-client" -# OKTA_CLIENT_SECRET = "okta-secret" - -# class TokenRequest(BaseModel): -# client_id: str -# client_secret: str - -# async def authenticate_client(token_request: TokenRequest): -# """Validate Okta Client Credentials and issue JWT""" -# if token_request.client_id != OKTA_CLIENT_ID or token_request.client_secret != OKTA_CLIENT_SECRET: -# raise HTTPException(status_code=401, detail="Invalid client credentials") - -# return {"access_token": create_jwt(), "token_type": "bearer"} def create_tokens(tenant_id): curr_time = time.time() @@ -38,7 +24,7 @@ def create_tokens(tenant_id): "sub": "scim_server", "aud": AUDIENCE, "iss": ISSUER, - "exp": "" + "exp": "", } access_payload.update({"exp": curr_time + ACCESS_TOKEN_EXPIRE_SECONDS}) access_token = jwt.encode(access_payload, ACCESS_SECRET_KEY, algorithm=ALGORITHM) @@ -47,20 +33,26 @@ def create_tokens(tenant_id): refresh_payload.update({"exp": curr_time + REFRESH_TOKEN_EXPIRE_SECONDS}) refresh_token = jwt.encode(refresh_payload, REFRESH_SECRET_KEY, algorithm=ALGORITHM) - return access_token, refresh_token + return access_token, refresh_token, ACCESS_TOKEN_EXPIRE_SECONDS + def verify_access_token(token: str): try: - payload = jwt.decode(token, ACCESS_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE) + payload = jwt.decode( + token, ACCESS_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE + ) return payload except jwt.ExpiredSignatureError: raise HTTPException(status_code=401, detail="Token expired") except jwt.InvalidTokenError: raise HTTPException(status_code=401, detail="Invalid token") + def verify_refresh_token(token: str): try: - payload = jwt.decode(token, REFRESH_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE) + payload = jwt.decode( + token, REFRESH_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE + ) return payload except jwt.ExpiredSignatureError: raise HTTPException(status_code=401, detail="Token expired") @@ -68,10 +60,25 @@ def verify_refresh_token(token: str): raise HTTPException(status_code=401, detail="Invalid token") -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +required_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + # Authentication Dependency -def auth_required(token: str = Depends(oauth2_scheme)): +def auth_required(token: str = Depends(required_oauth2_scheme)): """Dependency to check Authorization header.""" if config("SCIM_AUTH_TYPE") == "OAuth2": payload = verify_access_token(token) return payload["tenant_id"] + + +optional_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) + + +def auth_optional(token: str | None = Depends(optional_oauth2_scheme)): + if token is None: + return None + try: + tenant_id = auth_required(token) + return tenant_id + except HTTPException: + return None diff --git a/ee/api/routers/scim.py b/ee/api/routers/scim.py deleted file mode 100644 index 8a0975492..000000000 --- a/ee/api/routers/scim.py +++ /dev/null @@ -1,466 +0,0 @@ -import logging -import re -import uuid -from typing import Optional - -from decouple import config -from fastapi import Depends, HTTPException, Header, Query, Response -from fastapi.responses import JSONResponse -from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm -from pydantic import BaseModel, Field - -import schemas -from chalicelib.core import users, roles, tenants -from chalicelib.utils.scim_auth import auth_required, create_tokens, verify_refresh_token -from routers.base import get_routers - - -logger = logging.getLogger(__name__) - -public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2") - - -"""Authentication endpoints""" - -class RefreshRequest(BaseModel): - refresh_token: str - -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") - -# Login endpoint to generate tokens -@public_app.post("/token") -async def login(host: str = Header(..., alias="Host"), form_data: OAuth2PasswordRequestForm = Depends()): - subdomain = host.split(".")[0] - - # Missing authentication part, to add - if form_data.username != config("SCIM_USER") or form_data.password != config("SCIM_PASSWORD"): - raise HTTPException(status_code=401, detail="Invalid credentials") - - subdomain = "Openreplay EE" - tenant = tenants.get_by_name(subdomain) - access_token, refresh_token = create_tokens(tenant_id=tenant["tenantId"]) - - return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"} - -# Refresh token endpoint -@public_app.post("/refresh") -async def refresh_token(r: RefreshRequest): - - payload = verify_refresh_token(r.refresh_token) - new_access_token, _ = create_tokens(tenant_id=payload["tenant_id"]) - - return {"access_token": new_access_token, "token_type": "Bearer"} - -""" -User endpoints -""" - -class Name(BaseModel): - givenName: str - familyName: str - -class Email(BaseModel): - primary: bool - value: str - type: str - -class UserRequest(BaseModel): - schemas: list[str] - userName: str - name: Name - emails: list[Email] - displayName: str - locale: str - externalId: str - groups: list[dict] - password: str = Field(default=None) - active: bool - -class UserResponse(BaseModel): - schemas: list[str] - id: str - userName: str - name: Name - emails: list[Email] # ignore for now - displayName: str - locale: str - externalId: str - active: bool - groups: list[dict] - meta: dict = Field(default={"resourceType": "User"}) - -class PatchUserRequest(BaseModel): - schemas: list[str] - Operations: list[dict] - - -@public_app.get("/Users", dependencies=[Depends(auth_required)]) -async def get_users( - start_index: int = Query(1, alias="startIndex"), - count: Optional[int] = Query(None, alias="count"), - email: Optional[str] = Query(None, alias="filter"), -): - """Get SCIM Users""" - if email: - email = email.split(" ")[2].strip('"') - result_users = users.get_users_paginated(start_index, count, email) - - serialized_users = [] - for user in result_users: - serialized_users.append( - UserResponse( - schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"], - id = user["data"]["userId"], - userName = user["email"], - name = Name.model_validate(user["data"]["name"]), - emails = [Email.model_validate(user["data"]["emails"])], - displayName = user["name"], - locale = user["data"]["locale"], - externalId = user["internalId"], - active = True, # ignore for now, since, can't insert actual timestamp - groups = [], # ignore - ).model_dump(mode='json') - ) - return JSONResponse( - status_code=200, - content={ - "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], - "totalResults": len(serialized_users), - "startIndex": start_index, - "itemsPerPage": len(serialized_users), - "Resources": serialized_users, - }, - ) - -@public_app.get("/Users/{user_id}", dependencies=[Depends(auth_required)]) -def get_user(user_id: str): - """Get SCIM User""" - tenant_id = 1 - user = users.get_by_uuid(user_id, tenant_id) - if not user: - return JSONResponse( - status_code=404, - content={ - "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], - "detail": "User not found", - "status": 404, - } - ) - - res = UserResponse( - schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"], - id = user["data"]["userId"], - userName = user["email"], - name = Name.model_validate(user["data"]["name"]), - emails = [Email.model_validate(user["data"]["emails"])], - displayName = user["name"], - locale = user["data"]["locale"], - externalId = user["internalId"], - active = True, # ignore for now, since, can't insert actual timestamp - groups = [], # ignore - ) - return JSONResponse(status_code=201, content=res.model_dump(mode='json')) - - -@public_app.post("/Users", dependencies=[Depends(auth_required)]) -async def create_user(r: UserRequest): - """Create SCIM User""" - tenant_id = 1 - existing_user = users.get_by_email_only(r.userName) - deleted_user = users.get_deleted_user_by_email(r.userName) - - if existing_user: - return JSONResponse( - status_code = 409, - content = { - "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], - "detail": "User already exists in the database.", - "status": 409, - } - ) - elif deleted_user: - user_id = users.get_deleted_by_uuid(deleted_user["data"]["userId"], tenant_id) - user = users.restore_scim_user(user_id=user_id["userId"], tenant_id=tenant_id, user_uuid=uuid.uuid4().hex, email=r.emails[0].value, admin=False, - display_name=r.displayName, full_name=r.name.model_dump(mode='json'), emails=r.emails[0].model_dump(mode='json'), - origin="okta", locale=r.locale, role_id=None, internal_id=r.externalId) - else: - try: - user = users.create_scim_user(tenant_id=tenant_id, user_uuid=uuid.uuid4().hex, email=r.emails[0].value, admin=False, - display_name=r.displayName, full_name=r.name.model_dump(mode='json'), emails=r.emails[0].model_dump(mode='json'), - origin="okta", locale=r.locale, role_id=None, internal_id=r.externalId) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - res = UserResponse( - schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"], - id = user["data"]["userId"], - userName = r.userName, - name = r.name, - emails = r.emails, - displayName = r.displayName, - locale = r.locale, - externalId = r.externalId, - active = r.active, # ignore for now, since, can't insert actual timestamp - groups = [], # ignore - ) - return JSONResponse(status_code=201, content=res.model_dump(mode='json')) - - - -@public_app.put("/Users/{user_id}", dependencies=[Depends(auth_required)]) -def update_user(user_id: str, r: UserRequest): - """Update SCIM User""" - tenant_id = 1 - user = users.get_by_uuid(user_id, tenant_id) - if not user: - raise HTTPException(status_code=404, detail="User not found") - - changes = r.model_dump(mode='json', exclude={"schemas", "emails", "name", "locale", "groups", "password", "active"}) # some of these should be added later if necessary - nested_changes = r.model_dump(mode='json', include={"name", "emails"}) - mapping = {"userName": "email", "displayName": "name", "externalId": "internal_id"} # mapping between scim schema field names and local database model, can be done as config? - for k, v in mapping.items(): - if k in changes: - changes[v] = changes.pop(k) - changes["data"] = {} - for k, v in nested_changes.items(): - value_to_insert = v[0] if k == "emails" else v - changes["data"][k] = value_to_insert - try: - users.update(tenant_id, user["userId"], changes) - res = UserResponse( - schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"], - id = user["data"]["userId"], - userName = r.userName, - name = r.name, - emails = r.emails, - displayName = r.displayName, - locale = r.locale, - externalId = r.externalId, - active = r.active, # ignore for now, since, can't insert actual timestamp - groups = [], # ignore - ) - - return JSONResponse(status_code=201, content=res.model_dump(mode='json')) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@public_app.patch("/Users/{user_id}", dependencies=[Depends(auth_required)]) -def deactivate_user(user_id: str, r: PatchUserRequest): - """Deactivate user, soft-delete""" - tenant_id = 1 - active = r.model_dump(mode='json')["Operations"][0]["value"]["active"] - if active: - raise HTTPException(status_code=404, detail="Activating user is not supported") - user = users.get_by_uuid(user_id, tenant_id) - if not user: - raise HTTPException(status_code=404, detail="User not found") - users.delete_member_as_admin(tenant_id, user["userId"]) - - return Response(status_code=204, content="") - -@public_app.delete("/Users/{user_uuid}", dependencies=[Depends(auth_required)]) -def delete_user(user_uuid: str): - """Delete user from database, hard-delete""" - tenant_id = 1 - user = users.get_by_uuid(user_uuid, tenant_id) - if not user: - raise HTTPException(status_code=404, detail="User not found") - - users.__hard_delete_user_uuid(user_uuid) - return Response(status_code=204, content="") - - - -""" -Group endpoints -""" - -class Operation(BaseModel): - op: str - path: str = Field(default=None) - value: list[dict] | dict = Field(default=None) - -class GroupGetResponse(BaseModel): - schemas: list[str] = Field(default=["urn:ietf:params:scim:api:messages:2.0:ListResponse"]) - totalResults: int - startIndex: int - itemsPerPage: int - resources: list = Field(alias="Resources") - -class GroupRequest(BaseModel): - schemas: list[str] = Field(default=["urn:ietf:params:scim:schemas:core:2.0:Group"]) - displayName: str = Field(default=None) - members: list = Field(default=None) - operations: list[Operation] = Field(default=None, alias="Operations") - -class GroupPatchRequest(BaseModel): - schemas: list[str] = Field(default=["urn:ietf:params:scim:api:messages:2.0:PatchOp"]) - operations: list[Operation] = Field(alias="Operations") - -class GroupResponse(BaseModel): - schemas: list[str] = Field(default=["urn:ietf:params:scim:schemas:core:2.0:Group"]) - id: str - displayName: str - members: list - meta: dict = Field(default={"resourceType": "Group"}) - - -@public_app.get("/Groups", dependencies=[Depends(auth_required)]) -def get_groups( - start_index: int = Query(1, alias="startIndex"), - count: Optional[int] = Query(None, alias="count"), - group_name: Optional[str] = Query(None, alias="filter"), - ): - """Get groups""" - tenant_id = 1 - res = [] - if group_name: - group_name = group_name.split(" ")[2].strip('"') - - groups = roles.get_roles_with_uuid_paginated(tenant_id, start_index, count, group_name) - res = [{ - "id": group["data"]["groupId"], - "meta": { - "created": group["createdAt"], - "lastModified": "", # not currently a field - "version": "v1.0" - }, - "displayName": group["name"] - } for group in groups - ] - return JSONResponse( - status_code=200, - content=GroupGetResponse( - totalResults=len(groups), - startIndex=start_index, - itemsPerPage=len(groups), - Resources=res - ).model_dump(mode='json')) - -@public_app.get("/Groups/{group_id}", dependencies=[Depends(auth_required)]) -def get_group(group_id: str): - """Get a group by id""" - tenant_id = 1 - group = roles.get_role_by_group_id(tenant_id, group_id) - if not group: - raise HTTPException(status_code=404, detail="Group not found") - members = roles.get_users_by_group_uuid(tenant_id, group["data"]["groupId"]) - members = [{"value": member["data"]["userId"], "display": member["name"]} for member in members] - - return JSONResponse( - status_code=200, - content=GroupResponse( - id=group["data"]["groupId"], - displayName=group["name"], - members=members, - ).model_dump(mode='json')) - -@public_app.post("/Groups", dependencies=[Depends(auth_required)]) -def create_group(r: GroupRequest): - """Create a group""" - tenant_id = 1 - member_role = roles.get_member_permissions(tenant_id) - try: - data = schemas.RolePayloadSchema(name=r.displayName, permissions=member_role["permissions"]) # permissions by default are same as for member role - group = roles.create_as_admin(tenant_id, uuid.uuid4().hex, data) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - added_members = [] - for member in r.members: - user = users.get_by_uuid(member["value"], tenant_id) - if user: - users.update(tenant_id, user["userId"], {"role_id": group["roleId"]}) - added_members.append({ - "value": user["data"]["userId"], - "display": user["name"] - }) - - return JSONResponse( - status_code=200, - content=GroupResponse( - id=group["data"]["groupId"], - displayName=group["name"], - members=added_members, - ).model_dump(mode='json')) - - -@public_app.put("/Groups/{group_id}", dependencies=[Depends(auth_required)]) -def update_put_group(group_id: str, r: GroupRequest): - """Update a group or members of the group (not used by anything yet)""" - tenant_id = 1 - group = roles.get_role_by_group_id(tenant_id, group_id) - if not group: - raise HTTPException(status_code=404, detail="Group not found") - - if r.operations and r.operations[0].op == "replace" and r.operations[0].path is None: - roles.update_group_name(tenant_id, group["data"]["groupId"], r.operations[0].value["displayName"]) - return Response(status_code=200, content="") - - members = r.members - modified_members = [] - for member in members: - user = users.get_by_uuid(member["value"], tenant_id) - if user: - users.update(tenant_id, user["userId"], {"role_id": group["roleId"]}) - modified_members.append({ - "value": user["data"]["userId"], - "display": user["name"] - }) - - return JSONResponse( - status_code=200, - content=GroupResponse( - id=group_id, - displayName=group["name"], - members=modified_members, - ).model_dump(mode='json')) - - -@public_app.patch("/Groups/{group_id}", dependencies=[Depends(auth_required)]) -def update_patch_group(group_id: str, r: GroupPatchRequest): - """Update a group or members of the group, used by AIW""" - tenant_id = 1 - group = roles.get_role_by_group_id(tenant_id, group_id) - if not group: - raise HTTPException(status_code=404, detail="Group not found") - if r.operations[0].op == "replace" and r.operations[0].path is None: - roles.update_group_name(tenant_id, group["data"]["groupId"], r.operations[0].value["displayName"]) - return Response(status_code=200, content="") - - modified_members = [] - for op in r.operations: - if op.op == "add" or op.op == "replace": - # Both methods work as "replace" - for u in op.value: - user = users.get_by_uuid(u["value"], tenant_id) - if user: - users.update(tenant_id, user["userId"], {"role_id": group["roleId"]}) - modified_members.append({ - "value": user["data"]["userId"], - "display": user["name"] - }) - elif op.op == "remove": - user_id = re.search(r'\[value eq \"([a-f0-9]+)\"\]', op.path).group(1) - roles.remove_group_membership(tenant_id, group_id, user_id) - return JSONResponse( - status_code=200, - content=GroupResponse( - id=group_id, - displayName=group["name"], - members=modified_members, - ).model_dump(mode='json')) - - -@public_app.delete("/Groups/{group_id}", dependencies=[Depends(auth_required)]) -def delete_group(group_id: str): - """Delete a group, hard-delete""" - # possibly need to set the user's roles to default member role, instead of null - tenant_id = 1 - group = roles.get_role_by_group_id(tenant_id, group_id) - if not group: - raise HTTPException(status_code=404, detail="Group not found") - roles.delete_scim_group(tenant_id, group["data"]["groupId"]) - - return Response(status_code=200, content="") diff --git a/ee/api/routers/scim/__init__.py b/ee/api/routers/scim/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ee/api/routers/scim/api.py b/ee/api/routers/scim/api.py new file mode 100644 index 000000000..9a58e4d37 --- /dev/null +++ b/ee/api/routers/scim/api.py @@ -0,0 +1,164 @@ +from scim2_server import utils + + +from routers.base import get_routers +from routers.scim.providers import MultiTenantProvider +from routers.scim.backends import PostgresBackend +from routers.scim.postgres_resource import PostgresResource +from routers.scim import users, groups, helpers +from urllib.parse import urlencode + +from chalicelib.utils import pg_client + +from fastapi import HTTPException, Request +from fastapi.responses import RedirectResponse +from chalicelib.utils.scim_auth import ( + create_tokens, + verify_refresh_token, +) + + +b = PostgresBackend() +b.register_postgres_resource( + "User", + PostgresResource( + query_resources=users.query_resources, + get_resource=users.get_resource, + create_resource=users.create_resource, + search_existing=users.search_existing, + restore_resource=users.restore_resource, + delete_resource=users.delete_resource, + update_resource=users.update_resource, + ), +) +b.register_postgres_resource( + "Group", + PostgresResource( + query_resources=groups.query_resources, + get_resource=groups.get_resource, + create_resource=groups.create_resource, + search_existing=groups.search_existing, + restore_resource=None, + delete_resource=groups.delete_resource, + update_resource=groups.update_resource, + ), +) + +scim_app = MultiTenantProvider(b) + +for schema in utils.load_default_schemas().values(): + scim_app.register_schema(schema) +for schema in helpers.load_custom_schemas().values(): + scim_app.register_schema(schema) +for resource_type in helpers.load_custom_resource_types().values(): + scim_app.register_resource_type(resource_type) + + +public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2") + + +@public_app.post("/token") +async def post_token(r: Request): + form = await r.form() + + client_id = form.get("client_id") + client_secret = form.get("client_secret") + with pg_client.PostgresClient() as cur: + try: + cur.execute( + cur.mogrify( + """ + SELECT tenant_id + FROM public.tenants + WHERE tenant_id=%(tenant_id)s AND tenant_key=%(tenant_key)s + """, + {"tenant_id": int(client_id), "tenant_key": client_secret}, + ) + ) + except ValueError: + raise HTTPException(status_code=401, detail="Invalid credentials") + + tenant = cur.fetchone() + if not tenant: + raise HTTPException(status_code=401, detail="Invalid credentials") + + grant_type = form.get("grant_type") + if grant_type == "refresh_token": + refresh_token = form.get("refresh_token") + verify_refresh_token(refresh_token) + else: + code = form.get("code") + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + SELECT * + FROM public.scim_auth_codes + WHERE auth_code=%(auth_code)s AND tenant_id=%(tenant_id)s AND used IS FALSE + """, + {"auth_code": code, "tenant_id": int(client_id)}, + ) + ) + if cur.fetchone() is None: + raise HTTPException( + status_code=401, detail="Invalid code/client_id pair" + ) + cur.execute( + cur.mogrify( + """ + UPDATE public.scim_auth_codes + SET used=TRUE + WHERE auth_code=%(auth_code)s AND tenant_id=%(tenant_id)s AND used IS FALSE + """, + {"auth_code": code, "tenant_id": int(client_id)}, + ) + ) + + access_token, refresh_token, expires_in = create_tokens( + tenant_id=tenant["tenant_id"] + ) + + return { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": expires_in, + "refresh_token": refresh_token, + } + + +# note(jon): this might be specific to okta. if so, we should probably put specify that in the endpoint +@public_app.get("/authorize") +async def get_authorize( + r: Request, + response_type: str, + client_id: str, + redirect_uri: str, + state: str | None = None, +): + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + UPDATE public.scim_auth_codes + SET used=TRUE + WHERE tenant_id=%(tenant_id)s + """, + {"tenant_id": int(client_id)}, + ) + ) + cur.execute( + cur.mogrify( + """ + INSERT INTO public.scim_auth_codes (tenant_id) + VALUES (%(tenant_id)s) + RETURNING auth_code + """, + {"tenant_id": int(client_id)}, + ) + ) + code = cur.fetchone()["auth_code"] + params = {"code": code} + if state: + params["state"] = state + url = f"{redirect_uri}?{urlencode(params)}" + return RedirectResponse(url) diff --git a/ee/api/routers/scim/backends.py b/ee/api/routers/scim/backends.py new file mode 100644 index 000000000..85daf6d7c --- /dev/null +++ b/ee/api/routers/scim/backends.py @@ -0,0 +1,203 @@ +from scim2_server import backend +from scim2_server.filter import evaluate_filter +from scim2_server.utils import SCIMException + +from scim2_models import ( + SearchRequest, + Resource, + Context, + Error, +) +from scim2_filter_parser import lexer +from scim2_filter_parser.parser import SCIMParser +from routers.scim.postgres_resource import PostgresResource +from scim2_server.operators import ResolveSortOperator +import operator + + +class PostgresBackend(backend.Backend): + def __init__(self): + super().__init__() + self._postgres_resources = {} + + def register_postgres_resource( + self, resource_type_id: str, postgres_resource: PostgresResource + ): + self._postgres_resources[resource_type_id] = postgres_resource + + def query_resources( + self, + search_request: SearchRequest, + tenant_id: int, + resource_type_id: str | None = None, + ) -> tuple[int, list[Resource]]: + """Query the backend for a set of resources. + + :param search_request: SearchRequest instance describing the + query. + :param resource_type_id: ID of the resource type to query. If + None, all resource types are queried. + :return: A tuple of "total results" and a List of found + Resources. The List must contain a copy of resources. + Mutating elements in the List must not modify the data + stored in the backend. + :raises SCIMException: If the backend only supports querying for + one resource type at a time, setting resource_type_id to + None the backend may raise a + SCIMException(Error.make_too_many_error()). + """ + start_index = (search_request.start_index or 1) - 1 + + tree = None + if search_request.filter is not None: + token_stream = lexer.SCIMLexer().tokenize(search_request.filter) + tree = SCIMParser().parse(token_stream) + + # todo(jon): handle the case when resource_type_id is None. + # we're assuming it's never None for now. + # but, this is fine to leave as it doesn't seem to used or reached in + # any of my tests yet. + if not resource_type_id: + raise NotImplementedError + + resources = self._postgres_resources[resource_type_id].query_resources( + tenant_id + ) + model = self.get_model(resource_type_id) + resources = [ + model.model_validate(r, scim_ctx=Context.RESOURCE_QUERY_RESPONSE) + for r in resources + ] + resources = [r for r in resources if (tree is None or evaluate_filter(r, tree))] + + if search_request.sort_by is not None: + descending = search_request.sort_order == SearchRequest.SortOrder.descending + sort_operator = ResolveSortOperator(search_request.sort_by) + + # To ensure that unset attributes are sorted last (when ascending, as defined in the RFC), + # we have to divide the result set into a set and unset subset. + unset_values = [] + set_values = [] + for resource in resources: + result = sort_operator(resource) + if result is None: + unset_values.append(resource) + else: + set_values.append((resource, result)) + + set_values.sort(key=operator.itemgetter(1), reverse=descending) + set_values = [value[0] for value in set_values] + if descending: + resources = unset_values + set_values + else: + resources = set_values + unset_values + + found_resources = resources[start_index:] + if search_request.count is not None: + found_resources = resources[: search_request.count] + + return len(resources), found_resources + + def get_resource( + self, tenant_id: int, resource_type_id: str, object_id: str + ) -> Resource | None: + """Query the backend for a resources by its ID. + + :param resource_type_id: ID of the resource type to get the + object from. + :param object_id: ID of the object to get. + :return: The resource object if it exists, None otherwise. The + resource must be a copy, modifying it must not change the + data stored in the backend. + """ + resource = self._postgres_resources[resource_type_id].get_resource( + object_id, tenant_id + ) + if resource: + model = self.get_model(resource_type_id) + resource = model.model_validate(resource) + return resource + + def delete_resource( + self, tenant_id: int, resource_type_id: str, object_id: str + ) -> bool: + """Delete a resource. + + :param resource_type_id: ID of the resource type to delete the + object from. + :param object_id: ID of the object to delete. + :return: True if the resource was deleted, False otherwise. + """ + resource = self.get_resource(tenant_id, resource_type_id, object_id) + if resource: + self._postgres_resources[resource_type_id].delete_resource( + object_id, tenant_id + ) + return True + return False + + def create_resource( + self, tenant_id: int, resource_type_id: str, resource: Resource + ) -> Resource | None: + """Create a resource. + + :param resource_type_id: ID of the resource type to create. + :param resource: Resource to create. + :return: The created resource. Creation should set system- + defined attributes (ID, Metadata). May be the same object + that is passed in. + """ + model = self.get_model(resource_type_id) + existing = self._postgres_resources[resource_type_id].search_existing( + tenant_id, resource + ) + if existing: + existing = model.model_validate(existing) + if existing.active: + raise SCIMException(Error.make_uniqueness_error()) + resource = self._postgres_resources[resource_type_id].restore_resource( + tenant_id, resource + ) + else: + resource = self._postgres_resources[resource_type_id].create_resource( + tenant_id, resource + ) + resource = model.model_validate(resource) + return resource + + def update_resource( + self, tenant_id: int, resource_type_id: str, resource: Resource + ) -> Resource | None: + """Update a resource. The resource is identified by its ID. + + :param resource_type_id: ID of the resource type to update. + :param resource: Resource to update. + :return: The updated resource. Updating should update the + "meta.lastModified" data. May be the same object that is + passed in. + """ + model = self.get_model(resource_type_id) + existing = self._postgres_resources[resource_type_id].search_existing( + tenant_id, resource + ) + if existing: + existing = model.model_validate(existing) + if existing.active: + if existing.id != resource.id: + raise SCIMException(Error.make_uniqueness_error()) + resource = self._postgres_resources[resource_type_id].update_resource( + tenant_id, resource + ) + else: + self._postgres_resources[resource_type_id].delete_resource( + existing.id, tenant_id + ) + resource = self._postgres_resources[resource_type_id].update_resource( + resource.id, tenant_id, resource + ) + else: + resource = self._postgres_resources[resource_type_id].update_resource( + tenant_id, resource + ) + resource = model.model_validate(resource) + return resource diff --git a/ee/api/routers/scim/fixtures/custom_resource_types.json b/ee/api/routers/scim/fixtures/custom_resource_types.json new file mode 100644 index 000000000..0c6e718d3 --- /dev/null +++ b/ee/api/routers/scim/fixtures/custom_resource_types.json @@ -0,0 +1,36 @@ +[{ + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:ResourceType"], + "id": "User", + "name": "User", + "endpoint": "/Users", + "description": "User Account", + "schema": "urn:ietf:params:scim:schemas:core:2.0:User", + "schemaExtensions": [ + { + "schema": + "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User", + "required": true + }, + { + "schema": + "urn:ietf:params:scim:schemas:extension:openreplay:2.0:User", + "required": true + } + ], + "meta": { + "location": "/v2/ResourceTypes/User", + "resourceType": "ResourceType" + } + }, + { + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:ResourceType"], + "id": "Group", + "name": "Group", + "endpoint": "/Groups", + "description": "Group", + "schema": "urn:ietf:params:scim:schemas:core:2.0:Group", + "meta": { + "location": "/v2/ResourceTypes/Group", + "resourceType": "ResourceType" + } + }] diff --git a/ee/api/routers/scim/fixtures/custom_schemas.json b/ee/api/routers/scim/fixtures/custom_schemas.json new file mode 100644 index 000000000..eb8e8da37 --- /dev/null +++ b/ee/api/routers/scim/fixtures/custom_schemas.json @@ -0,0 +1,32 @@ +[ + { + "id": "urn:ietf:params:scim:schemas:extension:openreplay:2.0:User", + "name": "OpenreplayUser", + "description": "Openreplay User Account Extension", + "attributes": [ + { + "name": "permissions", + "type": "string", + "multiValued": true, + "description": "A list of permissions for the User that represent a thing the User is capable of doing.", + "required": false, + "canonicalValues": ["SESSION_REPLAY", "DEV_TOOLS", "METRICS", "ASSIST_LIVE", "ASSIST_CALL", "SPOT", "SPOT_PUBLIC"], + "mutability": "readWrite", + "returned": "default" + }, + { + "name": "projectKeys", + "type": "string", + "multiValued": true, + "description": "A list of project keys for the User that represent a project the User is allowed to work on.", + "required": false, + "mutability": "readWrite", + "returned": "default" + } + ], + "meta": { + "resourceType": "Schema", + "location": "/v2/Schemas/urn:ietf:params:scim:schemas:extension:openreplay:2.0:User" + } + } +] diff --git a/ee/api/routers/scim/groups.py b/ee/api/routers/scim/groups.py new file mode 100644 index 000000000..dda2517a3 --- /dev/null +++ b/ee/api/routers/scim/groups.py @@ -0,0 +1,203 @@ +from typing import Any +from datetime import datetime +from psycopg2.extensions import AsIs + +from chalicelib.utils import pg_client +from routers.scim import helpers + +from scim2_models import Error, Resource +from scim2_server.utils import SCIMException + + +def convert_provider_resource_to_client_resource( + provider_resource: dict, +) -> dict: + members = provider_resource["users"] or [] + return { + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], + "id": str(provider_resource["role_id"]), + "meta": { + "resourceType": "Group", + "created": provider_resource["created_at"].strftime("%Y-%m-%dT%H:%M:%SZ"), + "lastModified": provider_resource["updated_at"].strftime( + "%Y-%m-%dT%H:%M:%SZ" + ), + }, + "displayName": provider_resource["name"], + "members": [ + { + "value": str(member["user_id"]), + "$ref": f"Users/{member['user_id']}", + "type": "User", + } + for member in members + ], + } + + +def query_resources(tenant_id: int) -> list[dict]: + query = _main_select_query(tenant_id) + with pg_client.PostgresClient() as cur: + cur.execute(query) + items = cur.fetchall() + return [convert_provider_resource_to_client_resource(item) for item in items] + + +def get_resource(resource_id: str, tenant_id: int) -> dict | None: + query = _main_select_query(tenant_id, resource_id) + with pg_client.PostgresClient() as cur: + cur.execute(query) + item = cur.fetchone() + if item: + return convert_provider_resource_to_client_resource(item) + return None + + +def delete_resource(resource_id: str, tenant_id: int) -> None: + _update_resource_sql( + resource_id=resource_id, + tenant_id=tenant_id, + deleted_at=datetime.now(), + ) + + +def search_existing(tenant_id: int, resource: Resource) -> dict | None: + return None + + +def create_resource(tenant_id: int, resource: Resource) -> dict: + with pg_client.PostgresClient() as cur: + user_ids = ( + [int(x.value) for x in resource.members] if resource.members else None + ) + user_id_clause = helpers.safe_mogrify_array(user_ids, "int", cur) + try: + cur.execute( + cur.mogrify( + """ + INSERT INTO public.roles ( + name, + tenant_id + ) + VALUES ( + %(name)s, + %(tenant_id)s + ) + RETURNING role_id + """, + { + "name": resource.display_name, + "tenant_id": tenant_id, + }, + ) + ) + except Exception: + raise SCIMException(Error.make_invalid_value_error()) + role_id = cur.fetchone()["role_id"] + cur.execute( + f""" + UPDATE public.users + SET + updated_at = now(), + role_id = {role_id} + WHERE users.user_id = ANY({user_id_clause}) + """ + ) + cur.execute(f"{_main_select_query(tenant_id, role_id)} LIMIT 1") + item = cur.fetchone() + return convert_provider_resource_to_client_resource(item) + + +def update_resource(tenant_id: int, resource: Resource) -> dict | None: + item = _update_resource_sql( + resource_id=resource.id, + tenant_id=tenant_id, + name=resource.display_name, + user_ids=[int(x.value) for x in resource.members], + deleted_at=None, + ) + return convert_provider_resource_to_client_resource(item) + + +def _main_select_query(tenant_id: int, resource_id: str | None = None) -> str: + where_and_clauses = [ + f"roles.tenant_id = {tenant_id}", + "roles.deleted_at IS NULL", + ] + if resource_id is not None: + where_and_clauses.append(f"roles.role_id = {resource_id}") + where_clause = " AND ".join(where_and_clauses) + return f""" + SELECT + roles.*, + COALESCE( + ( + SELECT json_agg(users) + FROM public.users + WHERE users.role_id = roles.role_id + ), + '[]' + ) AS users, + COALESCE( + ( + SELECT json_agg(projects.project_key) + FROM public.projects + LEFT JOIN public.roles_projects USING (project_id) + WHERE roles_projects.role_id = roles.role_id + ), + '[]' + ) AS project_keys + FROM public.roles + WHERE {where_clause} + """ + + +def _update_resource_sql( + resource_id: int, + tenant_id: int, + user_ids: list[int] | None = None, + **kwargs: dict[str, Any], +) -> dict[str, Any]: + with pg_client.PostgresClient() as cur: + kwargs["updated_at"] = datetime.now() + set_fragments = [ + cur.mogrify("%s = %s", (AsIs(k), v)).decode("utf-8") + for k, v in kwargs.items() + ] + set_clause = ", ".join(set_fragments) + user_id_clause = helpers.safe_mogrify_array(user_ids, "int", cur) + cur.execute( + f""" + UPDATE public.users + SET + updated_at = now(), + role_id = NULL + WHERE + users.role_id = {resource_id} + AND users.user_id != ALL({user_id_clause}) + RETURNING * + """ + ) + cur.execute( + f""" + UPDATE public.users + SET + updated_at = now(), + role_id = {resource_id} + WHERE + (users.role_id != {resource_id} OR users.role_id IS NULL) + AND users.user_id = ANY({user_id_clause}) + RETURNING * + """ + ) + cur.execute( + f""" + UPDATE public.roles + SET {set_clause} + WHERE + roles.role_id = {resource_id} + AND roles.tenant_id = {tenant_id} + """ + ) + cur.execute(f"{_main_select_query(tenant_id, resource_id)} LIMIT 1") + return cur.fetchone() diff --git a/ee/api/routers/scim/helpers.py b/ee/api/routers/scim/helpers.py new file mode 100644 index 000000000..85c9c49f2 --- /dev/null +++ b/ee/api/routers/scim/helpers.py @@ -0,0 +1,44 @@ +from typing import Any, Literal +from chalicelib.utils import pg_client +from scim2_models import Schema, Resource, ResourceType +import os +import json + + +def safe_mogrify_array( + items: list[Any] | None, + array_type: Literal["varchar", "int"], + cursor: pg_client.PostgresClient, +) -> str: + items = items or [] + fragments = [cursor.mogrify("%s", (item,)).decode("utf-8") for item in items] + result = f"ARRAY[{', '.join(fragments)}]::{array_type}[]" + return result + + +def load_json_resource(json_name: str) -> dict: + with open(json_name) as f: + return json.load(f) + + +def load_scim_resource( + json_name: str, type_: type[Resource] +) -> dict[str, type[Resource]]: + ret = {} + definitions = load_json_resource(json_name) + for d in definitions: + model = type_.model_validate(d) + ret[model.id] = model + return ret + + +def load_custom_schemas() -> dict[str, Schema]: + json_name = os.path.join("routers", "scim", "fixtures", "custom_schemas.json") + return load_scim_resource(json_name, Schema) + + +def load_custom_resource_types() -> dict[str, ResourceType]: + json_name = os.path.join( + "routers", "scim", "fixtures", "custom_resource_types.json" + ) + return load_scim_resource(json_name, ResourceType) diff --git a/ee/api/routers/scim/postgres_resource.py b/ee/api/routers/scim/postgres_resource.py new file mode 100644 index 000000000..c06bc17f6 --- /dev/null +++ b/ee/api/routers/scim/postgres_resource.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from typing import Callable +from scim2_models import Resource + + +@dataclass +class PostgresResource: + query_resources: Callable[[int], list[dict]] + get_resource: Callable[[str, int], dict | None] + create_resource: Callable[[int, Resource], dict] + search_existing: Callable[[int, Resource], dict | None] + restore_resource: Callable[[int, Resource], dict] | None + delete_resource: Callable[[str, int], None] + update_resource: Callable[[int, Resource], dict] diff --git a/ee/api/routers/scim/providers.py b/ee/api/routers/scim/providers.py new file mode 100644 index 000000000..be24acc6e --- /dev/null +++ b/ee/api/routers/scim/providers.py @@ -0,0 +1,280 @@ +import traceback +from typing import Union + +from scim2_server import provider + +from scim2_models import ( + AuthenticationScheme, + ServiceProviderConfig, + Patch, + Bulk, + Filter, + Sort, + ETag, + Meta, + ChangePassword, + Error, + ResourceType, + Context, + ListResponse, + PatchOp, +) + +from werkzeug import Request, Response +from werkzeug.exceptions import HTTPException, NotFound, PreconditionFailed +from pydantic import ValidationError +from werkzeug.routing.exceptions import RequestRedirect +from scim2_server.utils import SCIMException, merge_resources + +from chalicelib.utils.scim_auth import verify_access_token + + +class MultiTenantProvider(provider.SCIMProvider): + def check_auth(self, request: Request): + auth = request.headers.get("Authorization") + if not auth or not auth.startswith("Bearer "): + return None + token = auth[len("Bearer ") :] + if not token: + return Response( + "Missing or invalid Authorization header", + status=401, + headers={"WWW-Authenticate": 'Bearer realm="login required"'}, + ) + payload = verify_access_token(token) + tenant_id = payload["tenant_id"] + return tenant_id + + def get_service_provider_config(self): + auth_schemes = [ + AuthenticationScheme( + type="oauthbearertoken", + name="OAuth Bearer Token", + description="Authentication scheme using the OAuth Bearer Token Standard. The access token should be sent in the 'Authorization' header using the Bearer schema.", + spec_uri="https://datatracker.ietf.org/doc/html/rfc6750", + ) + ] + return ServiceProviderConfig( + # todo(jon): write correct documentation uri + documentation_uri="https://www.example.com/", + patch=Patch(supported=True), + bulk=Bulk(supported=False), + filter=Filter(supported=True, max_results=1000), + change_password=ChangePassword(supported=False), + sort=Sort(supported=True), + etag=ETag(supported=False), + authentication_schemes=auth_schemes, + meta=Meta(resource_type="ServiceProviderConfig"), + ) + + def query_resource( + self, request: Request, tenant_id: int, resource: ResourceType | None + ): + search_request = self.build_search_request(request) + + kwargs = {} + if resource is not None: + kwargs["resource_type_id"] = resource.id + total_results, results = self.backend.query_resources( + search_request=search_request, tenant_id=tenant_id, **kwargs + ) + for r in results: + self.adjust_location(request, r) + + resources = [ + s.model_dump( + scim_ctx=Context.RESOURCE_QUERY_RESPONSE, + attributes=search_request.attributes, + excluded_attributes=search_request.excluded_attributes, + ) + for s in results + ] + + return ListResponse[Union[tuple(self.backend.get_models())]]( # noqa: UP007 + total_results=total_results, + items_per_page=search_request.count, + start_index=search_request.start_index, + resources=resources, + ) + + def call_resource( + self, request: Request, resource_endpoint: str, **kwargs + ) -> Response: + resource_type = self.backend.get_resource_type_by_endpoint( + "/" + resource_endpoint + ) + if not resource_type: + raise NotFound + + if "tenant_id" not in kwargs: + raise Exception + tenant_id = kwargs["tenant_id"] + + match request.method: + case "GET": + return self.make_response( + self.query_resource(request, tenant_id, resource_type).model_dump( + scim_ctx=Context.RESOURCE_QUERY_RESPONSE, + ) + ) + case _: # "POST" + payload = request.json + resource = self.backend.get_model(resource_type.id).model_validate( + payload, scim_ctx=Context.RESOURCE_CREATION_REQUEST + ) + created_resource = self.backend.create_resource( + tenant_id, + resource_type.id, + resource, + ) + self.adjust_location(request, created_resource) + return self.make_response( + created_resource.model_dump( + scim_ctx=Context.RESOURCE_CREATION_RESPONSE + ), + status=201, + headers={"Location": created_resource.meta.location}, + ) + + def call_single_resource( + self, request: Request, resource_endpoint: str, resource_id: str, **kwargs + ) -> Response: + find_endpoint = "/" + resource_endpoint + resource_type = self.backend.get_resource_type_by_endpoint(find_endpoint) + if not resource_type: + raise NotFound + + if "tenant_id" not in kwargs: + raise Exception + tenant_id = kwargs["tenant_id"] + + match request.method: + case "GET": + if resource := self.backend.get_resource( + tenant_id, resource_type.id, resource_id + ): + if self.continue_etag(request, resource): + response_args = self.get_attrs_from_request(request) + self.adjust_location(request, resource) + return self.make_response( + resource.model_dump( + scim_ctx=Context.RESOURCE_QUERY_RESPONSE, + **response_args, + ) + ) + else: + return self.make_response(None, status=304) + raise NotFound + case "DELETE": + if self.backend.delete_resource( + tenant_id, resource_type.id, resource_id + ): + return self.make_response(None, 204) + else: + raise NotFound + case "PUT": + response_args = self.get_attrs_from_request(request) + resource = self.backend.get_resource( + tenant_id, resource_type.id, resource_id + ) + if resource is None: + raise NotFound + if not self.continue_etag(request, resource): + raise PreconditionFailed + + updated_attributes = self.backend.get_model( + resource_type.id + ).model_validate(request.json) + merge_resources(resource, updated_attributes) + updated = self.backend.update_resource( + tenant_id, resource_type.id, resource + ) + self.adjust_location(request, updated) + return self.make_response( + updated.model_dump( + scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE, + **response_args, + ) + ) + case _: # "PATCH" + payload = request.json + # MS Entra sometimes passes a "id" attribute + if "id" in payload: + del payload["id"] + operations = payload.get("Operations", []) + for operation in operations: + if "name" in operation: + # MS Entra sometimes passes a "name" attribute + del operation["name"] + + patch_operation = PatchOp.model_validate(payload) + response_args = self.get_attrs_from_request(request) + resource = self.backend.get_resource( + tenant_id, resource_type.id, resource_id + ) + if resource is None: + raise NotFound + if not self.continue_etag(request, resource): + raise PreconditionFailed + + self.apply_patch_operation(resource, patch_operation) + updated = self.backend.update_resource( + tenant_id, resource_type.id, resource + ) + + if response_args: + self.adjust_location(request, updated) + return self.make_response( + updated.model_dump( + scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE, + **response_args, + ) + ) + else: + # RFC 7644, section 3.5.2: + # A PATCH operation MAY return a 204 (no content) + # if no attributes were requested + return self.make_response( + None, 204, headers={"ETag": updated.meta.version} + ) + + def wsgi_app(self, request: Request, environ): + try: + if environ.get("PATH_INFO", "").endswith(".scim"): + # RFC 7644, Section 3.8 + # Just strip .scim suffix, the provider always returns application/scim+json + environ["PATH_INFO"], _, _ = environ["PATH_INFO"].rpartition(".scim") + urls = self.url_map.bind_to_environ(environ) + endpoint, args = urls.match() + + tenant_id = None + if endpoint != "service_provider_config": + # RFC7643, Section 5: skip authentication for ServiceProviderConfig + tenant_id = self.check_auth(request) + + # Wrap the entire call in a transaction. Should probably be optimized (use transaction only when necessary). + with self.backend: + if endpoint == "service_provider_config" or endpoint == "schema": + response = getattr(self, f"call_{endpoint}")(request, **args) + else: + response = getattr(self, f"call_{endpoint}")( + request, **args, tenant_id=tenant_id + ) + return response + except RequestRedirect as e: + # urls.match may cause a redirect, handle it as a special case of HTTPException + self.log.exception(e) + return e.get_response(environ) + except HTTPException as e: + self.log.exception(e) + return self.make_error(Error(status=e.code, detail=e.description)) + except SCIMException as e: + self.log.exception(e) + return self.make_error(e.scim_error) + except ValidationError as e: + self.log.exception(e) + return self.make_error(Error(status=400, detail=str(e))) + except Exception as e: + self.log.exception(e) + tb = traceback.format_exc() + return self.make_error(Error(status=500, detail=str(e) + "\n" + tb)) diff --git a/ee/api/routers/scim/users.py b/ee/api/routers/scim/users.py new file mode 100644 index 000000000..7e42c4ebd --- /dev/null +++ b/ee/api/routers/scim/users.py @@ -0,0 +1,371 @@ +from routers.scim import helpers + +from chalicelib.utils import pg_client +from scim2_models import Resource + + +def convert_provider_resource_to_client_resource( + provider_resource: dict, +) -> dict: + groups = [] + if provider_resource["role_id"]: + groups.append( + { + "value": str(provider_resource["role_id"]), + "$ref": f"Groups/{provider_resource['role_id']}", + } + ) + return { + "id": str(provider_resource["user_id"]), + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User", + "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User", + "urn:ietf:params:scim:schemas:extension:openreplay:2.0:User", + ], + "meta": { + "resourceType": "User", + "created": provider_resource["created_at"].strftime("%Y-%m-%dT%H:%M:%SZ"), + "lastModified": provider_resource["updated_at"].strftime( + "%Y-%m-%dT%H:%M:%SZ" + ), + }, + "userName": provider_resource["email"], + "externalId": provider_resource["internal_id"], + "name": { + "formatted": provider_resource["name"], + }, + "displayName": provider_resource["name"] or provider_resource["email"], + "active": provider_resource["deleted_at"] is None, + "groups": groups, + "urn:ietf:params:scim:schemas:extension:openreplay:2.0:User": { + "permissions": provider_resource.get("permissions") or [], + "projectKeys": provider_resource.get("project_keys") or [], + }, + } + + +def query_resources(tenant_id: int) -> list[dict]: + with pg_client.PostgresClient() as cur: + cur.execute( + f""" + SELECT + users.*, + roles.permissions AS permissions, + COALESCE( + ( + SELECT json_agg(projects.project_key) + FROM public.projects + LEFT JOIN public.roles_projects USING (project_id) + WHERE roles_projects.role_id = roles.role_id + ), + '[]' + ) AS project_keys + FROM public.users + LEFT JOIN public.roles ON roles.role_id = users.role_id + WHERE users.tenant_id = {tenant_id} AND users.deleted_at IS NULL + """ + ) + items = cur.fetchall() + return [convert_provider_resource_to_client_resource(item) for item in items] + + +def get_resource(resource_id: str, tenant_id: int) -> dict | None: + with pg_client.PostgresClient() as cur: + cur.execute( + f""" + SELECT + users.*, + roles.permissions AS permissions, + COALESCE( + ( + SELECT json_agg(projects.project_key) + FROM public.projects + LEFT JOIN public.roles_projects USING (project_id) + WHERE roles_projects.role_id = roles.role_id + ), + '[]' + ) AS project_keys + FROM public.users + LEFT JOIN public.roles ON roles.role_id = users.role_id + WHERE users.tenant_id = {tenant_id} AND users.deleted_at IS NULL AND users.user_id = {resource_id} + """ + ) + item = cur.fetchone() + if item: + return convert_provider_resource_to_client_resource(item) + return None + + +def delete_resource(resource_id: str, tenatn_id: int) -> None: + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + UPDATE public.users + SET + deleted_at = NULL, + updated_at = now() + WHERE users.user_id = %(user_id)s + """, + {"user_id": resource_id}, + ) + ) + + +def search_existing(tenant_id: int, resource: Resource) -> dict | None: + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + SELECT * + FROM public.users + WHERE email = %(email)s + """, + {"email": resource.user_name}, + ) + ) + item = cur.fetchone() + if item: + return convert_provider_resource_to_client_resource(item) + return None + + +def restore_resource(tenant_id: int, resource: Resource) -> dict | None: + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + SELECT role_id + FROM public.users + WHERE user_id = %(user_id)s + """, + {"user_id": resource.id}, + ) + ) + item = cur.fetchone() + if item and item["role_id"] is not None: + _update_role_projects_and_permissions( + item["role_id"], + resource.OpenreplayUser.project_keys, + resource.OpenreplayUser.permissions, + cur, + ) + cur.execute( + cur.mogrify( + """ + WITH u AS ( + UPDATE public.users + SET + tenant_id = %(tenant_id)s, + email = %(email)s, + name = %(name)s, + internal_id = %(internal_id)s, + deleted_at = NULL, + created_at = now(), + updated_at = now(), + api_key = default, + jwt_iat = NULL, + weekly_report = default + WHERE users.email = %(email)s + RETURNING * + ) + SELECT + u.*, + roles.permissions AS permissions, + COALESCE( + ( + SELECT json_agg(projects.project_key) + FROM public.projects + LEFT JOIN public.roles_projects USING (project_id) + WHERE roles_projects.role_id = roles.role_id + ), + '[]' + ) AS project_keys + FROM u + LEFT JOIN public.roles ON roles.role_id = u.role_id + """, + { + "tenant_id": tenant_id, + "email": resource.user_name, + "name": " ".join( + [ + x + for x in [ + resource.name.honorific_prefix, + resource.name.given_name, + resource.name.middle_name, + resource.name.family_name, + resource.name.honorific_suffix, + ] + if x + ] + ) + if resource.name + else "", + "internal_id": resource.external_id, + }, + ) + ) + item = cur.fetchone() + return convert_provider_resource_to_client_resource(item) + + +def create_resource(tenant_id: int, resource: Resource) -> dict: + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + WITH u AS ( + INSERT INTO public.users ( + tenant_id, + email, + name, + internal_id + ) + VALUES ( + %(tenant_id)s, + %(email)s, + %(name)s, + %(internal_id)s + ) + RETURNING * + ) + SELECT * + FROM u + """, + { + "tenant_id": tenant_id, + "email": resource.user_name, + "name": " ".join( + [ + x + for x in [ + resource.name.honorific_prefix, + resource.name.given_name, + resource.name.middle_name, + resource.name.family_name, + resource.name.honorific_suffix, + ] + if x + ] + ) + if resource.name + else "", + "internal_id": resource.external_id, + }, + ) + ) + item = cur.fetchone() + return convert_provider_resource_to_client_resource(item) + + +def update_resource(tenant_id: int, resource: Resource) -> dict | None: + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + SELECT role_id + FROM public.users + WHERE user_id = %(user_id)s + """, + {"user_id": resource.id}, + ) + ) + item = cur.fetchone() + if item and item["role_id"] is not None: + _update_role_projects_and_permissions( + item["role_id"], + resource.OpenreplayUser.project_keys, + resource.OpenreplayUser.permissions, + cur, + ) + cur.execute( + cur.mogrify( + """ + WITH u AS ( + UPDATE public.users + SET + tenant_id = %(tenant_id)s, + email = %(email)s, + name = %(name)s, + internal_id = %(internal_id)s, + updated_at = now() + WHERE user_id = %(user_id)s + RETURNING * + ) + SELECT + u.*, + roles.permissions AS permissions, + COALESCE( + ( + SELECT json_agg(projects.project_key) + FROM public.projects + LEFT JOIN public.roles_projects USING (project_id) + WHERE roles_projects.role_id = roles.role_id + ), + '[]' + ) AS project_keys + FROM u + LEFT JOIN public.roles ON roles.role_id = u.role_id + """, + { + "user_id": resource.id, + "tenant_id": tenant_id, + "email": resource.user_name, + "name": " ".join( + [ + x + for x in [ + resource.name.honorific_prefix, + resource.name.given_name, + resource.name.middle_name, + resource.name.family_name, + resource.name.honorific_suffix, + ] + if x + ] + ) + if resource.name + else "", + "internal_id": resource.external_id, + }, + ) + ) + item = cur.fetchone() + return convert_provider_resource_to_client_resource(item) + + +def _update_role_projects_and_permissions( + role_id: int, + project_keys: list[str] | None, + permissions: list[str] | None, + cur: pg_client.PostgresClient, +) -> None: + all_projects = "true" if not project_keys else "false" + project_key_clause = helpers.safe_mogrify_array(project_keys, "varchar", cur) + permission_clause = helpers.safe_mogrify_array(permissions, "varchar", cur) + cur.execute( + f""" + UPDATE public.roles + SET + updated_at = now(), + all_projects = {all_projects}, + permissions = {permission_clause} + WHERE role_id = {role_id} + RETURNING * + """ + ) + cur.execute( + f""" + DELETE FROM public.roles_projects + WHERE roles_projects.role_id = {role_id} + """ + ) + cur.execute( + f""" + INSERT INTO public.roles_projects (role_id, project_id) + SELECT {role_id}, projects.project_id + FROM public.projects + WHERE projects.project_key = ANY({project_key_clause}) + """ + ) diff --git a/ee/scripts/schema/db/init_dbs/postgresql/init_schema.sql b/ee/scripts/schema/db/init_dbs/postgresql/init_schema.sql index caf4e7467..f0ea95b84 100644 --- a/ee/scripts/schema/db/init_dbs/postgresql/init_schema.sql +++ b/ee/scripts/schema/db/init_dbs/postgresql/init_schema.sql @@ -108,6 +108,16 @@ CREATE TABLE public.tenants ); +CREATE TABLE public.scim_auth_codes +( + auth_code_id integer generated BY DEFAULT AS IDENTITY PRIMARY KEY, + tenant_id integer NOT NULL REFERENCES public.tenants (tenant_id) ON DELETE CASCADE, + auth_code text NOT NULL UNIQUE DEFAULT generate_api_key(20), + created_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'), + used bool NOT NULL DEFAULT FALSE +); + + CREATE TABLE public.roles ( role_id integer generated BY DEFAULT AS IDENTITY PRIMARY KEY, @@ -118,6 +128,7 @@ CREATE TABLE public.roles protected bool NOT NULL DEFAULT FALSE, all_projects bool NOT NULL DEFAULT TRUE, created_at timestamp NOT NULL DEFAULT timezone('utc'::text, now()), + updated_at timestamp NOT NULL DEFAULT timezone('utc'::text, now()), deleted_at timestamp NULL DEFAULT NULL, service_role bool NOT NULL DEFAULT FALSE ); @@ -132,6 +143,7 @@ CREATE TABLE public.users role user_role NOT NULL DEFAULT 'member', name text NOT NULL, created_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'), + updated_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'), deleted_at timestamp without time zone NULL DEFAULT NULL, api_key text UNIQUE DEFAULT generate_api_key(20) NOT NULL, jwt_iat timestamp without time zone NULL DEFAULT NULL,