From fc357facf7f8191e0efbd1a8b534ba261311d41a Mon Sep 17 00:00:00 2001 From: Pavel Kim Date: Fri, 14 Feb 2025 17:52:43 +0100 Subject: [PATCH 1/4] Add user/group SCIM endpoints --- ee/api/.gitignore | 1 + ee/api/app.py | 4 + ee/api/chalicelib/core/roles.py | 109 +++++++- ee/api/chalicelib/core/users.py | 222 ++++++++++++++++ ee/api/routers/scim.py | 438 ++++++++++++++++++++++++++++++++ 5 files changed, 772 insertions(+), 2 deletions(-) create mode 100644 ee/api/routers/scim.py diff --git a/ee/api/.gitignore b/ee/api/.gitignore index 80beeee41..7140a891d 100644 --- a/ee/api/.gitignore +++ b/ee/api/.gitignore @@ -283,3 +283,4 @@ Pipfile.lock /chalicelib/utils/contextual_validators.py /routers/subs/product_analytics.py /schemas/product_analytics.py +/ee/bin/* diff --git a/ee/api/app.py b/ee/api/app.py index 5b3af9d80..a9d9c59cd 100644 --- a/ee/api/app.py +++ b/ee/api/app.py @@ -26,6 +26,7 @@ from routers.subs import v1_api_ee if config("ENABLE_SSO", cast=bool, default=True): from routers import saml + from routers import scim loglevel = config("LOGLEVEL", default=logging.WARNING) print(f">Loglevel set to: {loglevel}") @@ -158,3 +159,6 @@ if config("ENABLE_SSO", cast=bool, default=True): app.include_router(saml.public_app) app.include_router(saml.app) app.include_router(saml.app_apikey) + app.include_router(scim.public_app) + app.include_router(scim.app) + app.include_router(scim.app_apikey) diff --git a/ee/api/chalicelib/core/roles.py b/ee/api/chalicelib/core/roles.py index 5d92fbbc6..a879ff613 100644 --- a/ee/api/chalicelib/core/roles.py +++ b/ee/api/chalicelib/core/roles.py @@ -1,3 +1,4 @@ +import json from typing import Optional from fastapi import HTTPException, status @@ -78,6 +79,21 @@ def update(tenant_id, user_id, role_id, data: schemas.RolePayloadSchema): return helper.dict_to_camel_case(row) +def update_group_name(tenant_id, group_id, name): + with pg_client.PostgresClient() as cur: + query = cur.mogrify("""UPDATE public.roles + SET name= %(name)s + WHERE roles.data->>'group_id' = %(group_id)s + AND tenant_id = %(tenant_id)s + AND deleted_at ISNULL + AND protected = FALSE + RETURNING *;""", + {"tenant_id": tenant_id, "group_id": group_id, "name": name }) + cur.execute(query=query) + row = cur.fetchone() + + return helper.dict_to_camel_case(row) + def create(tenant_id, user_id, data: schemas.RolePayloadSchema): admin = users.get(user_id=user_id, tenant_id=tenant_id) @@ -112,6 +128,35 @@ def create(tenant_id, user_id, data: schemas.RolePayloadSchema): row["projects"] = [r["project_id"] for r in cur.fetchall()] return helper.dict_to_camel_case(row) +def create_as_admin(tenant_id, group_id, data: schemas.RolePayloadSchema): + + if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=None): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"name already exists.") + + if not data.all_projects and (data.projects is None or len(data.projects) == 0): + return {"errors": ["must specify a project or all projects"]} + if data.projects is not None and len(data.projects) > 0 and not data.all_projects: + data.projects = projects.is_authorized_batch(project_ids=data.projects, tenant_id=tenant_id) + with pg_client.PostgresClient() as cur: + query = cur.mogrify("""INSERT INTO roles(tenant_id, name, description, permissions, all_projects, data) + VALUES (%(tenant_id)s, %(name)s, %(description)s, %(permissions)s::text[], %(all_projects)s, %(data)s) + RETURNING *;""", + {"tenant_id": tenant_id, "name": data.name, "description": data.description, + "permissions": data.permissions, "all_projects": data.all_projects, "data": json.dumps({ "group_id": group_id })}) + cur.execute(query=query) + row = cur.fetchone() + row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"]) + row["projects"] = [] + if not data.all_projects: + role_id = row["role_id"] + query = cur.mogrify(f"""INSERT INTO roles_projects(role_id, project_id) + VALUES {",".join(f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(data.projects)))} + RETURNING project_id;""", + {"role_id": role_id, **{f"project_id_{i}": p for i, p in enumerate(data.projects)}}) + cur.execute(query=query) + row["projects"] = [r["project_id"] for r in cur.fetchall()] + return helper.dict_to_camel_case(row) + def get_roles(tenant_id): with pg_client.PostgresClient() as cur: @@ -133,6 +178,27 @@ def get_roles(tenant_id): r["created_at"] = TimeUTC.datetime_to_timestamp(r["created_at"]) return helper.list_to_camel_case(rows) +def get_roles_with_uuid(tenant_id): + with pg_client.PostgresClient() as cur: + query = cur.mogrify("""SELECT roles.*, COALESCE(projects, '{}') AS projects + FROM public.roles + LEFT JOIN LATERAL (SELECT array_agg(project_id) AS projects + FROM roles_projects + INNER JOIN projects USING (project_id) + WHERE roles_projects.role_id = roles.role_id + AND projects.deleted_at ISNULL ) AS role_projects ON (TRUE) + WHERE tenant_id =%(tenant_id)s + AND data ? 'group_id' + AND deleted_at IS NULL + AND not service_role + ORDER BY role_id;""", + {"tenant_id": tenant_id}) + cur.execute(query=query) + rows = cur.fetchall() + for r in rows: + r["created_at"] = TimeUTC.datetime_to_timestamp(r["created_at"]) + return helper.list_to_camel_case(rows) + def get_role_by_name(tenant_id, name): with pg_client.PostgresClient() as cur: @@ -155,7 +221,7 @@ def delete(tenant_id, user_id, role_id): if not admin["admin"] and not admin["superAdmin"]: return {"errors": ["unauthorized"]} with pg_client.PostgresClient() as cur: - query = cur.mogrify("""SELECT 1 + query = cur.mogrify("""SELECT 1 FROM public.roles WHERE role_id = %(role_id)s AND tenant_id = %(tenant_id)s @@ -165,7 +231,7 @@ def delete(tenant_id, user_id, role_id): cur.execute(query=query) if cur.fetchone() is not None: return {"errors": ["this role is protected"]} - query = cur.mogrify("""SELECT 1 + query = cur.mogrify("""SELECT 1 FROM public.users WHERE role_id = %(role_id)s AND tenant_id = %(tenant_id)s @@ -183,6 +249,29 @@ def delete(tenant_id, user_id, role_id): cur.execute(query=query) return get_roles(tenant_id=tenant_id) +def delete_scim_group(tenant_id, group_uuid): + + with pg_client.PostgresClient() as cur: + query = cur.mogrify("""SELECT 1 + FROM public.roles + WHERE data->>'group_id' = %(group_uuid)s + AND tenant_id = %(tenant_id)s + AND protected = TRUE + LIMIT 1;""", + {"tenant_id": tenant_id, "group_uuid": group_uuid}) + cur.execute(query) + if cur.fetchone() is not None: + return {"errors": ["this role is protected"]} + + query = cur.mogrify( + f"""DELETE FROM public.roles + WHERE roles.data->>'group_id' = %(group_uuid)s;""", # removed this: AND users.deleted_at IS NOT NULL + {"group_uuid": group_uuid}) + cur.execute(query) + + return get_roles(tenant_id=tenant_id) + + def get_role(tenant_id, role_id): with pg_client.PostgresClient() as cur: @@ -199,3 +288,19 @@ def get_role(tenant_id, role_id): if row is not None: row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"]) return helper.dict_to_camel_case(row) + +def get_role_by_group_id(tenant_id, group_id): + with pg_client.PostgresClient() as cur: + query = cur.mogrify("""SELECT roles.* + FROM public.roles + WHERE tenant_id =%(tenant_id)s + AND deleted_at IS NULL + AND not service_role + AND data->>'group_id' = %(group_id)s + LIMIT 1;""", + {"tenant_id": tenant_id, "group_id": group_id}) + cur.execute(query=query) + row = cur.fetchone() + if row is not None: + row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"]) + return helper.dict_to_camel_case(row) \ No newline at end of file diff --git a/ee/api/chalicelib/core/users.py b/ee/api/chalicelib/core/users.py index d36f91227..da6e96531 100644 --- a/ee/api/chalicelib/core/users.py +++ b/ee/api/chalicelib/core/users.py @@ -5,6 +5,7 @@ from typing import Optional from decouple import config from fastapi import BackgroundTasks, HTTPException +from psycopg2.extras import Json from pydantic import BaseModel, model_validator from starlette import status @@ -161,9 +162,13 @@ def update(tenant_id, user_id, changes, output=True): sub_query_users.append("""role_id=(SELECT COALESCE((SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND role_id = %(role_id)s), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name = 'Member' LIMIT 1), (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name != 'Owner' LIMIT 1)))""") + elif key == "data": # this is hardcoded, maybe a generic solution would be better + sub_query_users.append(f"data = data || %({(key)})s") else: sub_query_users.append(f"{helper.key_to_snake_case(key)} = %({key})s") changes["role_id"] = changes.get("roleId", changes.get("role_id")) + if "data" in changes: + changes["data"] = Json(changes["data"]) with pg_client.PostgresClient() as cur: if len(sub_query_users) > 0: query = cur.mogrify(f"""\ @@ -278,6 +283,42 @@ def get(user_id, tenant_id): ) r = cur.fetchone() return helper.dict_to_camel_case(r) + +def get_by_uuid(user_uuid, tenant_id): + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + f"""SELECT + users.user_id, + users.tenant_id, + email, + role, + users.name, + users.data, + users.internal_id, + (CASE WHEN role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, + (CASE WHEN role = 'admin' THEN TRUE ELSE FALSE END) AS admin, + (CASE WHEN role = 'member' THEN TRUE ELSE FALSE END) AS member, + origin, + role_id, + roles.name AS role_name, + roles.permissions, + roles.all_projects, + basic_authentication.password IS NOT NULL AS has_password, + users.service_account + FROM public.users LEFT JOIN public.basic_authentication ON users.user_id=basic_authentication.user_id + LEFT JOIN public.roles USING (role_id) + WHERE + users.data->>'user_id' = %(user_uuid)s + AND users.tenant_id = %(tenant_id)s + AND users.deleted_at IS NULL + AND (roles.role_id IS NULL OR roles.deleted_at IS NULL AND roles.tenant_id = %(tenant_id)s) + LIMIT 1;""", + {"user_uuid": user_uuid, "tenant_id": tenant_id}) + ) + r = cur.fetchone() + return helper.dict_to_camel_case(r) + def generate_new_api_key(user_id): @@ -405,6 +446,68 @@ def get_by_email_only(email): r = cur.fetchone() return helper.dict_to_camel_case(r) +def get_by_email_with_uuid(email): + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + f"""SELECT + users.user_id, + users.tenant_id, + users.email, + users.role, + users.name, + users.data, + (CASE WHEN users.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, + (CASE WHEN users.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, + (CASE WHEN users.role = 'member' THEN TRUE ELSE FALSE END) AS member, + origin, + basic_authentication.password IS NOT NULL AS has_password, + role_id, + internal_id, + roles.name AS role_name + FROM public.users LEFT JOIN public.basic_authentication USING(user_id) + INNER JOIN public.roles USING(role_id) + WHERE users.email = %(email)s + AND users.deleted_at IS NULL + LIMIT 1;""", + {"email": email}) + ) + r = cur.fetchone() + return helper.dict_to_camel_case(r) + + +def get_users_paginated(start_index, count): + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + f"""SELECT + users.user_id AS id, + users.tenant_id, + users.email AS email, + users.data AS data, + users.role, + users.name AS name, + (CASE WHEN users.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, + (CASE WHEN users.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, + (CASE WHEN users.role = 'member' THEN TRUE ELSE FALSE END) AS member, + origin, + basic_authentication.password IS NOT NULL AS has_password, + role_id, + internal_id, + roles.name AS role_name + FROM public.users LEFT JOIN public.basic_authentication USING(user_id) + INNER JOIN public.roles USING(role_id) + WHERE users.deleted_at IS NULL AND users.data ? 'user_id' + LIMIT %(count)s + OFFSET %(startIndex)s;""", + {"startIndex": start_index - 1, "count": count}) + ) + r = cur.fetchall() + if len(r): + r = helper.list_to_camel_case(r) + return r + return [] + def get_member(tenant_id, user_id): with pg_client.PostgresClient() as cur: @@ -519,6 +622,70 @@ def delete_member(user_id, tenant_id, id_to_delete): return {"data": get_members(tenant_id=tenant_id)} +def delete_member_as_admin(tenant_id, id_to_delete): + + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + f"""SELECT + users.user_id AS user_id, + users.tenant_id, + email, + role, + users.name, + origin, + role_id, + roles.name AS role_name, + (CASE WHEN role = 'member' THEN TRUE ELSE FALSE END) AS member, + roles.permissions, + roles.all_projects, + basic_authentication.password IS NOT NULL AS has_password, + users.service_account + FROM public.users LEFT JOIN public.basic_authentication ON users.user_id=basic_authentication.user_id + LEFT JOIN public.roles USING (role_id) + WHERE + role = 'owner' + AND users.tenant_id = %(tenant_id)s + AND users.deleted_at IS NULL + AND (roles.role_id IS NULL OR roles.deleted_at IS NULL AND roles.tenant_id = %(tenant_id)s) + LIMIT 1;""", + {"tenant_id": tenant_id, "user_uuid": id_to_delete}) + ) + r = cur.fetchone() + + if r["user_id"] == id_to_delete: + return {"errors": ["unauthorized, cannot delete self"]} + + if r["member"]: + return {"errors": ["unauthorized"]} + + to_delete = get(user_id=id_to_delete, tenant_id=tenant_id) + if to_delete is None: + return {"errors": ["not found"]} + + if to_delete["superAdmin"]: + return {"errors": ["cannot delete super admin"]} + + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify(f"""UPDATE public.users + SET deleted_at = timezone('utc'::text, now()), + jwt_iat= NULL, jwt_refresh_jti= NULL, + jwt_refresh_iat= NULL, + role_id=NULL + WHERE user_id=%(user_id)s AND tenant_id=%(tenant_id)s;""", + {"user_id": id_to_delete, "tenant_id": tenant_id})) + cur.execute( + cur.mogrify(f"""UPDATE public.basic_authentication + SET password= NULL, invitation_token= NULL, + invited_at= NULL, changed_at= NULL, + change_pwd_expire_at= NULL, change_pwd_token= NULL + WHERE user_id=%(user_id)s;""", + {"user_id": id_to_delete, "tenant_id": tenant_id})) + return {"data": get_members(tenant_id=tenant_id)} + + + def change_password(tenant_id, user_id, email, old_password, new_password): item = get(tenant_id=tenant_id, user_id=user_id) if item is None: @@ -859,6 +1026,53 @@ def create_sso_user(tenant_id, email, admin, name, origin, role_id, internal_id= query ) return helper.dict_to_camel_case(cur.fetchone()) + +def create_scim_user( + tenant_id, + user_uuid, + username, + admin, + display_name, + full_name: dict, + emails, + origin, + locale, + role_id, + internal_id=None, +): + + with pg_client.PostgresClient() as cur: + query = cur.mogrify(f"""\ + WITH u AS ( + INSERT INTO public.users (tenant_id, email, role, name, data, origin, internal_id, role_id) + VALUES (%(tenant_id)s, %(email)s, %(role)s, %(name)s, %(data)s, %(origin)s, %(internal_id)s, + (SELECT COALESCE((SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND role_id = %(role_id)s), + (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name = 'Member' LIMIT 1), + (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name != 'Owner' LIMIT 1)))) + RETURNING * + ), + au AS ( + INSERT INTO public.basic_authentication(user_id) + VALUES ((SELECT user_id FROM u)) + ) + SELECT u.user_id AS id, + u.email, + u.role, + u.name, + u.data, + (CASE WHEN u.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, + (CASE WHEN u.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, + (CASE WHEN u.role = 'member' THEN TRUE ELSE FALSE END) AS member, + origin + FROM u;""", + {"tenant_id": tenant_id, "email": username, "internal_id": internal_id, + "role": "admin" if admin else "member", "name": display_name, "origin": origin, + "role_id": role_id, "data": json.dumps({"lastAnnouncementView": TimeUTC.now(), "user_id": user_uuid, "locale": locale, "name": full_name, "emails": emails})}) + cur.execute( + query + ) + return helper.dict_to_camel_case(cur.fetchone()) + def __hard_delete_user(user_id): @@ -869,6 +1083,14 @@ def __hard_delete_user(user_id): {"user_id": user_id}) cur.execute(query) +def __hard_delete_user_uuid(user_uuid): + with pg_client.PostgresClient() as cur: + query = cur.mogrify( + f"""DELETE FROM public.users + WHERE users.data->>'user_id' = %(user_uuid)s;""", # removed this: AND users.deleted_at IS NOT NULL + {"user_uuid": user_uuid}) + cur.execute(query) + def logout(user_id: int): with pg_client.PostgresClient() as cur: diff --git a/ee/api/routers/scim.py b/ee/api/routers/scim.py new file mode 100644 index 000000000..54c8934a9 --- /dev/null +++ b/ee/api/routers/scim.py @@ -0,0 +1,438 @@ +import logging +import uuid +from typing import Optional + +from decouple import config +from fastapi import Depends, HTTPException, Header, Query, Response +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field + +import schemas +from chalicelib.core import users, roles +from routers.base import get_routers + +logger = logging.getLogger(__name__) + +""" +Models: + +USER + +schemas -> hardcoded +id -> from db +userName -> email, comes from Okta +name: + givenName -> from Okta + middleName -> from Okta + familyName -> from Okta +emails: + primary -> from Okta + value -> from Okta + type -> from Okta +displayName -> from Okta (potentially, givenName+" "+familyName) +locale -> from Okta (e.g. en-US) +externalId -> from Okta +active -> ! doesn't exist, but represent deleted users +groups -> users: {"display": group.displayName, "value": group.id} +meta -> hardcoded + + +GROUP + +schemas -> hardcoded +id -> from db +meta -> hardcoded +displayName -> from db +members -> users: {"display": user.userName, "value": user.id} + + +""" + +class Name(BaseModel): + givenName: str + familyName: str + +class Email(BaseModel): + primary: bool + value: str + type: str + +class UserRequest(BaseModel): + schemas: list[str] + userName: str + name: Name + emails: list[Email] # ignore for now + displayName: str + locale: str + externalId: str + groups: list[dict] + password: str = Field(default=None) + active: bool + + +class UserResponse(BaseModel): + schemas: list[str] + id: str + userName: str + name: Name + emails: list[Email] # ignore for now + displayName: str + locale: str + externalId: str + active: bool + groups: list[dict] + meta: dict = Field(default={"resourceType": "User"}) + +class PatchUserRequest(BaseModel): + schemas: list[str] + Operations: list[dict] + + +# Authentication Dependency +def auth_required(authorization: str = Header(..., alias="Authorization")): + """Dependency to check Authorization header.""" + token = authorization.replace("Bearer ", "") + if token != config("OCTA_TOKEN"): + raise HTTPException(status_code=403, detail="Unauthorized") + return token + + +public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2") + +@public_app.get("/Users", dependencies=[Depends(auth_required)]) +async def get_users( + start_index: int = Query(1, alias="startIndex"), + count: Optional[int] = Query(1, alias="count"), + filter: Optional[str] = Query(None, alias="filter"), +): + """Get SCIM Users""" + if filter: + single_filter = filter.split(" ") + filter_value = single_filter[2].strip('"') + + filtered_users = users.get_by_email_with_uuid(filter_value) + filtered_users = [filtered_users] if filtered_users else [] + else: + filtered_users = users.get_users_paginated(start_index, count) + + serialized_users = [] + for user in filtered_users: + logger.info(user) + serialized_users.append( + UserResponse( + schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"], + id = user["data"]["userId"], + userName = user["email"], + name = Name.model_validate(user["data"]["name"]), + emails = [Email.model_validate(user["data"]["emails"])], + displayName = user["name"], + locale = user["data"]["locale"], + externalId = user["internalId"], + active = True, # ignore for now, since, can't insert actual timestamp + groups = [], # ignore + ).model_dump(mode='json') + ) + return JSONResponse( + status_code=200, + content={ + "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], + "totalResults": len(serialized_users), + "startIndex": start_index, + "itemsPerPage": len(serialized_users), + "Resources": serialized_users, + }, + ) + +@public_app.get("/Users/{user_id}", dependencies=[Depends(auth_required)]) +def get_user(user_id: str): + """Get SCIM User""" + user = users.get_by_uuid(user_id, 1) + if not user: + return JSONResponse( + status_code=404, + content={ + "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], + "detail": "User not found", + "status": 404, + } + ) + + res = UserResponse( + schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"], + id = user["data"]["userId"], + userName = user["email"], + name = Name.model_validate(user["data"]["name"]), + emails = [Email.model_validate(user["data"]["emails"])], + displayName = user["name"], + locale = user["data"]["locale"], + externalId = user["internalId"], + active = True, # ignore for now, since, can't insert actual timestamp + groups = [], # ignore + ) + return JSONResponse(status_code=201, content=res.model_dump(mode='json')) + + +@public_app.post("/Users", dependencies=[Depends(auth_required)]) +async def create_user(r: UserRequest): + ## This needs to manage addition of previously deactivated users + """Create SCIM User""" + logger.info(r) + existing_user = users.get_by_email_only(r.userName) + + if existing_user: + return JSONResponse( + status_code = 409, + content = { + "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], + "detail": "User already exists in the database.", + "status": 409, + } + ) + else: + try: + # Need to handle groups later, for now ignore them + user = users.create_scim_user(tenant_id=1, user_uuid=uuid.uuid4().hex, username=r.emails[0].value, admin=False, + display_name=r.displayName, full_name=r.name.model_dump(mode='json'), emails=r.emails[0].model_dump(mode='json'), + origin="okta", locale=r.locale, role_id=None, internal_id=r.externalId) # role_id is set to 2 by default... + res = UserResponse( + schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"], + id = user["data"]["userId"], # Transformed to camel case + userName = r.userName, + name = r.name, + emails = r.emails, + displayName = r.displayName, + locale = r.locale, + externalId = r.externalId, + active = r.active, # ignore for now, since, can't insert actual timestamp + groups = [], # ignore + ) + return JSONResponse(status_code=201, content=res.model_dump(mode='json')) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@public_app.put("/Users/{user_id}", dependencies=[Depends(auth_required)]) # insert your header later +def update_user(user_id: str, r: UserRequest): + """Update SCIM User""" + logger.info(r) + user = users.get_by_uuid(user_id, 1) + if not user: + raise HTTPException(status_code=404, detail="User not found") + + changes = r.model_dump(mode='json', exclude={"schemas", "emails", "name", "locale", "groups", "password", "active"}) # some of these should be added later if necessary + nested_changes = r.model_dump(mode='json', include={"name", "emails"}) + mapping = {"userName": "email", "displayName": "name", "externalId": "internal_id"} # mapping between scim schema field names and local database model, can be done as config? + for k, v in mapping.items(): + if k in changes: + changes[v] = changes.pop(k) + changes["data"] = {} + for k, v in nested_changes.items(): + value_to_insert = v[0] if k == "emails" else v + changes["data"][k] = value_to_insert + try: + # Need to handle groups later, for now ignore them + users.update(1, user["userId"], changes) + res = UserResponse( + schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"], + id = user["data"]["userId"], + userName = r.userName, + name = r.name, + emails = r.emails, + displayName = r.displayName, + locale = r.locale, + externalId = r.externalId, + active = r.active, # ignore for now, since, can't insert actual timestamp + groups = [], # ignore + ) + + return JSONResponse(status_code=201, content=res.model_dump(mode='json')) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@public_app.patch("/Users/{user_id}", dependencies=[Depends(auth_required)]) +def deactivate_user(user_id: str, r: PatchUserRequest): + logger.info(r) + active = r.model_dump(mode='json')["Operations"][0]["value"]["active"] + logger.info(active) + if active: + raise HTTPException(status_code=404, detail="Activating user is not supported") + user = users.get_by_uuid(user_id, 1) + if not user: + raise HTTPException(status_code=404, detail="User not found") + logger.info(user) + users.delete_member_as_admin(1, user["userId"]) + + return Response(status_code=204, content="") + +@public_app.delete("/Users/{user_uuid}", dependencies=[Depends(auth_required)]) +def delete_user(user_uuid: str): + user = users.get_by_uuid(user_uuid, 1) + if not user: + raise HTTPException(status_code=404, detail="User not found") + + users.__hard_delete_user_uuid(user_uuid) + return Response(status_code=204, content="") + + +""" +Group endpoints + +Potential issues: +1. Every user can be assigned only to single role +2. Deleting the group might be constrained by existing users linked to the role, + since those can't be left orphans +3. + +""" + +class Operation(BaseModel): + op: str + path: str = Field(default=None) + value: list[dict] | dict = Field(default=None) + +class GroupRequest(BaseModel): + schemas: list[str] = Field(default=["urn:ietf:params:scim:schemas:core:2.0:Group"]) + displayName: str = Field(default=None) + members: list = Field(default=None) + operations: list[Operation] = Field(default=None, alias="Operations") + +class GroupPatchRequest(BaseModel): + schemas: list[str] = Field(default=["urn:ietf:params:scim:api:messages:2.0:PatchOp"]) + operations: list[Operation] = Field(alias="Operations") + +class GroupResponse(BaseModel): + schemas: list[str] + id: str + meta: dict = Field(default={"resourceType": "Group"}) + displayName: str + members: list + +@public_app.get("/Groups", dependencies=[Depends(auth_required)]) +def get_groups(): # Might need to add query params later + groups = roles.get_roles_with_uuid(1) + res = [] + for group in groups: + res.append(GroupResponse( + schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"], + id=group["data"]["groupId"], + displayName=group["name"], + members=[], # add later + ).model_dump(mode='json')) + return JSONResponse( + status_code=200, + content=res + ) + +@public_app.get("/Groups/{group_id}", dependencies=[Depends(auth_required)]) +def get_group(group_id: str): + group = roles.get_role_by_group_id(1, group_id) + if not group: + raise HTTPException(status_code=404, detail="Group not found") + + return JSONResponse( + status_code=200, + content=GroupResponse( + schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"], + id=group["data"]["groupId"], + displayName=group["name"], + members=[], # add later + ).model_dump(mode='json')) + +@public_app.post("/Groups", dependencies=[Depends(auth_required)]) +def create_group(r: GroupRequest): + logger.info(r) + try: + data = schemas.RolePayloadSchema(name=r.displayName, permissions=[schemas.Permissions.SESSION_REPLAY]) # one permission for now + group = roles.create_as_admin(1, uuid.uuid4().hex, data) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + return JSONResponse( + status_code=200, + content=GroupResponse( + schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"], + id=group["data"]["groupId"], + displayName=group["name"], + members=[], # add later + ).model_dump(mode='json')) + + +@public_app.put("/Groups/{group_id}", dependencies=[Depends(auth_required)]) +def update_put_group(group_id: str, r: GroupRequest): + # Possibly need to change GroupRequest object to accept a different structure + logger.info(r) + group = roles.get_role_by_group_id(1, group_id) + if not group: + raise HTTPException(status_code=404, detail="Group not found") + + if r.operations and r.operations[0].op == "replace" and r.operations[0].path is None: + roles.update_group_name(1, group["data"]["groupId"], r.operations[0].value["displayName"]) + return Response(status_code=200, content="") + + members = r.members + modified_members = [] + for member in members: + user = users.get_by_uuid(member["value"], 1) + if user: + users.update(1, user["userId"], {"role_id": group["roleId"]}) + modified_members.append({ + "value": user["data"]["userId"], + "display": user["name"] + }) + + return JSONResponse( + status_code=200, + content=GroupResponse( + schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"], + id=group_id, + displayName=group["name"], + members=modified_members, + ).model_dump(mode='json')) + + +@public_app.patch("/Groups/{group_id}", dependencies=[Depends(auth_required)]) +def update_patch_group(group_id: str, r: GroupPatchRequest): + logger.info(r) + group = roles.get_role_by_group_id(1, group_id) + if not group: + raise HTTPException(status_code=404, detail="Group not found") + if r.operations[0].op == "replace" and r.operations[0].path is None: + roles.update_group_name(1, group["data"]["groupId"], r.operations[0].value["displayName"]) + return Response(status_code=200, content="") + + if r.operations[0].op == "replace": + # find all members of that role, and for those that don't intersect with the list, set them to default role and return + pass + modified_members = [] + for op in r.operations: + if op.op == "add": + for u in op.value: + user = users.get_by_uuid(u["value"], 1) + if user: + users.update(1, user["userId"], {"role_id": group["roleId"]}) + modified_members.append({ + "value": user["data"]["userId"], + "display": user["name"] + }) + else: + # possibly remove by parsing the path? + pass + return JSONResponse( + status_code=200, + content=GroupResponse( + schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"], + id=group_id, + displayName=group["name"], + members=modified_members, + ).model_dump(mode='json')) + + +@public_app.delete("/Groups/{group_id}", dependencies=[Depends(auth_required)]) +def delete_group(group_id: str): + group = roles.get_role_by_group_id(1, group_id) + if not group: + raise HTTPException(status_code=404, detail="Group not found") + roles.delete_scim_group(1, group["data"]["groupId"]) + + return Response(status_code=200, content="") From e13008c006fbb3d1d91413acc0f68a794fb4b464 Mon Sep 17 00:00:00 2001 From: Pavel Kim Date: Mon, 17 Feb 2025 17:46:41 +0100 Subject: [PATCH 2/4] Support reactivation of users --- ee/api/chalicelib/core/users.py | 151 +++++++++++++++++++++++++++++++- ee/api/routers/scim.py | 42 +++++---- 2 files changed, 172 insertions(+), 21 deletions(-) diff --git a/ee/api/chalicelib/core/users.py b/ee/api/chalicelib/core/users.py index da6e96531..35bd2f732 100644 --- a/ee/api/chalicelib/core/users.py +++ b/ee/api/chalicelib/core/users.py @@ -318,6 +318,41 @@ def get_by_uuid(user_uuid, tenant_id): ) r = cur.fetchone() return helper.dict_to_camel_case(r) + +def get_deleted_by_uuid(user_uuid, tenant_id): + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + f"""SELECT + users.user_id, + users.tenant_id, + email, + role, + users.name, + users.data, + users.internal_id, + (CASE WHEN role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, + (CASE WHEN role = 'admin' THEN TRUE ELSE FALSE END) AS admin, + (CASE WHEN role = 'member' THEN TRUE ELSE FALSE END) AS member, + origin, + role_id, + roles.name AS role_name, + roles.permissions, + roles.all_projects, + basic_authentication.password IS NOT NULL AS has_password, + users.service_account + FROM public.users LEFT JOIN public.basic_authentication ON users.user_id=basic_authentication.user_id + LEFT JOIN public.roles USING (role_id) + WHERE + users.data->>'user_id' = %(user_uuid)s + AND users.tenant_id = %(tenant_id)s + AND users.deleted_at IS NOT NULL + AND (roles.role_id IS NULL OR roles.deleted_at IS NULL AND roles.tenant_id = %(tenant_id)s) + LIMIT 1;""", + {"user_uuid": user_uuid, "tenant_id": tenant_id}) + ) + r = cur.fetchone() + return helper.dict_to_camel_case(r) @@ -631,7 +666,7 @@ def delete_member_as_admin(tenant_id, id_to_delete): users.user_id AS user_id, users.tenant_id, email, - role, + role, users.name, origin, role_id, @@ -1030,7 +1065,7 @@ def create_sso_user(tenant_id, email, admin, name, origin, role_id, internal_id= def create_scim_user( tenant_id, user_uuid, - username, + email, admin, display_name, full_name: dict, @@ -1065,7 +1100,7 @@ def create_scim_user( (CASE WHEN u.role = 'member' THEN TRUE ELSE FALSE END) AS member, origin FROM u;""", - {"tenant_id": tenant_id, "email": username, "internal_id": internal_id, + {"tenant_id": tenant_id, "email": email, "internal_id": internal_id, "role": "admin" if admin else "member", "name": display_name, "origin": origin, "role_id": role_id, "data": json.dumps({"lastAnnouncementView": TimeUTC.now(), "user_id": user_uuid, "locale": locale, "name": full_name, "emails": emails})}) cur.execute( @@ -1214,6 +1249,116 @@ def restore_sso_user(user_id, tenant_id, email, admin, name, origin, role_id, in return helper.dict_to_camel_case(cur.fetchone()) +def restore_scim_user( + user_id, + tenant_id, + user_uuid, + email, + admin, + display_name, + full_name: dict, + emails, + origin, + locale, + role_id, + internal_id=None): + with pg_client.PostgresClient() as cur: + query = cur.mogrify(f"""\ + WITH u AS ( + UPDATE public.users + SET tenant_id= %(tenant_id)s, + role= %(role)s, + name= %(name)s, + data= %(data)s, + origin= %(origin)s, + internal_id= %(internal_id)s, + role_id= (SELECT COALESCE((SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND role_id = %(role_id)s), + (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name = 'Member' LIMIT 1), + (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name != 'Owner' LIMIT 1))), + deleted_at= NULL, + created_at= default, + api_key= default, + jwt_iat= NULL, + weekly_report= default + WHERE user_id = %(user_id)s + RETURNING * + ), + au AS ( + UPDATE public.basic_authentication + SET password= default, + invitation_token= default, + invited_at= default, + change_pwd_token= default, + change_pwd_expire_at= default, + changed_at= NULL + WHERE user_id = %(user_id)s + RETURNING user_id + ) + SELECT u.user_id AS id, + u.email, + u.role, + u.name, + u.data, + (CASE WHEN u.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, + (CASE WHEN u.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, + (CASE WHEN u.role = 'member' THEN TRUE ELSE FALSE END) AS member, + origin + FROM u;""", + {"tenant_id": tenant_id, "email": email, "internal_id": internal_id, + "role": "admin" if admin else "member", "name": display_name, "origin": origin, + "role_id": role_id, "data": json.dumps({"lastAnnouncementView": TimeUTC.now(), "user_id": user_uuid, "locale": locale, "name": full_name, "emails": emails}), + "user_id": user_id}) + cur.execute( + query + ) + return helper.dict_to_camel_case(cur.fetchone()) + +def create_scim_user2( + tenant_id, + user_uuid, + username, + admin, + display_name, + full_name: dict, + emails, + origin, + locale, + role_id, + internal_id=None, +): + + with pg_client.PostgresClient() as cur: + query = cur.mogrify(f"""\ + WITH u AS ( + INSERT INTO public.users (tenant_id, email, role, name, data, origin, internal_id, role_id) + VALUES (%(tenant_id)s, %(email)s, %(role)s, %(name)s, %(data)s, %(origin)s, %(internal_id)s, + (SELECT COALESCE((SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND role_id = %(role_id)s), + (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name = 'Member' LIMIT 1), + (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name != 'Owner' LIMIT 1)))) + RETURNING * + ), + au AS ( + INSERT INTO public.basic_authentication(user_id) + VALUES ((SELECT user_id FROM u)) + ) + SELECT u.user_id AS id, + u.email, + u.role, + u.name, + u.data, + (CASE WHEN u.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, + (CASE WHEN u.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, + (CASE WHEN u.role = 'member' THEN TRUE ELSE FALSE END) AS member, + origin + FROM u;""", + {"tenant_id": tenant_id, "email": username, "internal_id": internal_id, + "role": "admin" if admin else "member", "name": display_name, "origin": origin, + "role_id": role_id, "data": json.dumps({"lastAnnouncementView": TimeUTC.now(), "user_id": user_uuid, "locale": locale, "name": full_name, "emails": emails})}) + cur.execute( + query + ) + return helper.dict_to_camel_case(cur.fetchone()) + def get_user_settings(user_id): # read user settings from users.settings:jsonb column with pg_client.PostgresClient() as cur: diff --git a/ee/api/routers/scim.py b/ee/api/routers/scim.py index 54c8934a9..0ce029621 100644 --- a/ee/api/routers/scim.py +++ b/ee/api/routers/scim.py @@ -25,7 +25,7 @@ name: givenName -> from Okta middleName -> from Okta familyName -> from Okta -emails: +emails: primary -> from Okta value -> from Okta type -> from Okta @@ -69,7 +69,6 @@ class UserRequest(BaseModel): password: str = Field(default=None) active: bool - class UserResponse(BaseModel): schemas: list[str] id: str @@ -176,8 +175,8 @@ def get_user(user_id: str): async def create_user(r: UserRequest): ## This needs to manage addition of previously deactivated users """Create SCIM User""" - logger.info(r) existing_user = users.get_by_email_only(r.userName) + deleted_user = users.get_deleted_user_by_email(r.userName) if existing_user: return JSONResponse( @@ -188,28 +187,35 @@ async def create_user(r: UserRequest): "status": 409, } ) + elif deleted_user: + user_id = users.get_deleted_by_uuid(deleted_user["data"]["userId"], 1) + user = users.restore_scim_user(user_id=user_id["userId"], tenant_id=1, user_uuid=uuid.uuid4().hex, email=r.emails[0].value, admin=False, + display_name=r.displayName, full_name=r.name.model_dump(mode='json'), emails=r.emails[0].model_dump(mode='json'), + origin="okta", locale=r.locale, role_id=None, internal_id=r.externalId) else: try: # Need to handle groups later, for now ignore them - user = users.create_scim_user(tenant_id=1, user_uuid=uuid.uuid4().hex, username=r.emails[0].value, admin=False, + user = users.create_scim_user(tenant_id=1, user_uuid=uuid.uuid4().hex, email=r.emails[0].value, admin=False, display_name=r.displayName, full_name=r.name.model_dump(mode='json'), emails=r.emails[0].model_dump(mode='json'), origin="okta", locale=r.locale, role_id=None, internal_id=r.externalId) # role_id is set to 2 by default... - res = UserResponse( - schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"], - id = user["data"]["userId"], # Transformed to camel case - userName = r.userName, - name = r.name, - emails = r.emails, - displayName = r.displayName, - locale = r.locale, - externalId = r.externalId, - active = r.active, # ignore for now, since, can't insert actual timestamp - groups = [], # ignore - ) - return JSONResponse(status_code=201, content=res.model_dump(mode='json')) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - + + res = UserResponse( + schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"], + id = user["data"]["userId"], # Transformed to camel case + userName = r.userName, + name = r.name, + emails = r.emails, + displayName = r.displayName, + locale = r.locale, + externalId = r.externalId, + active = r.active, # ignore for now, since, can't insert actual timestamp + groups = [], # ignore + ) + return JSONResponse(status_code=201, content=res.model_dump(mode='json')) + + @public_app.put("/Users/{user_id}", dependencies=[Depends(auth_required)]) # insert your header later def update_user(user_id: str, r: UserRequest): From 937e4d244c82eac38bf45dedc13d36a98ed38058 Mon Sep 17 00:00:00 2001 From: Pavel Kim Date: Tue, 18 Feb 2025 16:40:52 +0100 Subject: [PATCH 3/4] Fix pagination and implement all patch group methods --- ee/api/chalicelib/core/roles.py | 78 +++++++++- ee/api/chalicelib/core/users.py | 90 +----------- ee/api/routers/scim.py | 252 ++++++++++++++++---------------- 3 files changed, 209 insertions(+), 211 deletions(-) diff --git a/ee/api/chalicelib/core/roles.py b/ee/api/chalicelib/core/roles.py index a879ff613..955c76af0 100644 --- a/ee/api/chalicelib/core/roles.py +++ b/ee/api/chalicelib/core/roles.py @@ -199,8 +199,31 @@ def get_roles_with_uuid(tenant_id): r["created_at"] = TimeUTC.datetime_to_timestamp(r["created_at"]) return helper.list_to_camel_case(rows) +def get_roles_with_uuid_paginated(tenant_id, start_index, count=None, name=None): + with pg_client.PostgresClient() as cur: + query = cur.mogrify("""SELECT roles.*, COALESCE(projects, '{}') AS projects + FROM public.roles + LEFT JOIN LATERAL (SELECT array_agg(project_id) AS projects + FROM roles_projects + INNER JOIN projects USING (project_id) + WHERE roles_projects.role_id = roles.role_id + AND projects.deleted_at ISNULL ) AS role_projects ON (TRUE) + WHERE tenant_id =%(tenant_id)s + AND data ? 'group_id' + AND deleted_at IS NULL + AND not service_role + AND name = COALESCE(%(name)s, name) + ORDER BY role_id + LIMIT %(count)s + OFFSET %(startIndex)s;""", + {"tenant_id": tenant_id, "name": name, "startIndex": start_index - 1, "count": count}) + cur.execute(query=query) + rows = cur.fetchall() + return helper.list_to_camel_case(rows) + def get_role_by_name(tenant_id, name): + ### "name" isn't unique in database with pg_client.PostgresClient() as cur: query = cur.mogrify("""SELECT * FROM public.roles @@ -303,4 +326,57 @@ def get_role_by_group_id(tenant_id, group_id): row = cur.fetchone() if row is not None: row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"]) - return helper.dict_to_camel_case(row) \ No newline at end of file + return helper.dict_to_camel_case(row) + +def get_users_by_group_uuid(tenant_id, group_id): + with pg_client.PostgresClient() as cur: + query = cur.mogrify("""SELECT + u.user_id, + u.name, + u.data + FROM public.roles r + LEFT JOIN public.users u USING (role_id, tenant_id) + WHERE u.tenant_id = %(tenant_id)s + AND u.deleted_at IS NULL + AND r.data->>'group_id' = %(group_id)s + """, + {"tenant_id": tenant_id, "group_id": group_id}) + cur.execute(query=query) + rows = cur.fetchall() + return helper.list_to_camel_case(rows) + +def get_member_permissions(tenant_id): + with pg_client.PostgresClient() as cur: + query = cur.mogrify("""SELECT + r.permissions + FROM public.roles r + WHERE r.tenant_id = %(tenant_id)s + AND r.name = 'Member' + AND r.deleted_at IS NULL + """, + {"tenant_id": tenant_id}) + cur.execute(query=query) + row = cur.fetchone() + return helper.dict_to_camel_case(row) + +def remove_group_membership(tenant_id, group_id, user_id): + with pg_client.PostgresClient() as cur: + query = cur.mogrify("""WITH r AS ( + SELECT role_id + FROM public.roles + WHERE data->>'group_id' = %(group_id)s + LIMIT 1 + ) + UPDATE public.users u + SET role_id= NULL + FROM r + WHERE u.data->>'user_id' = %(user_id)s + AND u.role_id = r.role_id + AND u.tenant_id = %(tenant_id)s + AND u.deleted_at IS NULL + RETURNING *;""", + {"tenant_id": tenant_id, "group_id": group_id, "user_id": user_id}) + cur.execute(query=query) + row = cur.fetchone() + + return helper.dict_to_camel_case(row) diff --git a/ee/api/chalicelib/core/users.py b/ee/api/chalicelib/core/users.py index 35bd2f732..94d1e8d41 100644 --- a/ee/api/chalicelib/core/users.py +++ b/ee/api/chalicelib/core/users.py @@ -277,7 +277,7 @@ def get(user_id, tenant_id): users.user_id = %(userId)s AND users.tenant_id = %(tenant_id)s AND users.deleted_at IS NULL - AND (roles.role_id IS NULL OR roles.deleted_at IS NULL AND roles.tenant_id = %(tenant_id)s) + AND (roles.role_id IS NULL OR roles.deleted_at IS NULL AND roles.tenant_id = %(tenant_id)s) LIMIT 1;""", {"userId": user_id, "tenant_id": tenant_id}) ) @@ -318,7 +318,7 @@ def get_by_uuid(user_uuid, tenant_id): ) r = cur.fetchone() return helper.dict_to_camel_case(r) - + def get_deleted_by_uuid(user_uuid, tenant_id): with pg_client.PostgresClient() as cur: cur.execute( @@ -481,37 +481,7 @@ def get_by_email_only(email): r = cur.fetchone() return helper.dict_to_camel_case(r) -def get_by_email_with_uuid(email): - with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify( - f"""SELECT - users.user_id, - users.tenant_id, - users.email, - users.role, - users.name, - users.data, - (CASE WHEN users.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, - (CASE WHEN users.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, - (CASE WHEN users.role = 'member' THEN TRUE ELSE FALSE END) AS member, - origin, - basic_authentication.password IS NOT NULL AS has_password, - role_id, - internal_id, - roles.name AS role_name - FROM public.users LEFT JOIN public.basic_authentication USING(user_id) - INNER JOIN public.roles USING(role_id) - WHERE users.email = %(email)s - AND users.deleted_at IS NULL - LIMIT 1;""", - {"email": email}) - ) - r = cur.fetchone() - return helper.dict_to_camel_case(r) - - -def get_users_paginated(start_index, count): +def get_users_paginated(start_index, count=None, email=None): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( @@ -532,10 +502,12 @@ def get_users_paginated(start_index, count): roles.name AS role_name FROM public.users LEFT JOIN public.basic_authentication USING(user_id) INNER JOIN public.roles USING(role_id) - WHERE users.deleted_at IS NULL AND users.data ? 'user_id' + WHERE users.deleted_at IS NULL + AND users.data ? 'user_id' + AND email = COALESCE(%(email)s, email) LIMIT %(count)s - OFFSET %(startIndex)s;""", - {"startIndex": start_index - 1, "count": count}) + OFFSET %(startIndex)s;;""", + {"startIndex": start_index - 1, "count": count, "email": email}) ) r = cur.fetchall() if len(r): @@ -1313,52 +1285,6 @@ def restore_scim_user( ) return helper.dict_to_camel_case(cur.fetchone()) -def create_scim_user2( - tenant_id, - user_uuid, - username, - admin, - display_name, - full_name: dict, - emails, - origin, - locale, - role_id, - internal_id=None, -): - - with pg_client.PostgresClient() as cur: - query = cur.mogrify(f"""\ - WITH u AS ( - INSERT INTO public.users (tenant_id, email, role, name, data, origin, internal_id, role_id) - VALUES (%(tenant_id)s, %(email)s, %(role)s, %(name)s, %(data)s, %(origin)s, %(internal_id)s, - (SELECT COALESCE((SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND role_id = %(role_id)s), - (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name = 'Member' LIMIT 1), - (SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name != 'Owner' LIMIT 1)))) - RETURNING * - ), - au AS ( - INSERT INTO public.basic_authentication(user_id) - VALUES ((SELECT user_id FROM u)) - ) - SELECT u.user_id AS id, - u.email, - u.role, - u.name, - u.data, - (CASE WHEN u.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin, - (CASE WHEN u.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, - (CASE WHEN u.role = 'member' THEN TRUE ELSE FALSE END) AS member, - origin - FROM u;""", - {"tenant_id": tenant_id, "email": username, "internal_id": internal_id, - "role": "admin" if admin else "member", "name": display_name, "origin": origin, - "role_id": role_id, "data": json.dumps({"lastAnnouncementView": TimeUTC.now(), "user_id": user_uuid, "locale": locale, "name": full_name, "emails": emails})}) - cur.execute( - query - ) - return helper.dict_to_camel_case(cur.fetchone()) - def get_user_settings(user_id): # read user settings from users.settings:jsonb column with pg_client.PostgresClient() as cur: diff --git a/ee/api/routers/scim.py b/ee/api/routers/scim.py index 0ce029621..85a777c90 100644 --- a/ee/api/routers/scim.py +++ b/ee/api/routers/scim.py @@ -1,4 +1,5 @@ import logging +import re import uuid from typing import Optional @@ -13,39 +14,20 @@ from routers.base import get_routers logger = logging.getLogger(__name__) + +public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2") + +# Authentication Dependency +def auth_required(authorization: str = Header(..., alias="Authorization")): + """Dependency to check Authorization header.""" + token = authorization.replace("Bearer ", "") + if token != config("OCTA_TOKEN"): + raise HTTPException(status_code=403, detail="Unauthorized") + return token + + """ -Models: - -USER - -schemas -> hardcoded -id -> from db -userName -> email, comes from Okta -name: - givenName -> from Okta - middleName -> from Okta - familyName -> from Okta -emails: - primary -> from Okta - value -> from Okta - type -> from Okta -displayName -> from Okta (potentially, givenName+" "+familyName) -locale -> from Okta (e.g. en-US) -externalId -> from Okta -active -> ! doesn't exist, but represent deleted users -groups -> users: {"display": group.displayName, "value": group.id} -meta -> hardcoded - - -GROUP - -schemas -> hardcoded -id -> from db -meta -> hardcoded -displayName -> from db -members -> users: {"display": user.userName, "value": user.id} - - +User endpoints """ class Name(BaseModel): @@ -61,7 +43,7 @@ class UserRequest(BaseModel): schemas: list[str] userName: str name: Name - emails: list[Email] # ignore for now + emails: list[Email] displayName: str locale: str externalId: str @@ -87,36 +69,19 @@ class PatchUserRequest(BaseModel): Operations: list[dict] -# Authentication Dependency -def auth_required(authorization: str = Header(..., alias="Authorization")): - """Dependency to check Authorization header.""" - token = authorization.replace("Bearer ", "") - if token != config("OCTA_TOKEN"): - raise HTTPException(status_code=403, detail="Unauthorized") - return token - - -public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2") - @public_app.get("/Users", dependencies=[Depends(auth_required)]) async def get_users( start_index: int = Query(1, alias="startIndex"), - count: Optional[int] = Query(1, alias="count"), - filter: Optional[str] = Query(None, alias="filter"), + count: Optional[int] = Query(None, alias="count"), + email: Optional[str] = Query(None, alias="filter"), ): """Get SCIM Users""" - if filter: - single_filter = filter.split(" ") - filter_value = single_filter[2].strip('"') - - filtered_users = users.get_by_email_with_uuid(filter_value) - filtered_users = [filtered_users] if filtered_users else [] - else: - filtered_users = users.get_users_paginated(start_index, count) + if email: + email = email.split(" ")[2].strip('"') + result_users = users.get_users_paginated(start_index, count, email) serialized_users = [] - for user in filtered_users: - logger.info(user) + for user in result_users: serialized_users.append( UserResponse( schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"], @@ -145,7 +110,8 @@ async def get_users( @public_app.get("/Users/{user_id}", dependencies=[Depends(auth_required)]) def get_user(user_id: str): """Get SCIM User""" - user = users.get_by_uuid(user_id, 1) + tenant_id = 1 + user = users.get_by_uuid(user_id, tenant_id) if not user: return JSONResponse( status_code=404, @@ -173,8 +139,8 @@ def get_user(user_id: str): @public_app.post("/Users", dependencies=[Depends(auth_required)]) async def create_user(r: UserRequest): - ## This needs to manage addition of previously deactivated users """Create SCIM User""" + tenant_id = 1 existing_user = users.get_by_email_only(r.userName) deleted_user = users.get_deleted_user_by_email(r.userName) @@ -188,22 +154,21 @@ async def create_user(r: UserRequest): } ) elif deleted_user: - user_id = users.get_deleted_by_uuid(deleted_user["data"]["userId"], 1) - user = users.restore_scim_user(user_id=user_id["userId"], tenant_id=1, user_uuid=uuid.uuid4().hex, email=r.emails[0].value, admin=False, + user_id = users.get_deleted_by_uuid(deleted_user["data"]["userId"], tenant_id) + user = users.restore_scim_user(user_id=user_id["userId"], tenant_id=tenant_id, user_uuid=uuid.uuid4().hex, email=r.emails[0].value, admin=False, display_name=r.displayName, full_name=r.name.model_dump(mode='json'), emails=r.emails[0].model_dump(mode='json'), origin="okta", locale=r.locale, role_id=None, internal_id=r.externalId) else: try: - # Need to handle groups later, for now ignore them - user = users.create_scim_user(tenant_id=1, user_uuid=uuid.uuid4().hex, email=r.emails[0].value, admin=False, + user = users.create_scim_user(tenant_id=tenant_id, user_uuid=uuid.uuid4().hex, email=r.emails[0].value, admin=False, display_name=r.displayName, full_name=r.name.model_dump(mode='json'), emails=r.emails[0].model_dump(mode='json'), - origin="okta", locale=r.locale, role_id=None, internal_id=r.externalId) # role_id is set to 2 by default... + origin="okta", locale=r.locale, role_id=None, internal_id=r.externalId) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) res = UserResponse( schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"], - id = user["data"]["userId"], # Transformed to camel case + id = user["data"]["userId"], userName = r.userName, name = r.name, emails = r.emails, @@ -217,11 +182,11 @@ async def create_user(r: UserRequest): -@public_app.put("/Users/{user_id}", dependencies=[Depends(auth_required)]) # insert your header later +@public_app.put("/Users/{user_id}", dependencies=[Depends(auth_required)]) def update_user(user_id: str, r: UserRequest): """Update SCIM User""" - logger.info(r) - user = users.get_by_uuid(user_id, 1) + tenant_id = 1 + user = users.get_by_uuid(user_id, tenant_id) if not user: raise HTTPException(status_code=404, detail="User not found") @@ -236,8 +201,7 @@ def update_user(user_id: str, r: UserRequest): value_to_insert = v[0] if k == "emails" else v changes["data"][k] = value_to_insert try: - # Need to handle groups later, for now ignore them - users.update(1, user["userId"], changes) + users.update(tenant_id, user["userId"], changes) res = UserResponse( schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"], id = user["data"]["userId"], @@ -258,22 +222,23 @@ def update_user(user_id: str, r: UserRequest): @public_app.patch("/Users/{user_id}", dependencies=[Depends(auth_required)]) def deactivate_user(user_id: str, r: PatchUserRequest): - logger.info(r) + """Deactivate user, soft-delete""" + tenant_id = 1 active = r.model_dump(mode='json')["Operations"][0]["value"]["active"] - logger.info(active) if active: raise HTTPException(status_code=404, detail="Activating user is not supported") - user = users.get_by_uuid(user_id, 1) + user = users.get_by_uuid(user_id, tenant_id) if not user: raise HTTPException(status_code=404, detail="User not found") - logger.info(user) - users.delete_member_as_admin(1, user["userId"]) + users.delete_member_as_admin(tenant_id, user["userId"]) return Response(status_code=204, content="") @public_app.delete("/Users/{user_uuid}", dependencies=[Depends(auth_required)]) def delete_user(user_uuid: str): - user = users.get_by_uuid(user_uuid, 1) + """Delete user from database, hard-delete""" + tenant_id = 1 + user = users.get_by_uuid(user_uuid, tenant_id) if not user: raise HTTPException(status_code=404, detail="User not found") @@ -281,15 +246,9 @@ def delete_user(user_uuid: str): return Response(status_code=204, content="") + """ Group endpoints - -Potential issues: -1. Every user can be assigned only to single role -2. Deleting the group might be constrained by existing users linked to the role, - since those can't be left orphans -3. - """ class Operation(BaseModel): @@ -297,6 +256,13 @@ class Operation(BaseModel): path: str = Field(default=None) value: list[dict] | dict = Field(default=None) +class GroupGetResponse(BaseModel): + schemas: list[str] = Field(default=["urn:ietf:params:scim:api:messages:2.0:ListResponse"]) + totalResults: int + startIndex: int + itemsPerPage: int + resources: list = Field(alias="Resources") + class GroupRequest(BaseModel): schemas: list[str] = Field(default=["urn:ietf:params:scim:schemas:core:2.0:Group"]) displayName: str = Field(default=None) @@ -308,80 +274,111 @@ class GroupPatchRequest(BaseModel): operations: list[Operation] = Field(alias="Operations") class GroupResponse(BaseModel): - schemas: list[str] + schemas: list[str] = Field(default=["urn:ietf:params:scim:schemas:core:2.0:Group"]) id: str - meta: dict = Field(default={"resourceType": "Group"}) displayName: str members: list + meta: dict = Field(default={"resourceType": "Group"}) + @public_app.get("/Groups", dependencies=[Depends(auth_required)]) -def get_groups(): # Might need to add query params later - groups = roles.get_roles_with_uuid(1) +def get_groups( + start_index: int = Query(1, alias="startIndex"), + count: Optional[int] = Query(None, alias="count"), + group_name: Optional[str] = Query(None, alias="filter"), + ): + """Get groups""" + tenant_id = 1 res = [] - for group in groups: - res.append(GroupResponse( - schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"], - id=group["data"]["groupId"], - displayName=group["name"], - members=[], # add later - ).model_dump(mode='json')) + if group_name: + group_name = group_name.split(" ")[2].strip('"') + + groups = roles.get_roles_with_uuid_paginated(tenant_id, start_index, count, group_name) + res = [{ + "id": group["data"]["groupId"], + "meta": { + "created": group["createdAt"], + "lastModified": "", # not currently a field + "version": "v1.0" + }, + "displayName": group["name"] + } for group in groups + ] return JSONResponse( status_code=200, - content=res - ) + content=GroupGetResponse( + totalResults=len(groups), + startIndex=start_index, + itemsPerPage=len(groups), + Resources=res + ).model_dump(mode='json')) @public_app.get("/Groups/{group_id}", dependencies=[Depends(auth_required)]) def get_group(group_id: str): - group = roles.get_role_by_group_id(1, group_id) + """Get a group by id""" + tenant_id = 1 + group = roles.get_role_by_group_id(tenant_id, group_id) if not group: raise HTTPException(status_code=404, detail="Group not found") - + members = roles.get_users_by_group_uuid(tenant_id, group["data"]["groupId"]) + members = [{"value": member["data"]["userId"], "display": member["name"]} for member in members] + return JSONResponse( status_code=200, content=GroupResponse( - schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"], id=group["data"]["groupId"], displayName=group["name"], - members=[], # add later + members=members, ).model_dump(mode='json')) @public_app.post("/Groups", dependencies=[Depends(auth_required)]) def create_group(r: GroupRequest): - logger.info(r) + """Create a group""" + tenant_id = 1 + member_role = roles.get_member_permissions(tenant_id) try: - data = schemas.RolePayloadSchema(name=r.displayName, permissions=[schemas.Permissions.SESSION_REPLAY]) # one permission for now - group = roles.create_as_admin(1, uuid.uuid4().hex, data) + data = schemas.RolePayloadSchema(name=r.displayName, permissions=member_role["permissions"]) # permissions by default are same as for member role + group = roles.create_as_admin(tenant_id, uuid.uuid4().hex, data) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + added_members = [] + for member in r.members: + user = users.get_by_uuid(member["value"], tenant_id) + if user: + users.update(tenant_id, user["userId"], {"role_id": group["roleId"]}) + added_members.append({ + "value": user["data"]["userId"], + "display": user["name"] + }) + return JSONResponse( status_code=200, content=GroupResponse( - schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"], id=group["data"]["groupId"], displayName=group["name"], - members=[], # add later + members=added_members, ).model_dump(mode='json')) @public_app.put("/Groups/{group_id}", dependencies=[Depends(auth_required)]) def update_put_group(group_id: str, r: GroupRequest): - # Possibly need to change GroupRequest object to accept a different structure - logger.info(r) - group = roles.get_role_by_group_id(1, group_id) + """Update a group or members of the group (not used by anything yet)""" + tenant_id = 1 + group = roles.get_role_by_group_id(tenant_id, group_id) if not group: raise HTTPException(status_code=404, detail="Group not found") if r.operations and r.operations[0].op == "replace" and r.operations[0].path is None: - roles.update_group_name(1, group["data"]["groupId"], r.operations[0].value["displayName"]) + roles.update_group_name(tenant_id, group["data"]["groupId"], r.operations[0].value["displayName"]) return Response(status_code=200, content="") members = r.members modified_members = [] for member in members: - user = users.get_by_uuid(member["value"], 1) + user = users.get_by_uuid(member["value"], tenant_id) if user: - users.update(1, user["userId"], {"role_id": group["roleId"]}) + users.update(tenant_id, user["userId"], {"role_id": group["roleId"]}) modified_members.append({ "value": user["data"]["userId"], "display": user["name"] @@ -390,44 +387,41 @@ def update_put_group(group_id: str, r: GroupRequest): return JSONResponse( status_code=200, content=GroupResponse( - schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"], id=group_id, displayName=group["name"], members=modified_members, ).model_dump(mode='json')) - + @public_app.patch("/Groups/{group_id}", dependencies=[Depends(auth_required)]) def update_patch_group(group_id: str, r: GroupPatchRequest): - logger.info(r) - group = roles.get_role_by_group_id(1, group_id) + """Update a group or members of the group, used by AIW""" + tenant_id = 1 + group = roles.get_role_by_group_id(tenant_id, group_id) if not group: raise HTTPException(status_code=404, detail="Group not found") if r.operations[0].op == "replace" and r.operations[0].path is None: - roles.update_group_name(1, group["data"]["groupId"], r.operations[0].value["displayName"]) + roles.update_group_name(tenant_id, group["data"]["groupId"], r.operations[0].value["displayName"]) return Response(status_code=200, content="") - - if r.operations[0].op == "replace": - # find all members of that role, and for those that don't intersect with the list, set them to default role and return - pass + modified_members = [] for op in r.operations: - if op.op == "add": + if op.op == "add" or op.op == "replace": + # Both methods work as "replace" for u in op.value: - user = users.get_by_uuid(u["value"], 1) + user = users.get_by_uuid(u["value"], tenant_id) if user: - users.update(1, user["userId"], {"role_id": group["roleId"]}) + users.update(tenant_id, user["userId"], {"role_id": group["roleId"]}) modified_members.append({ "value": user["data"]["userId"], "display": user["name"] }) - else: - # possibly remove by parsing the path? - pass + elif op.op == "remove": + user_id = re.search(r'\[value eq \"([a-f0-9]+)\"\]', op.path).group(1) + roles.remove_group_membership(tenant_id, group_id, user_id) return JSONResponse( status_code=200, content=GroupResponse( - schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"], id=group_id, displayName=group["name"], members=modified_members, @@ -436,9 +430,11 @@ def update_patch_group(group_id: str, r: GroupPatchRequest): @public_app.delete("/Groups/{group_id}", dependencies=[Depends(auth_required)]) def delete_group(group_id: str): - group = roles.get_role_by_group_id(1, group_id) + """Delete a group, hard-delete""" + tenant_id = 1 + group = roles.get_role_by_group_id(tenant_id, group_id) if not group: raise HTTPException(status_code=404, detail="Group not found") - roles.delete_scim_group(1, group["data"]["groupId"]) + roles.delete_scim_group(tenant_id, group["data"]["groupId"]) return Response(status_code=200, content="") From cd70633d1fe7bd5daab23e6fdcd2c16c9f44e403 Mon Sep 17 00:00:00 2001 From: Pavel Kim Date: Thu, 20 Feb 2025 14:31:32 +0100 Subject: [PATCH 4/4] [Draft] add auth flow with JWT --- ee/api/chalicelib/core/tenants.py | 14 +++++ ee/api/chalicelib/utils/scim_auth.py | 77 ++++++++++++++++++++++++++++ ee/api/routers/scim.py | 44 ++++++++++++---- 3 files changed, 126 insertions(+), 9 deletions(-) create mode 100644 ee/api/chalicelib/utils/scim_auth.py diff --git a/ee/api/chalicelib/core/tenants.py b/ee/api/chalicelib/core/tenants.py index ca2d59dde..84f2c6d3a 100644 --- a/ee/api/chalicelib/core/tenants.py +++ b/ee/api/chalicelib/core/tenants.py @@ -56,6 +56,20 @@ def get_by_api_key(api_key): return helper.dict_to_camel_case(cur.fetchone()) +def get_by_name(name): + with pg_client.PostgresClient() as cur: + query = cur.mogrify(f"""SELECT tenants.tenant_id, + tenants.name, + tenants.created_at + FROM public.tenants + WHERE tenants.name = %(name)s + AND tenants.deleted_at ISNULL + LIMIT 1;""", + {"name": name}) + cur.execute(query=query) + return helper.dict_to_camel_case(cur.fetchone()) + + def generate_new_api_key(tenant_id): with pg_client.PostgresClient() as cur: query = cur.mogrify(f"""UPDATE public.tenants diff --git a/ee/api/chalicelib/utils/scim_auth.py b/ee/api/chalicelib/utils/scim_auth.py new file mode 100644 index 000000000..83e779c40 --- /dev/null +++ b/ee/api/chalicelib/utils/scim_auth.py @@ -0,0 +1,77 @@ +import logging +import time +import jwt + +from decouple import config +from fastapi import HTTPException, Depends +from fastapi.security import OAuth2PasswordBearer + +logger = logging.getLogger(__name__) + +ACCESS_SECRET_KEY = config("SCIM_ACCESS_SECRET_KEY") +REFRESH_SECRET_KEY = config("SCIM_REFRESH_SECRET_KEY") +ALGORITHM = config("SCIM_JWT_ALGORITHM") +ACCESS_TOKEN_EXPIRE_SECONDS = int(config("SCIM_ACCESS_TOKEN_EXPIRE_SECONDS")) +REFRESH_TOKEN_EXPIRE_SECONDS = int(config("SCIM_REFRESH_TOKEN_EXPIRE_SECONDS")) +AUDIENCE="okta_client" +ISSUER=config("JWT_ISSUER"), + +# Simulated Okta Client Credentials +# OKTA_CLIENT_ID = "okta-client" +# OKTA_CLIENT_SECRET = "okta-secret" + +# class TokenRequest(BaseModel): +# client_id: str +# client_secret: str + +# async def authenticate_client(token_request: TokenRequest): +# """Validate Okta Client Credentials and issue JWT""" +# if token_request.client_id != OKTA_CLIENT_ID or token_request.client_secret != OKTA_CLIENT_SECRET: +# raise HTTPException(status_code=401, detail="Invalid client credentials") + +# return {"access_token": create_jwt(), "token_type": "bearer"} + +def create_tokens(tenant_id): + curr_time = time.time() + access_payload = { + "tenant_id": tenant_id, + "sub": "scim_server", + "aud": AUDIENCE, + "iss": ISSUER, + "exp": "" + } + access_payload.update({"exp": curr_time + ACCESS_TOKEN_EXPIRE_SECONDS}) + access_token = jwt.encode(access_payload, ACCESS_SECRET_KEY, algorithm=ALGORITHM) + + refresh_payload = access_payload.copy() + refresh_payload.update({"exp": curr_time + REFRESH_TOKEN_EXPIRE_SECONDS}) + refresh_token = jwt.encode(refresh_payload, REFRESH_SECRET_KEY, algorithm=ALGORITHM) + + return access_token, refresh_token + +def verify_access_token(token: str): + try: + payload = jwt.decode(token, ACCESS_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE) + return payload + except jwt.ExpiredSignatureError: + raise HTTPException(status_code=401, detail="Token expired") + except jwt.InvalidTokenError: + raise HTTPException(status_code=401, detail="Invalid token") + +def verify_refresh_token(token: str): + try: + payload = jwt.decode(token, REFRESH_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE) + return payload + except jwt.ExpiredSignatureError: + raise HTTPException(status_code=401, detail="Token expired") + except jwt.InvalidTokenError: + raise HTTPException(status_code=401, detail="Invalid token") + + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +# Authentication Dependency +def auth_required(token: str = Depends(oauth2_scheme)): + """Dependency to check Authorization header.""" + if config("SCIM_AUTH_TYPE") == "OAuth2": + payload = verify_access_token(token) + return payload["tenant_id"] diff --git a/ee/api/routers/scim.py b/ee/api/routers/scim.py index 85a777c90..8a0975492 100644 --- a/ee/api/routers/scim.py +++ b/ee/api/routers/scim.py @@ -6,25 +6,50 @@ from typing import Optional from decouple import config from fastapi import Depends, HTTPException, Header, Query, Response from fastapi.responses import JSONResponse +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from pydantic import BaseModel, Field import schemas -from chalicelib.core import users, roles +from chalicelib.core import users, roles, tenants +from chalicelib.utils.scim_auth import auth_required, create_tokens, verify_refresh_token from routers.base import get_routers + logger = logging.getLogger(__name__) - public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2") -# Authentication Dependency -def auth_required(authorization: str = Header(..., alias="Authorization")): - """Dependency to check Authorization header.""" - token = authorization.replace("Bearer ", "") - if token != config("OCTA_TOKEN"): - raise HTTPException(status_code=403, detail="Unauthorized") - return token +"""Authentication endpoints""" + +class RefreshRequest(BaseModel): + refresh_token: str + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + +# Login endpoint to generate tokens +@public_app.post("/token") +async def login(host: str = Header(..., alias="Host"), form_data: OAuth2PasswordRequestForm = Depends()): + subdomain = host.split(".")[0] + + # Missing authentication part, to add + if form_data.username != config("SCIM_USER") or form_data.password != config("SCIM_PASSWORD"): + raise HTTPException(status_code=401, detail="Invalid credentials") + + subdomain = "Openreplay EE" + tenant = tenants.get_by_name(subdomain) + access_token, refresh_token = create_tokens(tenant_id=tenant["tenantId"]) + + return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"} + +# Refresh token endpoint +@public_app.post("/refresh") +async def refresh_token(r: RefreshRequest): + + payload = verify_refresh_token(r.refresh_token) + new_access_token, _ = create_tokens(tenant_id=payload["tenant_id"]) + + return {"access_token": new_access_token, "token_type": "Bearer"} """ User endpoints @@ -431,6 +456,7 @@ def update_patch_group(group_id: str, r: GroupPatchRequest): @public_app.delete("/Groups/{group_id}", dependencies=[Depends(auth_required)]) def delete_group(group_id: str): """Delete a group, hard-delete""" + # possibly need to set the user's roles to default member role, instead of null tenant_id = 1 group = roles.get_role_by_group_id(tenant_id, group_id) if not group: