From 464b9b1b4774e1121ea92489677ccc917f6af362 Mon Sep 17 00:00:00 2001 From: Jonathan Griffin Date: Fri, 18 Apr 2025 10:37:31 +0200 Subject: [PATCH] reformat files and remove unnecessary imports --- ee/api/chalicelib/core/roles.py | 107 ++++-- ee/api/chalicelib/core/users.py | 549 ++++++++++++++++++--------- ee/api/chalicelib/utils/scim_auth.py | 23 +- ee/api/routers/scim.py | 124 +++--- ee/api/routers/scim_constants.py | 63 +-- ee/api/routers/scim_helpers.py | 22 +- 6 files changed, 588 insertions(+), 300 deletions(-) diff --git a/ee/api/chalicelib/core/roles.py b/ee/api/chalicelib/core/roles.py index 0bad7aade..321ca1102 100644 --- a/ee/api/chalicelib/core/roles.py +++ b/ee/api/chalicelib/core/roles.py @@ -9,13 +9,15 @@ from chalicelib.utils.TimeUTC import TimeUTC def __exists_by_name(tenant_id: int, name: str, exclude_id: Optional[int]) -> bool: with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""SELECT EXISTS(SELECT 1 + query = cur.mogrify( + f"""SELECT EXISTS(SELECT 1 FROM public.roles WHERE tenant_id = %(tenant_id)s AND name ILIKE %(name)s AND deleted_at ISNULL {"AND role_id!=%(exclude_id)s" if exclude_id else ""}) AS exists;""", - {"tenant_id": tenant_id, "name": name, "exclude_id": exclude_id}) + {"tenant_id": tenant_id, "name": name, "exclude_id": exclude_id}, + ) cur.execute(query=query) row = cur.fetchone() return row["exists"] @@ -27,24 +29,31 @@ def update(tenant_id, user_id, role_id, data: schemas.RolePayloadSchema): if not admin["admin"] and not admin["superAdmin"]: return {"errors": ["unauthorized"]} if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=role_id): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"name already exists.") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="name already exists." + ) if not data.all_projects and (data.projects is None or len(data.projects) == 0): return {"errors": ["must specify a project or all projects"]} if data.projects is not None and len(data.projects) > 0 and not data.all_projects: - data.projects = projects.is_authorized_batch(project_ids=data.projects, tenant_id=tenant_id) + data.projects = projects.is_authorized_batch( + project_ids=data.projects, tenant_id=tenant_id + ) with pg_client.PostgresClient() as cur: - query = cur.mogrify("""SELECT 1 + query = cur.mogrify( + """SELECT 1 FROM public.roles WHERE role_id = %(role_id)s AND tenant_id = %(tenant_id)s AND protected = TRUE LIMIT 1;""", - {"tenant_id": tenant_id, "role_id": role_id}) + {"tenant_id": tenant_id, "role_id": role_id}, + ) cur.execute(query=query) if cur.fetchone() is not None: return {"errors": ["this role is protected"]} - query = cur.mogrify("""UPDATE public.roles + query = cur.mogrify( + """UPDATE public.roles SET name= %(name)s, description= %(description)s, permissions= %(permissions)s, @@ -56,23 +65,31 @@ def update(tenant_id, user_id, role_id, data: schemas.RolePayloadSchema): RETURNING *, COALESCE((SELECT ARRAY_AGG(project_id) FROM roles_projects WHERE roles_projects.role_id=%(role_id)s),'{}') AS projects;""", - {"tenant_id": tenant_id, "role_id": role_id, **data.model_dump()}) + {"tenant_id": tenant_id, "role_id": role_id, **data.model_dump()}, + ) cur.execute(query=query) row = cur.fetchone() row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"]) if not data.all_projects: d_projects = [i for i in row["projects"] if i not in data.projects] if len(d_projects) > 0: - query = cur.mogrify("""DELETE FROM roles_projects + query = cur.mogrify( + """DELETE FROM roles_projects WHERE role_id=%(role_id)s AND project_id IN %(project_ids)s""", - {"role_id": role_id, "project_ids": tuple(d_projects)}) + {"role_id": role_id, "project_ids": tuple(d_projects)}, + ) cur.execute(query=query) n_projects = [i for i in data.projects if i not in row["projects"]] if len(n_projects) > 0: - query = cur.mogrify(f"""INSERT INTO roles_projects(role_id, project_id) + query = cur.mogrify( + f"""INSERT INTO roles_projects(role_id, project_id) VALUES {",".join([f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(n_projects))])}""", - {"role_id": role_id, **{f"project_id_{i}": p for i, p in enumerate(n_projects)}}) + { + "role_id": role_id, + **{f"project_id_{i}": p for i, p in enumerate(n_projects)}, + }, + ) cur.execute(query=query) row["projects"] = data.projects @@ -86,28 +103,44 @@ def create(tenant_id, user_id, data: schemas.RolePayloadSchema): return {"errors": ["unauthorized"]} if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=None): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"name already exists.") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="name already exists." + ) if not data.all_projects and (data.projects is None or len(data.projects) == 0): return {"errors": ["must specify a project or all projects"]} if data.projects is not None and len(data.projects) > 0 and not data.all_projects: - data.projects = projects.is_authorized_batch(project_ids=data.projects, tenant_id=tenant_id) + data.projects = projects.is_authorized_batch( + project_ids=data.projects, tenant_id=tenant_id + ) with pg_client.PostgresClient() as cur: - query = cur.mogrify("""INSERT INTO roles(tenant_id, name, description, permissions, all_projects) + query = cur.mogrify( + """INSERT INTO roles(tenant_id, name, description, permissions, all_projects) VALUES (%(tenant_id)s, %(name)s, %(description)s, %(permissions)s::text[], %(all_projects)s) RETURNING *;""", - {"tenant_id": tenant_id, "name": data.name, "description": data.description, - "permissions": data.permissions, "all_projects": data.all_projects}) + { + "tenant_id": tenant_id, + "name": data.name, + "description": data.description, + "permissions": data.permissions, + "all_projects": data.all_projects, + }, + ) cur.execute(query=query) row = cur.fetchone() row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"]) row["projects"] = [] if not data.all_projects: role_id = row["role_id"] - query = cur.mogrify(f"""INSERT INTO roles_projects(role_id, project_id) + query = cur.mogrify( + f"""INSERT INTO roles_projects(role_id, project_id) VALUES {",".join(f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(data.projects)))} RETURNING project_id;""", - {"role_id": role_id, **{f"project_id_{i}": p for i, p in enumerate(data.projects)}}) + { + "role_id": role_id, + **{f"project_id_{i}": p for i, p in enumerate(data.projects)}, + }, + ) cur.execute(query=query) row["projects"] = [r["project_id"] for r in cur.fetchall()] return helper.dict_to_camel_case(row) @@ -115,7 +148,8 @@ def create(tenant_id, user_id, data: schemas.RolePayloadSchema): def get_roles(tenant_id): with pg_client.PostgresClient() as cur: - query = cur.mogrify("""SELECT roles.*, COALESCE(projects, '{}') AS projects + query = cur.mogrify( + """SELECT roles.*, COALESCE(projects, '{}') AS projects FROM public.roles LEFT JOIN LATERAL (SELECT array_agg(project_id) AS projects FROM roles_projects @@ -126,7 +160,8 @@ def get_roles(tenant_id): AND deleted_at IS NULL AND not service_role ORDER BY role_id;""", - {"tenant_id": tenant_id}) + {"tenant_id": tenant_id}, + ) cur.execute(query=query) rows = cur.fetchall() for r in rows: @@ -136,12 +171,14 @@ def get_roles(tenant_id): def get_role_by_name(tenant_id, name): with pg_client.PostgresClient() as cur: - query = cur.mogrify("""SELECT * + query = cur.mogrify( + """SELECT * FROM public.roles WHERE tenant_id =%(tenant_id)s AND deleted_at IS NULL AND name ILIKE %(name)s;""", - {"tenant_id": tenant_id, "name": name}) + {"tenant_id": tenant_id, "name": name}, + ) cur.execute(query=query) row = cur.fetchone() if row is not None: @@ -155,45 +192,53 @@ def delete(tenant_id, user_id, role_id): if not admin["admin"] and not admin["superAdmin"]: return {"errors": ["unauthorized"]} with pg_client.PostgresClient() as cur: - query = cur.mogrify("""SELECT 1 + query = cur.mogrify( + """SELECT 1 FROM public.roles WHERE role_id = %(role_id)s AND tenant_id = %(tenant_id)s AND protected = TRUE LIMIT 1;""", - {"tenant_id": tenant_id, "role_id": role_id}) + {"tenant_id": tenant_id, "role_id": role_id}, + ) cur.execute(query=query) if cur.fetchone() is not None: return {"errors": ["this role is protected"]} - query = cur.mogrify("""SELECT 1 + query = cur.mogrify( + """SELECT 1 FROM public.users WHERE role_id = %(role_id)s AND tenant_id = %(tenant_id)s LIMIT 1;""", - {"tenant_id": tenant_id, "role_id": role_id}) + {"tenant_id": tenant_id, "role_id": role_id}, + ) cur.execute(query=query) if cur.fetchone() is not None: return {"errors": ["this role is already attached to other user(s)"]} - query = cur.mogrify("""UPDATE public.roles + query = cur.mogrify( + """UPDATE public.roles SET deleted_at = timezone('utc'::text, now()) WHERE role_id = %(role_id)s AND tenant_id = %(tenant_id)s AND protected = FALSE;""", - {"tenant_id": tenant_id, "role_id": role_id}) + {"tenant_id": tenant_id, "role_id": role_id}, + ) cur.execute(query=query) return get_roles(tenant_id=tenant_id) def get_role(tenant_id, role_id): with pg_client.PostgresClient() as cur: - query = cur.mogrify("""SELECT roles.* + query = cur.mogrify( + """SELECT roles.* FROM public.roles WHERE tenant_id =%(tenant_id)s AND deleted_at IS NULL AND not service_role AND role_id = %(role_id)s LIMIT 1;""", - {"tenant_id": tenant_id, "role_id": role_id}) + {"tenant_id": tenant_id, "role_id": role_id}, + ) cur.execute(query=query) row = cur.fetchone() if row is not None: diff --git a/ee/api/chalicelib/core/users.py b/ee/api/chalicelib/core/users.py index 52bfb1485..a57a194e5 100644 --- a/ee/api/chalicelib/core/users.py +++ b/ee/api/chalicelib/core/users.py @@ -1,7 +1,7 @@ import json import logging import secrets -from typing import Any, Optional +from typing import Optional from decouple import config from fastapi import BackgroundTasks, HTTPException @@ -25,9 +25,12 @@ def __generate_invitation_token(): return secrets.token_urlsafe(64) -def create_new_member(tenant_id, email, invitation_token, admin, name, owner=False, role_id=None): +def create_new_member( + tenant_id, email, invitation_token, admin, name, owner=False, role_id=None +): with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""\ + query = cur.mogrify( + """\ WITH u AS ( INSERT INTO public.users (tenant_id, email, role, name, data, role_id) VALUES (%(tenant_id)s, %(email)s, %(role)s, %(name)s, %(data)s, @@ -53,10 +56,16 @@ def create_new_member(tenant_id, email, invitation_token, admin, name, owner=Fal roles.name AS role_name, TRUE AS has_password FROM au,u LEFT JOIN roles USING(tenant_id) WHERE roles.role_id IS NULL OR roles.role_id = (SELECT u.role_id FROM u);""", - {"tenant_id": tenant_id, "email": email, - "role": "owner" if owner else "admin" if admin else "member", "name": name, - "data": json.dumps({"lastAnnouncementView": TimeUTC.now()}), - "invitation_token": invitation_token, "role_id": role_id}) + { + "tenant_id": tenant_id, + "email": email, + "role": "owner" if owner else "admin" if admin else "member", + "name": name, + "data": json.dumps({"lastAnnouncementView": TimeUTC.now()}), + "invitation_token": invitation_token, + "role_id": role_id, + }, + ) cur.execute(query) row = helper.dict_to_camel_case(cur.fetchone()) if row: @@ -64,9 +73,12 @@ def create_new_member(tenant_id, email, invitation_token, admin, name, owner=Fal return row -def restore_member(tenant_id, user_id, email, invitation_token, admin, name, owner=False, role_id=None): +def restore_member( + tenant_id, user_id, email, invitation_token, admin, name, owner=False, role_id=None +): with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""\ + query = cur.mogrify( + """\ WITH u AS (UPDATE public.users SET name= %(name)s, role = %(role)s, @@ -106,9 +118,16 @@ def restore_member(tenant_id, user_id, email, invitation_token, admin, name, own TRUE AS has_password FROM au,u LEFT JOIN roles USING(tenant_id) WHERE roles.role_id IS NULL OR roles.role_id = (SELECT u.role_id FROM u);""", - {"tenant_id": tenant_id, "user_id": user_id, "email": email, - "role": "owner" if owner else "admin" if admin else "member", "name": name, - "role_id": role_id, "invitation_token": invitation_token}) + { + "tenant_id": tenant_id, + "user_id": user_id, + "email": email, + "role": "owner" if owner else "admin" if admin else "member", + "name": name, + "role_id": role_id, + "invitation_token": invitation_token, + }, + ) cur.execute(query) result = cur.fetchone() result["created_at"] = TimeUTC.datetime_to_timestamp(result["created_at"]) @@ -118,7 +137,8 @@ def restore_member(tenant_id, user_id, email, invitation_token, admin, name, own def generate_new_invitation(user_id): invitation_token = __generate_invitation_token() with pg_client.PostgresClient() as cur: - query = cur.mogrify("""\ + query = cur.mogrify( + """\ UPDATE public.basic_authentication SET invitation_token = %(invitation_token)s, invited_at = timezone('utc'::text, now()), @@ -126,10 +146,9 @@ def generate_new_invitation(user_id): change_pwd_token = NULL WHERE user_id=%(user_id)s RETURNING invitation_token;""", - {"user_id": user_id, "invitation_token": invitation_token}) - cur.execute( - query + {"user_id": user_id, "invitation_token": invitation_token}, ) + cur.execute(query) return __get_invitation_link(cur.fetchone().pop("invitation_token")) @@ -168,14 +187,20 @@ def update_scim_user( "tenant_id": tenant_id, "user_id": user_id, "email": email, - } + }, ) ) return helper.dict_to_camel_case(cur.fetchone()) def update(tenant_id, user_id, changes, output=True): - AUTH_KEYS = ["password", "invitationToken", "invitedAt", "changePwdExpireAt", "changePwdToken"] + AUTH_KEYS = [ + "password", + "invitationToken", + "invitedAt", + "changePwdExpireAt", + "changePwdToken", + ] if len(changes.keys()) == 0: return None @@ -184,7 +209,9 @@ def update(tenant_id, user_id, changes, output=True): for key in changes.keys(): if key in AUTH_KEYS: if key == "password": - sub_query_bauth.append("password = crypt(%(password)s, gen_salt('bf', 12))") + sub_query_bauth.append( + "password = crypt(%(password)s, gen_salt('bf', 12))" + ) sub_query_bauth.append("changed_at = timezone('utc'::text, now())") else: sub_query_bauth.append(f"{helper.key_to_snake_case(key)} = %({key})s") @@ -193,7 +220,9 @@ def update(tenant_id, user_id, changes, output=True): sub_query_users.append("""role_id=(SELECT COALESCE((SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND role_id = %(role_id)s), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name = 'Member' LIMIT 1), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name != 'Owner' LIMIT 1)))""") - elif key == "data": # this is hardcoded, maybe a generic solution would be better + elif ( + key == "data" + ): # this is hardcoded, maybe a generic solution would be better sub_query_users.append(f"data = data || %({(key)})s") else: sub_query_users.append(f"{helper.key_to_snake_case(key)} = %({key})s") @@ -202,26 +231,35 @@ def update(tenant_id, user_id, changes, output=True): changes["data"] = Json(changes["data"]) with pg_client.PostgresClient() as cur: if len(sub_query_users) > 0: - query = cur.mogrify(f"""\ + query = cur.mogrify( + f"""\ UPDATE public.users SET {" ,".join(sub_query_users)} WHERE users.user_id = %(user_id)s AND users.tenant_id = %(tenant_id)s;""", - {"tenant_id": tenant_id, "user_id": user_id, **changes}) + {"tenant_id": tenant_id, "user_id": user_id, **changes}, + ) cur.execute(query) if len(sub_query_bauth) > 0: - query = cur.mogrify(f"""\ + query = cur.mogrify( + f"""\ UPDATE public.basic_authentication SET {" ,".join(sub_query_bauth)} WHERE basic_authentication.user_id = %(user_id)s;""", - {"tenant_id": tenant_id, "user_id": user_id, **changes}) + {"tenant_id": tenant_id, "user_id": user_id, **changes}, + ) cur.execute(query) if not output: return None return get(user_id=user_id, tenant_id=tenant_id) -def create_member(tenant_id, user_id, data: schemas.CreateMemberSchema, background_tasks: BackgroundTasks): +def create_member( + tenant_id, + user_id, + data: schemas.CreateMemberSchema, + background_tasks: BackgroundTasks, +): admin = get(tenant_id=tenant_id, user_id=user_id) if not admin["admin"] and not admin["superAdmin"]: return {"errors": ["unauthorized"]} @@ -235,7 +273,9 @@ def create_member(tenant_id, user_id, data: schemas.CreateMemberSchema, backgrou data.name = data.email role_id = data.roleId if role_id is None: - role_id = roles.get_role_by_name(tenant_id=tenant_id, name="member").get("roleId") + role_id = roles.get_role_by_name(tenant_id=tenant_id, name="member").get( + "roleId" + ) else: role = roles.get_role(tenant_id=tenant_id, role_id=role_id) if role is None: @@ -245,22 +285,46 @@ def create_member(tenant_id, user_id, data: schemas.CreateMemberSchema, backgrou invitation_token = __generate_invitation_token() user = get_deleted_user_by_email(email=data.email) if user is not None and user["tenantId"] == tenant_id: - new_member = restore_member(tenant_id=tenant_id, email=data.email, invitation_token=invitation_token, - admin=data.admin, name=data.name, user_id=user["userId"], role_id=role_id) + new_member = restore_member( + tenant_id=tenant_id, + email=data.email, + invitation_token=invitation_token, + admin=data.admin, + name=data.name, + user_id=user["userId"], + role_id=role_id, + ) elif user is not None: __hard_delete_user(user_id=user["userId"]) - new_member = create_new_member(tenant_id=tenant_id, email=data.email, invitation_token=invitation_token, - admin=data.admin, name=data.name, role_id=role_id) + new_member = create_new_member( + tenant_id=tenant_id, + email=data.email, + invitation_token=invitation_token, + admin=data.admin, + name=data.name, + role_id=role_id, + ) else: - new_member = create_new_member(tenant_id=tenant_id, email=data.email, invitation_token=invitation_token, - admin=data.admin, name=data.name, role_id=role_id) - new_member["invitationLink"] = __get_invitation_link(new_member.pop("invitationToken")) - background_tasks.add_task(email_helper.send_team_invitation, **{ - "recipient": data.email, - "invitation_link": new_member["invitationLink"], - "client_id": tenants.get_by_tenant_id(tenant_id)["name"], - "sender_name": admin["name"] - }) + new_member = create_new_member( + tenant_id=tenant_id, + email=data.email, + invitation_token=invitation_token, + admin=data.admin, + name=data.name, + role_id=role_id, + ) + new_member["invitationLink"] = __get_invitation_link( + new_member.pop("invitationToken") + ) + background_tasks.add_task( + email_helper.send_team_invitation, + **{ + "recipient": data.email, + "invitation_link": new_member["invitationLink"], + "client_id": tenants.get_by_tenant_id(tenant_id)["name"], + "sender_name": admin["name"], + }, + ) return {"data": new_member} @@ -271,14 +335,14 @@ def __get_invitation_link(invitation_token): def allow_password_change(user_id, delta_min=10): pass_token = secrets.token_urlsafe(8) with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""UPDATE public.basic_authentication + query = cur.mogrify( + """UPDATE public.basic_authentication SET change_pwd_expire_at = timezone('utc'::text, now()+INTERVAL '%(delta)s MINUTES'), change_pwd_token = %(pass_token)s WHERE user_id = %(user_id)s""", - {"user_id": user_id, "delta": delta_min, "pass_token": pass_token}) - cur.execute( - query + {"user_id": user_id, "delta": delta_min, "pass_token": pass_token}, ) + cur.execute(query) return pass_token @@ -286,7 +350,7 @@ def get(user_id, tenant_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + """SELECT users.user_id, users.tenant_id, email, @@ -310,11 +374,13 @@ def get(user_id, tenant_id): AND users.deleted_at IS NULL AND (roles.role_id IS NULL OR roles.deleted_at IS NULL AND roles.tenant_id = %(tenant_id)s) LIMIT 1;""", - {"userId": user_id, "tenant_id": tenant_id}) + {"userId": user_id, "tenant_id": tenant_id}, + ) ) r = cur.fetchone() return helper.dict_to_camel_case(r) + def get_scim_user_by_id(user_id, tenant_id): with pg_client.PostgresClient() as cur: cur.execute( @@ -341,12 +407,13 @@ def generate_new_api_key(user_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""UPDATE public.users + """UPDATE public.users SET api_key=generate_api_key(20) WHERE users.user_id = %(userId)s AND deleted_at IS NULL RETURNING api_key;""", - {"userId": user_id}) + {"userId": user_id}, + ) ) r = cur.fetchone() return helper.dict_to_camel_case(r) @@ -356,7 +423,7 @@ def __get_account_info(tenant_id, user_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT users.name, + """SELECT users.name, tenants.name AS tenant_name, tenants.opt_out FROM public.users INNER JOIN public.tenants USING (tenant_id) @@ -364,14 +431,19 @@ def __get_account_info(tenant_id, user_id): AND tenants.tenant_id= %(tenantId)s AND tenants.deleted_at IS NULL AND users.deleted_at IS NULL;""", - {"tenantId": tenant_id, "userId": user_id}) + {"tenantId": tenant_id, "userId": user_id}, + ) ) r = cur.fetchone() return helper.dict_to_camel_case(r) def edit_account(user_id, tenant_id, changes: schemas.EditAccountSchema): - if changes.opt_out is not None or changes.tenantName is not None and len(changes.tenantName) > 0: + if ( + changes.opt_out is not None + or changes.tenantName is not None + and len(changes.tenantName) > 0 + ): user = get(user_id=user_id, tenant_id=tenant_id) if not user["superAdmin"] and not user["admin"]: return {"errors": ["unauthorized"]} @@ -391,7 +463,9 @@ def edit_account(user_id, tenant_id, changes: schemas.EditAccountSchema): return {"data": __get_account_info(tenant_id=tenant_id, user_id=user_id)} -def edit_member(user_id_to_update, tenant_id, changes: schemas.EditMemberSchema, editor_id): +def edit_member( + user_id_to_update, tenant_id, changes: schemas.EditMemberSchema, editor_id +): user = get_member(user_id=user_id_to_update, tenant_id=tenant_id) _changes = {} if editor_id != user_id_to_update: @@ -429,7 +503,12 @@ def edit_member(user_id_to_update, tenant_id, changes: schemas.EditMemberSchema, return {"errors": ["invalid role"]} if len(_changes.keys()) > 0: - update(tenant_id=tenant_id, user_id=user_id_to_update, changes=_changes, output=False) + update( + tenant_id=tenant_id, + user_id=user_id_to_update, + changes=_changes, + output=False, + ) return {"data": get_member(user_id=user_id_to_update, tenant_id=tenant_id)} return {"data": user} @@ -443,7 +522,7 @@ def get_existing_scim_user_by_unique_values(email): FROM public.users WHERE users.email = %(email)s """, - {"email": email} + {"email": email}, ) ) return helper.dict_to_camel_case(cur.fetchone()) @@ -453,7 +532,7 @@ def get_by_email_only(email): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + """SELECT users.user_id, users.tenant_id, users.email, @@ -472,11 +551,13 @@ def get_by_email_only(email): WHERE users.email = %(email)s AND users.deleted_at IS NULL LIMIT 1;""", - {"email": email}) + {"email": email}, + ) ) r = cur.fetchone() return helper.dict_to_camel_case(r) + def get_users_paginated(start_index, tenant_id, count=None): with pg_client.PostgresClient() as cur: cur.execute( @@ -490,11 +571,7 @@ def get_users_paginated(start_index, tenant_id, count=None): LIMIT %(limit)s OFFSET %(offset)s; """, - { - "offset": start_index - 1, - "limit": count, - "tenant_id": tenant_id - }, + {"offset": start_index - 1, "limit": count, "tenant_id": tenant_id}, ) ) r = cur.fetchall() @@ -505,7 +582,7 @@ def get_member(tenant_id, user_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + """SELECT users.user_id, users.email, users.role, @@ -525,7 +602,8 @@ def get_member(tenant_id, user_id): LEFT JOIN public.roles USING (role_id) WHERE users.tenant_id = %(tenant_id)s AND users.deleted_at IS NULL AND users.user_id = %(user_id)s ORDER BY name, user_id""", - {"tenant_id": tenant_id, "user_id": user_id}) + {"tenant_id": tenant_id, "user_id": user_id}, + ) ) u = helper.dict_to_camel_case(cur.fetchone()) if u: @@ -542,7 +620,7 @@ def get_members(tenant_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + """SELECT users.user_id, users.email, users.role, @@ -564,7 +642,8 @@ def get_members(tenant_id): AND users.deleted_at IS NULL AND NOT users.service_account ORDER BY name, user_id""", - {"tenant_id": tenant_id}) + {"tenant_id": tenant_id}, + ) ) r = cur.fetchall() if len(r): @@ -572,7 +651,9 @@ def get_members(tenant_id): for u in r: u["createdAt"] = TimeUTC.datetime_to_timestamp(u["createdAt"]) if u["invitationToken"]: - u["invitationLink"] = __get_invitation_link(u.pop("invitationToken")) + u["invitationLink"] = __get_invitation_link( + u.pop("invitationToken") + ) else: u["invitationLink"] = None return r @@ -597,20 +678,26 @@ def delete_member(user_id, tenant_id, id_to_delete): with pg_client.PostgresClient() as cur: cur.execute( - cur.mogrify(f"""UPDATE public.users + cur.mogrify( + """UPDATE public.users SET deleted_at = timezone('utc'::text, now()), jwt_iat= NULL, jwt_refresh_jti= NULL, jwt_refresh_iat= NULL, role_id=NULL WHERE user_id=%(user_id)s AND tenant_id=%(tenant_id)s;""", - {"user_id": id_to_delete, "tenant_id": tenant_id})) + {"user_id": id_to_delete, "tenant_id": tenant_id}, + ) + ) cur.execute( - cur.mogrify(f"""UPDATE public.basic_authentication + cur.mogrify( + """UPDATE public.basic_authentication SET password= NULL, invitation_token= NULL, invited_at= NULL, changed_at= NULL, change_pwd_expire_at= NULL, change_pwd_token= NULL WHERE user_id=%(user_id)s;""", - {"user_id": id_to_delete, "tenant_id": tenant_id})) + {"user_id": id_to_delete, "tenant_id": tenant_id}, + ) + ) return {"data": get_members(tenant_id=tenant_id)} @@ -618,11 +705,21 @@ def change_password(tenant_id, user_id, email, old_password, new_password): item = get(tenant_id=tenant_id, user_id=user_id) if item is None: return {"errors": ["access denied"]} - if item["origin"] is not None and config("enforce_SSO", cast=bool, default=False) \ - and not item["superAdmin"] and helper.is_saml2_available(): - return {"errors": ["Please use your SSO to change your password, enforced by admin"]} + if ( + item["origin"] is not None + and config("enforce_SSO", cast=bool, default=False) + and not item["superAdmin"] + and helper.is_saml2_available() + ): + return { + "errors": ["Please use your SSO to change your password, enforced by admin"] + } if item["origin"] is not None and item["hasPassword"] is False: - return {"errors": ["cannot change your password because you are logged-in from an SSO service"]} + return { + "errors": [ + "cannot change your password because you are logged-in from an SSO service" + ] + } if old_password == new_password: return {"errors": ["old and new password are the same"]} auth = authenticate(email, old_password, for_change_password=True) @@ -630,23 +727,7 @@ def change_password(tenant_id, user_id, email, old_password, new_password): return {"errors": ["wrong password"]} changes = {"password": new_password} user = update(tenant_id=tenant_id, user_id=user_id, changes=changes) - r = authenticate(user['email'], new_password) - - return { - "jwt": r.pop("jwt"), - "refreshToken": r.pop("refreshToken"), - "refreshTokenMaxAge": r.pop("refreshTokenMaxAge"), - "spotJwt": r.pop("spotJwt"), - "spotRefreshToken": r.pop("spotRefreshToken"), - "spotRefreshTokenMaxAge": r.pop("spotRefreshTokenMaxAge"), - "tenantId": tenant_id - } - - -def set_password_invitation(tenant_id, user_id, new_password): - changes = {"password": new_password} - user = update(tenant_id=tenant_id, user_id=user_id, changes=changes) - r = authenticate(user['email'], new_password) + r = authenticate(user["email"], new_password) return { "jwt": r.pop("jwt"), @@ -656,7 +737,23 @@ def set_password_invitation(tenant_id, user_id, new_password): "spotRefreshToken": r.pop("spotRefreshToken"), "spotRefreshTokenMaxAge": r.pop("spotRefreshTokenMaxAge"), "tenantId": tenant_id, - **r + } + + +def set_password_invitation(tenant_id, user_id, new_password): + changes = {"password": new_password} + user = update(tenant_id=tenant_id, user_id=user_id, changes=changes) + r = authenticate(user["email"], new_password) + + return { + "jwt": r.pop("jwt"), + "refreshToken": r.pop("refreshToken"), + "refreshTokenMaxAge": r.pop("refreshTokenMaxAge"), + "spotJwt": r.pop("spotJwt"), + "spotRefreshToken": r.pop("spotRefreshToken"), + "spotRefreshTokenMaxAge": r.pop("spotRefreshTokenMaxAge"), + "tenantId": tenant_id, + **r, } @@ -664,14 +761,15 @@ def email_exists(email): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + """SELECT count(user_id) FROM public.users WHERE email = %(email)s AND deleted_at IS NULL LIMIT 1;""", - {"email": email}) + {"email": email}, + ) ) r = cur.fetchone() return r["count"] > 0 @@ -681,14 +779,15 @@ def get_deleted_user_by_email(email): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + """SELECT * FROM public.users WHERE email = %(email)s AND deleted_at NOTNULL LIMIT 1;""", - {"email": email}) + {"email": email}, + ) ) r = cur.fetchone() return helper.dict_to_camel_case(r) @@ -698,7 +797,7 @@ def get_by_invitation_token(token, pass_token=None): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + """SELECT *, DATE_PART('day',timezone('utc'::text, now()) \ - COALESCE(basic_authentication.invited_at,'2000-01-01'::timestamp ))>=1 AS expired_invitation, @@ -707,7 +806,8 @@ def get_by_invitation_token(token, pass_token=None): FROM public.users INNER JOIN public.basic_authentication USING(user_id) WHERE invitation_token = %(token)s {"AND change_pwd_token = %(pass_token)s" if pass_token else ""} LIMIT 1;""", - {"token": token, "pass_token": pass_token}) + {"token": token, "pass_token": pass_token}, + ) ) r = cur.fetchone() return helper.dict_to_camel_case(r) @@ -717,7 +817,7 @@ def auth_exists(user_id, tenant_id, jwt_iat) -> bool: with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT user_id, + """SELECT user_id, EXTRACT(epoch FROM jwt_iat)::BIGINT AS jwt_iat, changed_at, service_account, @@ -728,26 +828,31 @@ def auth_exists(user_id, tenant_id, jwt_iat) -> bool: AND tenant_id = %(tenant_id)s AND deleted_at IS NULL LIMIT 1;""", - {"userId": user_id, "tenant_id": tenant_id}) + {"userId": user_id, "tenant_id": tenant_id}, + ) ) r = cur.fetchone() - return r is not None \ - and (r["service_account"] and not r["has_basic_auth"] - or r.get("jwt_iat") is not None \ - and (abs(jwt_iat - r["jwt_iat"]) <= 1)) + return r is not None and ( + r["service_account"] + and not r["has_basic_auth"] + or r.get("jwt_iat") is not None + and (abs(jwt_iat - r["jwt_iat"]) <= 1) + ) def refresh_auth_exists(user_id, tenant_id, jwt_jti=None): with pg_client.PostgresClient() as cur: cur.execute( - cur.mogrify(f"""SELECT user_id + cur.mogrify( + """SELECT user_id FROM public.users WHERE user_id = %(userId)s AND tenant_id= %(tenant_id)s AND deleted_at IS NULL AND jwt_refresh_jti = %(jwt_jti)s LIMIT 1;""", - {"userId": user_id, "tenant_id": tenant_id, "jwt_jti": jwt_jti}) + {"userId": user_id, "tenant_id": tenant_id, "jwt_jti": jwt_jti}, + ) ) r = cur.fetchone() return r is not None @@ -785,7 +890,8 @@ class RefreshSpotJWTs(FullLoginJWTs): def change_jwt_iat_jti(user_id): with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""UPDATE public.users + query = cur.mogrify( + """UPDATE public.users SET jwt_iat = timezone('utc'::text, now()-INTERVAL '10s'), jwt_refresh_jti = 0, jwt_refresh_iat = timezone('utc'::text, now()-INTERVAL '10s'), @@ -799,7 +905,8 @@ def change_jwt_iat_jti(user_id): EXTRACT (epoch FROM spot_jwt_iat)::BIGINT AS spot_jwt_iat, spot_jwt_refresh_jti, EXTRACT (epoch FROM spot_jwt_refresh_iat)::BIGINT AS spot_jwt_refresh_iat;""", - {"user_id": user_id}) + {"user_id": user_id}, + ) cur.execute(query) row = cur.fetchone() return FullLoginJWTs(**row) @@ -807,14 +914,16 @@ def change_jwt_iat_jti(user_id): def refresh_jwt_iat_jti(user_id): with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""UPDATE public.users + query = cur.mogrify( + """UPDATE public.users SET jwt_iat = timezone('utc'::text, now()-INTERVAL '10s'), jwt_refresh_jti = jwt_refresh_jti + 1 WHERE user_id = %(user_id)s RETURNING EXTRACT (epoch FROM jwt_iat)::BIGINT AS jwt_iat, jwt_refresh_jti, EXTRACT (epoch FROM jwt_refresh_iat)::BIGINT AS jwt_refresh_iat;""", - {"user_id": user_id}) + {"user_id": user_id}, + ) cur.execute(query) row = cur.fetchone() return RefreshLoginJWTs(**row) @@ -825,7 +934,7 @@ def authenticate(email, password, for_change_password=False) -> dict | bool | No return {"errors": ["must sign-in with SSO, enforced by admin"]} with pg_client.PostgresClient() as cur: query = cur.mogrify( - f"""SELECT + """SELECT users.user_id, users.tenant_id, users.role, @@ -845,19 +954,21 @@ def authenticate(email, password, for_change_password=False) -> dict | bool | No AND basic_authentication.user_id = (SELECT su.user_id FROM public.users AS su WHERE su.email=%(email)s AND su.deleted_at IS NULL LIMIT 1) AND (roles.role_id IS NULL OR roles.deleted_at IS NULL) LIMIT 1;""", - {"email": email, "password": password}) + {"email": email, "password": password}, + ) cur.execute(query) r = cur.fetchone() if r is None and helper.is_saml2_available(): query = cur.mogrify( - f"""SELECT 1 + """SELECT 1 FROM public.users WHERE users.email = %(email)s AND users.deleted_at IS NULL AND users.origin IS NOT NULL LIMIT 1;""", - {"email": email}) + {"email": email}, + ) cur.execute(query) if cur.fetchone() is not None: return {"errors": ["must sign-in with SSO"]} @@ -867,33 +978,51 @@ def authenticate(email, password, for_change_password=False) -> dict | bool | No return True r = helper.dict_to_camel_case(r) if r["serviceAccount"]: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, - detail="service account is not authorized to login") - elif config("enforce_SSO", cast=bool, default=False) and helper.is_saml2_available(): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="service account is not authorized to login", + ) + elif ( + config("enforce_SSO", cast=bool, default=False) + and helper.is_saml2_available() + ): return {"errors": ["must sign-in with SSO, enforced by admin"]} - j_r = change_jwt_iat_jti(user_id=r['userId']) + j_r = change_jwt_iat_jti(user_id=r["userId"]) response = { - "jwt": authorizers.generate_jwt(user_id=r['userId'], tenant_id=r['tenantId'], iat=j_r.jwt_iat, - aud=AUDIENCE), - "refreshToken": authorizers.generate_jwt_refresh(user_id=r['userId'], - tenant_id=r['tenantId'], - iat=j_r.jwt_refresh_iat, - aud=AUDIENCE, - jwt_jti=j_r.jwt_refresh_jti, - for_spot=False), + "jwt": authorizers.generate_jwt( + user_id=r["userId"], + tenant_id=r["tenantId"], + iat=j_r.jwt_iat, + aud=AUDIENCE, + ), + "refreshToken": authorizers.generate_jwt_refresh( + user_id=r["userId"], + tenant_id=r["tenantId"], + iat=j_r.jwt_refresh_iat, + aud=AUDIENCE, + jwt_jti=j_r.jwt_refresh_jti, + for_spot=False, + ), "refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int), "email": email, - "spotJwt": authorizers.generate_jwt(user_id=r['userId'], tenant_id=r['tenantId'], - iat=j_r.spot_jwt_iat, aud=spot.AUDIENCE, for_spot=True), - "spotRefreshToken": authorizers.generate_jwt_refresh(user_id=r['userId'], - tenant_id=r['tenantId'], - iat=j_r.spot_jwt_refresh_iat, - aud=spot.AUDIENCE, - jwt_jti=j_r.spot_jwt_refresh_jti, - for_spot=True), + "spotJwt": authorizers.generate_jwt( + user_id=r["userId"], + tenant_id=r["tenantId"], + iat=j_r.spot_jwt_iat, + aud=spot.AUDIENCE, + for_spot=True, + ), + "spotRefreshToken": authorizers.generate_jwt_refresh( + user_id=r["userId"], + tenant_id=r["tenantId"], + iat=j_r.spot_jwt_refresh_iat, + aud=spot.AUDIENCE, + jwt_jti=j_r.spot_jwt_refresh_jti, + for_spot=True, + ), "spotRefreshTokenMaxAge": config("JWT_SPOT_REFRESH_EXPIRATION", cast=int), - **r + **r, } return response @@ -904,7 +1033,7 @@ def get_user_role(tenant_id, user_id): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + """SELECT users.user_id, users.email, users.role, @@ -918,14 +1047,16 @@ def get_user_role(tenant_id, user_id): AND users.user_id=%(user_id)s AND users.tenant_id=%(tenant_id)s LIMIT 1""", - {"tenant_id": tenant_id, "user_id": user_id}) + {"tenant_id": tenant_id, "user_id": user_id}, + ) ) return helper.dict_to_camel_case(cur.fetchone()) def create_sso_user(tenant_id, email, admin, name, origin, role_id, internal_id=None): with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""\ + query = cur.mogrify( + """\ WITH u AS ( INSERT INTO public.users (tenant_id, email, role, name, data, origin, internal_id, role_id) VALUES (%(tenant_id)s, %(email)s, %(role)s, %(name)s, %(data)s, %(origin)s, %(internal_id)s, @@ -947,14 +1078,21 @@ def create_sso_user(tenant_id, email, admin, name, origin, role_id, internal_id= (CASE WHEN u.role = 'member' THEN TRUE ELSE FALSE END) AS member, origin FROM u;""", - {"tenant_id": tenant_id, "email": email, "internal_id": internal_id, - "role": "admin" if admin else "member", "name": name, "origin": origin, - "role_id": role_id, "data": json.dumps({"lastAnnouncementView": TimeUTC.now()})}) - cur.execute( - query + { + "tenant_id": tenant_id, + "email": email, + "internal_id": internal_id, + "role": "admin" if admin else "member", + "name": name, + "origin": origin, + "role_id": role_id, + "data": json.dumps({"lastAnnouncementView": TimeUTC.now()}), + }, ) + cur.execute(query) return helper.dict_to_camel_case(cur.fetchone()) + def create_scim_user( email, name, @@ -984,7 +1122,7 @@ def create_scim_user( "tenant_id": tenant_id, "email": email, "name": name, - } + }, ) ) return helper.dict_to_camel_case(cur.fetchone()) @@ -1001,7 +1139,7 @@ def soft_delete_scim_user_by_id(user_id, tenant_id): users.user_id = %(user_id)s AND users.tenant_id = %(tenant_id)s """, - {"user_id": user_id, "tenant_id": tenant_id} + {"user_id": user_id, "tenant_id": tenant_id}, ) ) @@ -1009,13 +1147,13 @@ def soft_delete_scim_user_by_id(user_id, tenant_id): def __hard_delete_user(user_id): with pg_client.PostgresClient() as cur: query = cur.mogrify( - f"""DELETE FROM public.users + """DELETE FROM public.users WHERE users.user_id = %(user_id)s AND users.deleted_at IS NOT NULL ;""", - {"user_id": user_id}) + {"user_id": user_id}, + ) cur.execute(query) - def logout(user_id: int): with pg_client.PostgresClient() as cur: query = cur.mogrify( @@ -1023,25 +1161,33 @@ def logout(user_id: int): SET jwt_iat = NULL, jwt_refresh_jti = NULL, jwt_refresh_iat = NULL, spot_jwt_iat = NULL, spot_jwt_refresh_jti = NULL, spot_jwt_refresh_iat = NULL WHERE user_id = %(user_id)s;""", - {"user_id": user_id}) + {"user_id": user_id}, + ) cur.execute(query) def refresh(user_id: int, tenant_id: int = -1) -> dict: j = refresh_jwt_iat_jti(user_id=user_id) return { - "jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=j.jwt_iat, - aud=AUDIENCE), - "refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=j.jwt_refresh_iat, - aud=AUDIENCE, jwt_jti=j.jwt_refresh_jti), - "refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int) - (j.jwt_iat - j.jwt_refresh_iat), + "jwt": authorizers.generate_jwt( + user_id=user_id, tenant_id=tenant_id, iat=j.jwt_iat, aud=AUDIENCE + ), + "refreshToken": authorizers.generate_jwt_refresh( + user_id=user_id, + tenant_id=tenant_id, + iat=j.jwt_refresh_iat, + aud=AUDIENCE, + jwt_jti=j.jwt_refresh_jti, + ), + "refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int) + - (j.jwt_iat - j.jwt_refresh_iat), } def authenticate_sso(email: str, internal_id: str): with pg_client.PostgresClient() as cur: query = cur.mogrify( - f"""SELECT + """SELECT users.user_id, users.tenant_id, users.role, @@ -1054,7 +1200,8 @@ def authenticate_sso(email: str, internal_id: str): service_account FROM public.users AS users WHERE users.email = %(email)s AND internal_id = %(internal_id)s;""", - {"email": email, "internal_id": internal_id}) + {"email": email, "internal_id": internal_id}, + ) cur.execute(query) r = cur.fetchone() @@ -1062,33 +1209,56 @@ def authenticate_sso(email: str, internal_id: str): if r is not None: r = helper.dict_to_camel_case(r) if r["serviceAccount"]: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, - detail="service account is not authorized to login") - j_r = change_jwt_iat_jti(user_id=r['userId']) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="service account is not authorized to login", + ) + j_r = change_jwt_iat_jti(user_id=r["userId"]) response = { - "jwt": authorizers.generate_jwt(user_id=r['userId'], tenant_id=r['tenantId'], iat=j_r.jwt_iat, - aud=AUDIENCE), - "refreshToken": authorizers.generate_jwt_refresh(user_id=r['userId'], tenant_id=r['tenantId'], - iat=j_r.jwt_refresh_iat, - aud=AUDIENCE, jwt_jti=j_r.jwt_refresh_jti), + "jwt": authorizers.generate_jwt( + user_id=r["userId"], + tenant_id=r["tenantId"], + iat=j_r.jwt_iat, + aud=AUDIENCE, + ), + "refreshToken": authorizers.generate_jwt_refresh( + user_id=r["userId"], + tenant_id=r["tenantId"], + iat=j_r.jwt_refresh_iat, + aud=AUDIENCE, + jwt_jti=j_r.jwt_refresh_jti, + ), "refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int), - "spotJwt": authorizers.generate_jwt(user_id=r['userId'], tenant_id=r['tenantId'], - iat=j_r.spot_jwt_iat, aud=spot.AUDIENCE, for_spot=True), - "spotRefreshToken": authorizers.generate_jwt_refresh(user_id=r['userId'], - tenant_id=r['tenantId'], - iat=j_r.spot_jwt_refresh_iat, - aud=spot.AUDIENCE, - jwt_jti=j_r.spot_jwt_refresh_jti, for_spot=True), - "spotRefreshTokenMaxAge": config("JWT_SPOT_REFRESH_EXPIRATION", cast=int) + "spotJwt": authorizers.generate_jwt( + user_id=r["userId"], + tenant_id=r["tenantId"], + iat=j_r.spot_jwt_iat, + aud=spot.AUDIENCE, + for_spot=True, + ), + "spotRefreshToken": authorizers.generate_jwt_refresh( + user_id=r["userId"], + tenant_id=r["tenantId"], + iat=j_r.spot_jwt_refresh_iat, + aud=spot.AUDIENCE, + jwt_jti=j_r.spot_jwt_refresh_jti, + for_spot=True, + ), + "spotRefreshTokenMaxAge": config("JWT_SPOT_REFRESH_EXPIRATION", cast=int), } return response - logger.warning(f"SSO user not found with email: {email} and internal_id: {internal_id}") + logger.warning( + f"SSO user not found with email: {email} and internal_id: {internal_id}" + ) return None -def restore_sso_user(user_id, tenant_id, email, admin, name, origin, role_id, internal_id=None): +def restore_sso_user( + user_id, tenant_id, email, admin, name, origin, role_id, internal_id=None +): with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""\ + query = cur.mogrify( + """\ WITH u AS ( UPDATE public.users SET tenant_id= %(tenant_id)s, @@ -1128,13 +1298,19 @@ def restore_sso_user(user_id, tenant_id, email, admin, name, origin, role_id, in (CASE WHEN u.role = 'member' THEN TRUE ELSE FALSE END) AS member, origin FROM u;""", - {"tenant_id": tenant_id, "email": email, "internal_id": internal_id, - "role": "admin" if admin else "member", "name": name, "origin": origin, - "role_id": role_id, "data": json.dumps({"lastAnnouncementView": TimeUTC.now()}), - "user_id": user_id}) - cur.execute( - query + { + "tenant_id": tenant_id, + "email": email, + "internal_id": internal_id, + "role": "admin" if admin else "member", + "name": name, + "origin": origin, + "role_id": role_id, + "data": json.dumps({"lastAnnouncementView": TimeUTC.now()}), + "user_id": user_id, + }, ) + cur.execute(query) return helper.dict_to_camel_case(cur.fetchone()) @@ -1161,23 +1337,25 @@ def restore_scim_user( SELECT * FROM u; """, - {"tenant_id": tenant_id, "user_id": user_id} + {"tenant_id": tenant_id, "user_id": user_id}, ) ) return helper.dict_to_camel_case(cur.fetchone()) + def get_user_settings(user_id): # read user settings from users.settings:jsonb column with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""SELECT + """SELECT settings FROM public.users WHERE users.deleted_at IS NULL AND users.user_id=%(user_id)s LIMIT 1""", - {"user_id": user_id}) + {"user_id": user_id}, + ) ) return helper.dict_to_camel_case(cur.fetchone()) @@ -1209,11 +1387,12 @@ def update_user_settings(user_id, settings): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( - f"""UPDATE public.users + """UPDATE public.users SET settings = %(settings)s WHERE users.user_id = %(user_id)s AND deleted_at IS NULL RETURNING settings;""", - {"user_id": user_id, "settings": json.dumps(settings)}) + {"user_id": user_id, "settings": json.dumps(settings)}, + ) ) return helper.dict_to_camel_case(cur.fetchone()) diff --git a/ee/api/chalicelib/utils/scim_auth.py b/ee/api/chalicelib/utils/scim_auth.py index fb73b9dbb..a8deaa136 100644 --- a/ee/api/chalicelib/utils/scim_auth.py +++ b/ee/api/chalicelib/utils/scim_auth.py @@ -13,8 +13,8 @@ REFRESH_SECRET_KEY = config("SCIM_REFRESH_SECRET_KEY") ALGORITHM = config("SCIM_JWT_ALGORITHM") ACCESS_TOKEN_EXPIRE_SECONDS = int(config("SCIM_ACCESS_TOKEN_EXPIRE_SECONDS")) REFRESH_TOKEN_EXPIRE_SECONDS = int(config("SCIM_REFRESH_TOKEN_EXPIRE_SECONDS")) -AUDIENCE="okta_client" -ISSUER=config("JWT_ISSUER"), +AUDIENCE = "okta_client" +ISSUER = (config("JWT_ISSUER"),) # Simulated Okta Client Credentials # OKTA_CLIENT_ID = "okta-client" @@ -23,7 +23,7 @@ ISSUER=config("JWT_ISSUER"), # class TokenRequest(BaseModel): # client_id: str # client_secret: str - + # async def authenticate_client(token_request: TokenRequest): # """Validate Okta Client Credentials and issue JWT""" # if token_request.client_id != OKTA_CLIENT_ID or token_request.client_secret != OKTA_CLIENT_SECRET: @@ -31,6 +31,7 @@ ISSUER=config("JWT_ISSUER"), # return {"access_token": create_jwt(), "token_type": "bearer"} + def create_tokens(tenant_id): curr_time = time.time() access_payload = { @@ -38,7 +39,7 @@ def create_tokens(tenant_id): "sub": "scim_server", "aud": AUDIENCE, "iss": ISSUER, - "exp": "" + "exp": "", } access_payload.update({"exp": curr_time + ACCESS_TOKEN_EXPIRE_SECONDS}) access_token = jwt.encode(access_payload, ACCESS_SECRET_KEY, algorithm=ALGORITHM) @@ -49,18 +50,24 @@ def create_tokens(tenant_id): return access_token, refresh_token + def verify_access_token(token: str): try: - payload = jwt.decode(token, ACCESS_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE) + payload = jwt.decode( + token, ACCESS_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE + ) return payload except jwt.ExpiredSignatureError: raise HTTPException(status_code=401, detail="Token expired") except jwt.InvalidTokenError: raise HTTPException(status_code=401, detail="Invalid token") + def verify_refresh_token(token: str): try: - payload = jwt.decode(token, REFRESH_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE) + payload = jwt.decode( + token, REFRESH_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE + ) return payload except jwt.ExpiredSignatureError: raise HTTPException(status_code=401, detail="Token expired") @@ -69,6 +76,8 @@ def verify_refresh_token(token: str): required_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + # Authentication Dependency def auth_required(token: str = Depends(required_oauth2_scheme)): """Dependency to check Authorization header.""" @@ -78,6 +87,8 @@ def auth_required(token: str = Depends(required_oauth2_scheme)): optional_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) + + def auth_optional(token: str | None = Depends(optional_oauth2_scheme)): if token is None: return None diff --git a/ee/api/routers/scim.py b/ee/api/routers/scim.py index ec287b511..40230c14f 100644 --- a/ee/api/routers/scim.py +++ b/ee/api/routers/scim.py @@ -1,7 +1,5 @@ import logging -import re -import uuid -from typing import Any, Literal, Optional +from typing import Any, Literal import copy from datetime import datetime @@ -9,11 +7,15 @@ from decouple import config from fastapi import Depends, HTTPException, Header, Query, Response, Request from fastapi.responses import JSONResponse from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm -from pydantic import BaseModel, Field, field_serializer +from pydantic import BaseModel, field_serializer -import schemas -from chalicelib.core import users, roles, tenants -from chalicelib.utils.scim_auth import auth_optional, auth_required, create_tokens, verify_refresh_token +from chalicelib.core import users, tenants +from chalicelib.utils.scim_auth import ( + auth_optional, + auth_required, + create_tokens, + verify_refresh_token, +) from routers.base import get_routers from routers.scim_constants import RESOURCE_TYPES, SCHEMAS, SERVICE_PROVIDER_CONFIG from routers import scim_helpers @@ -26,29 +28,41 @@ public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2") """Authentication endpoints""" + class RefreshRequest(BaseModel): refresh_token: str + oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + # Login endpoint to generate tokens @public_app.post("/token") -async def login(host: str = Header(..., alias="Host"), form_data: OAuth2PasswordRequestForm = Depends()): +async def login( + host: str = Header(..., alias="Host"), + form_data: OAuth2PasswordRequestForm = Depends(), +): subdomain = host.split(".")[0] # Missing authentication part, to add - if form_data.username != config("SCIM_USER") or form_data.password != config("SCIM_PASSWORD"): + if form_data.username != config("SCIM_USER") or form_data.password != config( + "SCIM_PASSWORD" + ): raise HTTPException(status_code=401, detail="Invalid credentials") tenant = tenants.get_by_name(subdomain) access_token, refresh_token = create_tokens(tenant_id=tenant["tenantId"]) - return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"} + return { + "access_token": access_token, + "refresh_token": refresh_token, + "token_type": "bearer", + } + # Refresh token endpoint @public_app.post("/refresh") async def refresh_token(r: RefreshRequest): - payload = verify_refresh_token(r.refresh_token) new_access_token, _ = create_tokens(tenant_id=payload["tenant_id"]) @@ -68,7 +82,7 @@ def _not_found_error_response(resource_id: str): "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], "detail": f"Resource {resource_id} not found", "status": "404", - } + }, ) @@ -80,7 +94,7 @@ def _uniqueness_error_response(): "detail": "One or more of the attribute values are already in use or are reserved.", "status": "409", "scimType": "uniqueness", - } + }, ) @@ -92,7 +106,7 @@ def _mutability_error_response(): "detail": "The attempted modification is not compatible with the target attribute's mutability or current state.", "status": "400", "scimType": "mutability", - } + }, ) @@ -105,7 +119,7 @@ async def get_resource_types(filter_param: str | None = Query(None, alias="filte "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], "detail": "Operation is not permitted based on the supplied authorization", "status": "403", - } + }, ) return JSONResponse( status_code=200, @@ -130,8 +144,7 @@ async def get_resource_type(resource_id: str): SCHEMA_IDS_TO_SCHEMA_DETAILS = { - schema_detail["id"]: schema_detail - for schema_detail in SCHEMAS + schema_detail["id"]: schema_detail for schema_detail in SCHEMAS } @@ -144,7 +157,7 @@ async def get_schemas(filter_param: str | None = Query(None, alias="filter")): "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], "detail": "Operation is not permitted based on the supplied authorization", "status": "403", - } + }, ) return JSONResponse( status_code=200, @@ -154,9 +167,8 @@ async def get_schemas(filter_param: str | None = Query(None, alias="filter")): "startIndex": 1, "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], "Resources": [ - value - for _, value in sorted(SCHEMA_IDS_TO_SCHEMA_DETAILS.items()) - ] + value for _, value in sorted(SCHEMA_IDS_TO_SCHEMA_DETAILS.items()) + ], }, ) @@ -174,7 +186,9 @@ async def get_schema(schema_id: str): # note(jon): it was recommended to make this endpoint partially open # so that clients can view the `authenticationSchemes` prior to being authenticated. @public_app.get("/ServiceProviderConfig") -async def get_service_provider_config(r: Request, tenant_id: str | None = Depends(auth_optional)): +async def get_service_provider_config( + r: Request, tenant_id: str | None = Depends(auth_optional) +): content = copy.deepcopy(SERVICE_PROVIDER_CONFIG) content["meta"]["location"] = str(r.url) is_authenticated = tenant_id is not None @@ -193,6 +207,8 @@ async def get_service_provider_config(r: Request, tenant_id: str | None = Depend """ User endpoints """ + + class UserRequest(BaseModel): userName: str @@ -203,7 +219,9 @@ class PatchUserRequest(BaseModel): class ResourceMetaResponse(BaseModel): - resourceType: Literal["ServiceProviderConfig", "ResourceType", "Schema", "User"] | None = None + resourceType: ( + Literal["ServiceProviderConfig", "ResourceType", "Schema", "User"] | None + ) = None created: datetime | None = None lastModified: datetime | None = None location: str | None = None @@ -231,12 +249,16 @@ class CommonResourceResponse(BaseModel): class UserResponse(CommonResourceResponse): - schemas: list[Literal["urn:ietf:params:scim:schemas:core:2.0:User"]] = ["urn:ietf:params:scim:schemas:core:2.0:User"] + schemas: list[Literal["urn:ietf:params:scim:schemas:core:2.0:User"]] = [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ] userName: str | None = None class QueryResourceResponse(BaseModel): - schemas: list[Literal["urn:ietf:params:scim:api:messages:2.0:ListResponse"]] = ["urn:ietf:params:scim:api:messages:2.0:ListResponse"] + schemas: list[Literal["urn:ietf:params:scim:api:messages:2.0:ListResponse"]] = [ + "urn:ietf:params:scim:api:messages:2.0:ListResponse" + ] totalResults: int # todo(jon): add the other schemas Resources: list[UserResponse] @@ -247,21 +269,33 @@ class QueryResourceResponse(BaseModel): MAX_USERS_PER_PAGE = 10 -def _convert_db_user_to_scim_user(db_user: dict[str, Any], attributes: list[str] | None = None, excluded_attributes: list[str] | None = None) -> UserResponse: - user_schema = SCHEMA_IDS_TO_SCHEMA_DETAILS["urn:ietf:params:scim:schemas:core:2.0:User"] +def _convert_db_user_to_scim_user( + db_user: dict[str, Any], + attributes: list[str] | None = None, + excluded_attributes: list[str] | None = None, +) -> UserResponse: + user_schema = SCHEMA_IDS_TO_SCHEMA_DETAILS[ + "urn:ietf:params:scim:schemas:core:2.0:User" + ] all_attributes = scim_helpers.get_all_attribute_names(user_schema) attributes = attributes or all_attributes - always_returned_attributes = scim_helpers.get_all_attribute_names_where_returned_is_always(user_schema) + always_returned_attributes = ( + scim_helpers.get_all_attribute_names_where_returned_is_always(user_schema) + ) included_attributes = list(set(attributes).union(set(always_returned_attributes))) excluded_attributes = excluded_attributes or [] - excluded_attributes = list(set(excluded_attributes).difference(set(always_returned_attributes))) + excluded_attributes = list( + set(excluded_attributes).difference(set(always_returned_attributes)) + ) scim_user = { "id": str(db_user["userId"]), "meta": { "resourceType": "User", "created": db_user["createdAt"], - "lastModified": db_user["createdAt"], # todo(jon): we currently don't keep track of this in the db - "location": f"Users/{db_user['userId']}" + "lastModified": db_user[ + "createdAt" + ], # todo(jon): we currently don't keep track of this in the db + "location": f"Users/{db_user['userId']}", }, "userName": db_user["email"], } @@ -272,14 +306,16 @@ def _convert_db_user_to_scim_user(db_user: dict[str, Any], attributes: list[str] @public_app.get("/Users") async def get_users( - tenant_id = Depends(auth_required), + tenant_id=Depends(auth_required), requested_start_index: int = Query(1, alias="startIndex"), requested_items_per_page: int | None = Query(None, alias="count"), attributes: list[str] | None = Query(None), excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"), ): start_index = max(1, requested_start_index) - items_per_page = min(max(0, requested_items_per_page or MAX_USERS_PER_PAGE), MAX_USERS_PER_PAGE) + items_per_page = min( + max(0, requested_items_per_page or MAX_USERS_PER_PAGE), MAX_USERS_PER_PAGE + ) # todo(jon): this might not be the most efficient thing to do. could be better to just do a count. # but this is the fastest thing at the moment just to test that it's working total_users = users.get_users_paginated(1, tenant_id) @@ -302,7 +338,7 @@ async def get_users( @public_app.get("/Users/{user_id}") def get_user( user_id: str, - tenant_id = Depends(auth_required), + tenant_id=Depends(auth_required), attributes: list[str] | None = Query(None), excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"), ): @@ -311,13 +347,12 @@ def get_user( return _not_found_error_response(user_id) scim_user = _convert_db_user_to_scim_user(db_user, attributes, excluded_attributes) return JSONResponse( - status_code=200, - content=scim_user.model_dump(mode="json", exclude_none=True) + status_code=200, content=scim_user.model_dump(mode="json", exclude_none=True) ) @public_app.post("/Users") -async def create_user(r: UserRequest, tenant_id = Depends(auth_required)): +async def create_user(r: UserRequest, tenant_id=Depends(auth_required)): # note(jon): this method will return soft deleted users as well existing_db_user = users.get_existing_scim_user_by_unique_values(r.userName) if existing_db_user and existing_db_user["deletedAt"] is None: @@ -334,23 +369,26 @@ async def create_user(r: UserRequest, tenant_id = Depends(auth_required)): ) scim_user = _convert_db_user_to_scim_user(db_user) response = JSONResponse( - status_code=201, - content=scim_user.model_dump(mode="json", exclude_none=True) + status_code=201, content=scim_user.model_dump(mode="json", exclude_none=True) ) response.headers["Location"] = scim_user.meta.location return response @public_app.put("/Users/{user_id}") -def update_user(user_id: str, r: UserRequest, tenant_id = Depends(auth_required)): +def update_user(user_id: str, r: UserRequest, tenant_id=Depends(auth_required)): db_resource = users.get_scim_user_by_id(user_id, tenant_id) if not db_resource: return _not_found_error_response(user_id) - current_scim_resource = _convert_db_user_to_scim_user(db_resource).model_dump(mode="json", exclude_none=True) + current_scim_resource = _convert_db_user_to_scim_user(db_resource).model_dump( + mode="json", exclude_none=True + ) changes = r.model_dump(mode="json") schema = SCHEMA_IDS_TO_SCHEMA_DETAILS["urn:ietf:params:scim:schemas:core:2.0:User"] try: - valid_mutable_changes = scim_helpers.filter_mutable_attributes(schema, changes, current_scim_resource) + valid_mutable_changes = scim_helpers.filter_mutable_attributes( + schema, changes, current_scim_resource + ) except ValueError: # todo(jon): will need to add a test for this once we have an immutable field return _mutability_error_response() @@ -371,7 +409,7 @@ def update_user(user_id: str, r: UserRequest, tenant_id = Depends(auth_required) @public_app.delete("/Users/{user_id}") -def delete_user(user_id: str, tenant_id = Depends(auth_required)): +def delete_user(user_id: str, tenant_id=Depends(auth_required)): user = users.get_scim_user_by_id(user_id, tenant_id) if not user: return _not_found_error_response(user_id) diff --git a/ee/api/routers/scim_constants.py b/ee/api/routers/scim_constants.py index 5a6256ee3..74e00ee01 100644 --- a/ee/api/routers/scim_constants.py +++ b/ee/api/routers/scim_constants.py @@ -1,22 +1,22 @@ # note(jon): please see https://datatracker.ietf.org/doc/html/rfc7643 for details on these constants -from typing import Any, Literal +from typing import Any def _attribute_characteristics( - name: str, - description: str, - type: str="string", - sub_attributes: dict[str, Any] | None=None, - # note(jon): no default for multiValued is defined in the docs and it is marked as optional. - # from our side, we'll default it to False. - multi_valued: bool=False, - required: bool=False, - canonical_values: list[str] | None=None, - case_exact: bool=False, - mutability: str="readWrite", - returned: str="default", - uniqueness: str="none", - reference_types: list[str] | None=None, + name: str, + description: str, + type: str = "string", + sub_attributes: dict[str, Any] | None = None, + # note(jon): no default for multiValued is defined in the docs and it is marked as optional. + # from our side, we'll default it to False. + multi_valued: bool = False, + required: bool = False, + canonical_values: list[str] | None = None, + case_exact: bool = False, + mutability: str = "readWrite", + returned: str = "default", + uniqueness: str = "none", + reference_types: list[str] | None = None, ): characteristics = { "name": name, @@ -33,14 +33,16 @@ def _attribute_characteristics( "referenceTypes": reference_types, } characteristics_without_none = { - key: value - for key, value in characteristics.items() - if value is not None + key: value for key, value in characteristics.items() if value is not None } return characteristics_without_none -def _multi_valued_attributes(type_canonical_values: list[str], type_required: bool=False, type_mutability="readWrite"): +def _multi_valued_attributes( + type_canonical_values: list[str], + type_required: bool = False, + type_mutability="readWrite", +): return [ _attribute_characteristics( name="type", @@ -68,7 +70,7 @@ def _multi_valued_attributes(type_canonical_values: list[str], type_required: bo name="$ref", type="reference", reference_types=["uri"], - description="The reference URI of a target resource." + description="The reference URI of a target resource.", ), ] @@ -77,7 +79,7 @@ def _multi_valued_attributes(type_canonical_values: list[str], type_required: bo # in section 3.1 of RFC7643, it is specified that ResourceType and # ServiceProviderConfig are not included in the common attributes. but # in other references, they treat them as a resource. -def _common_resource_attributes(id_required: bool=True, id_uniqueness: str="none"): +def _common_resource_attributes(id_required: bool = True, id_uniqueness: str = "none"): return [ _attribute_characteristics( name="id", @@ -151,7 +153,6 @@ def _common_resource_attributes(id_required: bool=True, id_uniqueness: str="none ] - SERVICE_PROVIDER_CONFIG_SCHEMA = { "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"], "id": "urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig", @@ -339,7 +340,7 @@ SERVICE_PROVIDER_CONFIG_SCHEMA = { ), ], ), - ] + ], } @@ -409,9 +410,9 @@ RESOURCE_TYPE_SCHEMA = { required=True, mutability="readOnly", ), - ] + ], ), - ] + ], } SCHEMA_SCHEMA = { @@ -548,7 +549,7 @@ SCHEMA_SCHEMA = { canonical_values=[ # todo(jon): add "User" and "Group" once those are done. "external", - "uri" + "uri", ], case_exact=True, ), @@ -659,15 +660,15 @@ SCHEMA_SCHEMA = { canonical_values=[ # todo(jon): add "User" and "Group" once those are done. "external", - "uri" + "uri", ], case_exact=True, ), ], ), - ] - ) - ] + ], + ), + ], } @@ -749,7 +750,7 @@ SERVICE_PROVIDER_CONFIG = { # and then updating these timestamps from an api and such. for now, if we update # the configuration, we should update the timestamp here. "lastModified": "2025-04-15T15:45:00Z", - "location": "", # note(jon): this field will be computed in the /ServiceProviderConfig endpoint + "location": "", # note(jon): this field will be computed in the /ServiceProviderConfig endpoint }, } diff --git a/ee/api/routers/scim_helpers.py b/ee/api/routers/scim_helpers.py index 7d1cf4b95..6c04ecab8 100644 --- a/ee/api/routers/scim_helpers.py +++ b/ee/api/routers/scim_helpers.py @@ -4,6 +4,7 @@ from copy import deepcopy def get_all_attribute_names(schema: dict[str, Any]) -> list[str]: result = [] + def _walk(attrs, prefix=None): for attr in attrs: name = attr["name"] @@ -12,12 +13,16 @@ def get_all_attribute_names(schema: dict[str, Any]) -> list[str]: if attr["type"] == "complex": sub = attr.get("subAttributes") or attr.get("attributes") or [] _walk(sub, path) + _walk(schema["attributes"]) return result -def get_all_attribute_names_where_returned_is_always(schema: dict[str, Any]) -> list[str]: +def get_all_attribute_names_where_returned_is_always( + schema: dict[str, Any], +) -> list[str]: result = [] + def _walk(attrs, prefix=None): for attr in attrs: name = attr["name"] @@ -27,11 +32,14 @@ def get_all_attribute_names_where_returned_is_always(schema: dict[str, Any]) -> if attr["type"] == "complex": sub = attr.get("subAttributes") or attr.get("attributes") or [] _walk(sub, path) + _walk(schema["attributes"]) return result -def filter_attributes(resource: dict[str, Any], include_list: list[str]) -> dict[str, Any]: +def filter_attributes( + resource: dict[str, Any], include_list: list[str] +) -> dict[str, Any]: result = {} for attr in include_list: parts = attr.split(".", 1) @@ -63,7 +71,9 @@ def filter_attributes(resource: dict[str, Any], include_list: list[str]) -> dict return result -def exclude_attributes(resource: dict[str, Any], exclude_list: list[str]) -> dict[str, Any]: +def exclude_attributes( + resource: dict[str, Any], exclude_list: list[str] +) -> dict[str, Any]: exclude_map = {} for attr in exclude_list: parts = attr.split(".", 1) @@ -105,7 +115,11 @@ def exclude_attributes(resource: dict[str, Any], exclude_list: list[str]) -> dic return new_resource -def filter_mutable_attributes(schema: dict[str, Any], requested_changes: dict[str, Any], current: dict[str, Any]) -> dict[str, Any]: +def filter_mutable_attributes( + schema: dict[str, Any], + requested_changes: dict[str, Any], + current_values: dict[str, Any], +) -> dict[str, Any]: attributes = {attr.get("name"): attr for attr in schema.get("attributes", [])} valid_changes = {}