diff --git a/ee/api/app.py b/ee/api/app.py index a9d9c59cd..672285032 100644 --- a/ee/api/app.py +++ b/ee/api/app.py @@ -26,7 +26,7 @@ from routers.subs import v1_api_ee if config("ENABLE_SSO", cast=bool, default=True): from routers import saml - from routers import scim + from routers.scim import api as scim loglevel = config("LOGLEVEL", default=logging.WARNING) print(f">Loglevel set to: {loglevel}") diff --git a/ee/api/chalicelib/core/users.py b/ee/api/chalicelib/core/users.py index 2c63faab1..d80907bad 100644 --- a/ee/api/chalicelib/core/users.py +++ b/ee/api/chalicelib/core/users.py @@ -2,11 +2,9 @@ import json import logging import secrets from typing import Optional -from datetime import datetime from decouple import config from fastapi import BackgroundTasks, HTTPException -from psycopg2.extensions import AsIs from psycopg2.extras import Json from pydantic import BaseModel, model_validator from starlette import status @@ -352,239 +350,6 @@ def get(user_id, tenant_id): return helper.dict_to_camel_case(r) -def count_total_scim_users(tenant_id: int) -> int: - with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify( - """ - SELECT COUNT(*) - FROM public.users - WHERE - users.tenant_id = %(tenant_id)s - AND users.deleted_at IS NULL - """, - {"tenant_id": tenant_id}, - ) - ) - return cur.fetchone()["count"] - - -def get_scim_users_paginated(start_index, tenant_id, count=None): - with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify( - """ - SELECT - users.*, - roles.name AS role_name - FROM public.users - LEFT JOIN public.roles USING (role_id) - WHERE - users.tenant_id = %(tenant_id)s - AND users.deleted_at IS NULL - LIMIT %(limit)s - OFFSET %(offset)s; - """, - {"offset": start_index - 1, "limit": count, "tenant_id": tenant_id}, - ) - ) - r = cur.fetchall() - return helper.list_to_camel_case(r) - - -def get_scim_user_by_id(user_id, tenant_id): - with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify( - """ - SELECT - users.*, - roles.name AS role_name - FROM public.users - LEFT JOIN public.roles USING (role_id) - WHERE - users.user_id = %(user_id)s - AND users.tenant_id = %(tenant_id)s - AND users.deleted_at IS NULL - LIMIT 1; - """, - { - "user_id": user_id, - "tenant_id": tenant_id, - }, - ) - ) - return helper.dict_to_camel_case(cur.fetchone()) - - -def create_scim_user( - email: str, - tenant_id: int, - name: str = "", - internal_id: str | None = None, - role_id: int | None = None, -): - with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify( - """ - WITH u AS ( - INSERT INTO public.users ( - tenant_id, - email, - name, - internal_id, - role_id - ) - VALUES ( - %(tenant_id)s, - %(email)s, - %(name)s, - %(internal_id)s, - %(role_id)s - ) - RETURNING * - ) - SELECT - u.*, - roles.name as role_name - FROM u LEFT JOIN public.roles USING (role_id); - """, - { - "tenant_id": tenant_id, - "email": email, - "name": name, - "internal_id": internal_id, - "role_id": role_id, - }, - ) - ) - return helper.dict_to_camel_case(cur.fetchone()) - - -def restore_scim_user( - tenant_id: int, - email: str, - name: str = "", - internal_id: str | None = None, - role_id: int | None = None, - **kwargs, -): - with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify( - """ - WITH u AS ( - UPDATE public.users - SET - tenant_id = %(tenant_id)s, - email = %(email)s, - name = %(name)s, - internal_id = %(internal_id)s, - role_id = %(role_id)s, - deleted_at = NULL, - created_at = now(), - updated_at = now(), - api_key = default, - jwt_iat = NULL, - weekly_report = default - WHERE users.email = %(email)s - RETURNING * - ) - SELECT - u.*, - roles.name as role_name - FROM u LEFT JOIN public.roles USING (role_id); - """, - { - "tenant_id": tenant_id, - "email": email, - "name": name, - "internal_id": internal_id, - "role_id": role_id, - }, - ) - ) - return helper.dict_to_camel_case(cur.fetchone()) - - -def update_scim_user( - user_id: int, - tenant_id: int, - email: str, - name: str = "", - internal_id: str | None = None, - role_id: int | None = None, -): - with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify( - """ - WITH u AS ( - UPDATE public.users - SET - email = %(email)s, - name = %(name)s, - internal_id = %(internal_id)s, - role_id = %(role_id)s, - updated_at = now() - WHERE - users.user_id = %(user_id)s - AND users.tenant_id = %(tenant_id)s - AND users.deleted_at IS NULL - RETURNING * - ) - SELECT - u.*, - roles.name as role_name - FROM u LEFT JOIN public.roles USING (role_id); - """, - { - "tenant_id": tenant_id, - "user_id": user_id, - "email": email, - "name": name, - "internal_id": internal_id, - "role_id": role_id, - }, - ) - ) - return helper.dict_to_camel_case(cur.fetchone()) - - -def patch_scim_user( - user_id: int, - tenant_id: int, - **kwargs, -): - with pg_client.PostgresClient() as cur: - set_fragments = [] - kwargs["updated_at"] = datetime.now() - for k, v in kwargs.items(): - fragment = cur.mogrify( - "%s = %s", - (AsIs(k), v), - ).decode("utf-8") - set_fragments.append(fragment) - set_clause = ", ".join(set_fragments) - query = f""" - WITH u AS ( - UPDATE public.users - SET {set_clause} - WHERE - users.user_id = {user_id} - AND users.tenant_id = {tenant_id} - AND users.deleted_at IS NULL - RETURNING * - ) - SELECT - u.*, - roles.name as role_name - FROM u LEFT JOIN public.roles USING (role_id);""" - cur.execute(query) - return helper.dict_to_camel_case(cur.fetchone()) - - def generate_new_api_key(user_id): with pg_client.PostgresClient() as cur: cur.execute( @@ -695,21 +460,6 @@ def edit_member( return {"data": user} -def get_existing_scim_user_by_unique_values_from_all_users(email: str, **kwargs): - with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify( - """ - SELECT * - FROM public.users - WHERE users.email = %(email)s - """, - {"email": email}, - ) - ) - return helper.dict_to_camel_case(cur.fetchone()) - - def get_by_email_only(email): with pg_client.PostgresClient() as cur: cur.execute( @@ -1255,24 +1005,6 @@ def create_sso_user(tenant_id, email, admin, name, origin, role_id, internal_id= return helper.dict_to_camel_case(cur.fetchone()) -def soft_delete_scim_user_by_id(user_id, tenant_id): - with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify( - """ - UPDATE public.users - SET - deleted_at = NULL, - updated_at = default - WHERE - users.user_id = %(user_id)s - AND users.tenant_id = %(tenant_id)s - """, - {"user_id": user_id, "tenant_id": tenant_id}, - ) - ) - - def __hard_delete_user(user_id): with pg_client.PostgresClient() as cur: query = cur.mogrify( diff --git a/ee/api/routers/scim.py b/ee/api/routers/scim.py deleted file mode 100644 index 024fde933..000000000 --- a/ee/api/routers/scim.py +++ /dev/null @@ -1,644 +0,0 @@ -from copy import deepcopy -import logging -from typing import Any, Callable -from enum import Enum -from datetime import datetime - -from decouple import config -from fastapi import Depends, HTTPException, Header, Query, Response, Request -from fastapi.responses import JSONResponse -from fastapi.security import OAuth2PasswordRequestForm -from pydantic import BaseModel -from psycopg2 import errors - -from chalicelib.core import users, roles, tenants -from chalicelib.utils.scim_auth import ( - auth_optional, - auth_required, - create_tokens, - verify_refresh_token, -) -from routers.base import get_routers -from routers.scim_constants import RESOURCE_TYPES, SCHEMAS, SERVICE_PROVIDER_CONFIG -from routers import scim_helpers, scim_groups - - -logger = logging.getLogger(__name__) - -public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2") - - -@public_app.post("/token") -async def post_token( - host: str = Header(..., alias="Host"), - form_data: OAuth2PasswordRequestForm = Depends(), -): - subdomain = host.split(".")[0] - - # Missing authentication part, to add - if form_data.username != config("SCIM_USER") or form_data.password != config( - "SCIM_PASSWORD" - ): - raise HTTPException(status_code=401, detail="Invalid credentials") - - tenant = tenants.get_by_name(subdomain) - access_token, refresh_token = create_tokens(tenant_id=tenant["tenantId"]) - - return { - "access_token": access_token, - "refresh_token": refresh_token, - "token_type": "Bearer", - } - - -class RefreshRequest(BaseModel): - refresh_token: str - - -@public_app.post("/refresh") -async def post_refresh(r: RefreshRequest): - payload = verify_refresh_token(r.refresh_token) - new_access_token, _ = create_tokens(tenant_id=payload["tenant_id"]) - return {"access_token": new_access_token, "token_type": "Bearer"} - - -RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS = { - resource_type_detail["id"]: resource_type_detail - for resource_type_detail in RESOURCE_TYPES -} - - -def _not_found_error_response(resource_id: int): - return JSONResponse( - status_code=404, - content={ - "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], - "detail": f"Resource {resource_id} not found", - "status": "404", - }, - ) - - -def _uniqueness_error_response(): - return JSONResponse( - status_code=409, - content={ - "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], - "detail": "One or more of the attribute values are already in use or are reserved.", - "status": "409", - "scimType": "uniqueness", - }, - ) - - -def _mutability_error_response(): - return JSONResponse( - status_code=400, - content={ - "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], - "detail": "The attempted modification is not compatible with the target attribute's mutability or current state.", - "status": "400", - "scimType": "mutability", - }, - ) - - -def _operation_not_permitted_error_response(): - return JSONResponse( - status_code=403, - content={ - "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], - "detail": "Operation is not permitted based on the supplied authorization", - "status": "403", - }, - ) - - -def _invalid_value_error_response(): - return JSONResponse( - status_code=400, - content={ - "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], - "detail": "A required value was missing, or the value specified was not compatible with the operation or attribtue type, or resource schema.", - "status": "400", - "scimType": "invalidValue", - }, - ) - - -def _internal_server_error_response(detail: str): - return JSONResponse( - status_code=500, - content={ - "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], - "detail": detail, - "status": "500", - }, - ) - - -@public_app.get("/ResourceTypes", dependencies=[Depends(auth_required)]) -async def get_resource_types(filter_param: str | None = Query(None, alias="filter")): - if filter_param is not None: - return _operation_not_permitted_error_response() - return JSONResponse( - status_code=200, - content={ - "totalResults": len(RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS), - "itemsPerPage": len(RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS), - "startIndex": 1, - "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], - "Resources": list(RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS.values()), - }, - ) - - -@public_app.get("/ResourceTypes/{resource_id}", dependencies=[Depends(auth_required)]) -async def get_resource_type(resource_id: str): - if resource_id not in RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS: - return _not_found_error_response(resource_id) - return JSONResponse( - status_code=200, - content=RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS[resource_id], - ) - - -SCHEMA_IDS_TO_SCHEMA_DETAILS = { - schema_detail["id"]: schema_detail for schema_detail in SCHEMAS -} - - -@public_app.get("/Schemas", dependencies=[Depends(auth_required)]) -async def get_schemas(filter_param: str | None = Query(None, alias="filter")): - if filter_param is not None: - return _operation_not_permitted_error_response() - return JSONResponse( - status_code=200, - content={ - "totalResults": len(SCHEMA_IDS_TO_SCHEMA_DETAILS), - "itemsPerPage": len(SCHEMA_IDS_TO_SCHEMA_DETAILS), - "startIndex": 1, - "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], - "Resources": [ - value for _, value in sorted(SCHEMA_IDS_TO_SCHEMA_DETAILS.items()) - ], - }, - ) - - -@public_app.get("/Schemas/{schema_id}") -async def get_schema(schema_id: str, tenant_id=Depends(auth_required)): - if schema_id not in SCHEMA_IDS_TO_SCHEMA_DETAILS: - return _not_found_error_response(schema_id) - schema = deepcopy(SCHEMA_IDS_TO_SCHEMA_DETAILS[schema_id]) - if schema_id == "urn:ietf:params:scim:schemas:core:2.0:User": - db_roles = roles.get_roles(tenant_id) - role_names = [role["name"] for role in db_roles] - user_type_attribute = next( - filter(lambda x: x["name"] == "userType", schema["attributes"]) - ) - user_type_attribute["canonicalValues"] = role_names - return JSONResponse( - status_code=200, - content=schema, - ) - - -# note(jon): it was recommended to make this endpoint partially open -# so that clients can view the `authenticationSchemes` prior to being authenticated. -@public_app.get("/ServiceProviderConfig") -async def get_service_provider_config( - r: Request, tenant_id: str | None = Depends(auth_optional) -): - is_authenticated = tenant_id is not None - if not is_authenticated: - return JSONResponse( - status_code=200, - content={ - "schemas": SERVICE_PROVIDER_CONFIG["schemas"], - "authenticationSchemes": SERVICE_PROVIDER_CONFIG[ - "authenticationSchemes" - ], - "meta": SERVICE_PROVIDER_CONFIG["meta"], - }, - ) - return JSONResponse(status_code=200, content=SERVICE_PROVIDER_CONFIG) - - -def _serialize_db_resource_to_scim_resource_with_attribute_awareness( - db_resource: dict[str, Any], - schema_id: str, - serialize_db_resource_to_scim_resource: Callable[[dict[str, Any]], dict[str, Any]], - attributes: list[str] | None = None, - excluded_attributes: list[str] | None = None, -) -> dict[str, Any]: - schema = SCHEMA_IDS_TO_SCHEMA_DETAILS[schema_id] - all_attributes = scim_helpers.get_all_attribute_names(schema) - attributes = attributes or all_attributes - always_returned_attributes = ( - scim_helpers.get_all_attribute_names_where_returned_is_always(schema) - ) - included_attributes = list(set(attributes).union(set(always_returned_attributes))) - excluded_attributes = excluded_attributes or [] - excluded_attributes = list( - set(excluded_attributes).difference(set(always_returned_attributes)) - ) - scim_resource = serialize_db_resource_to_scim_resource(db_resource) - scim_resource = scim_helpers.filter_attributes(scim_resource, included_attributes) - scim_resource = scim_helpers.exclude_attributes(scim_resource, excluded_attributes) - return scim_resource - - -def _parse_scim_user_input(data: dict[str, Any], tenant_id: str) -> dict[str, Any]: - role_id = None - if "userType" in data: - role = roles.get_role_by_name(tenant_id, data["userType"]) - role_id = role["roleId"] if role else None - name = data.get("name", {}).get("formatted") - if not name: - name = " ".join( - [ - x - for x in [ - data.get("name", {}).get("honorificPrefix"), - data.get("name", {}).get("givenName"), - data.get("name", {}).get("middleName"), - data.get("name", {}).get("familyName"), - data.get("name", {}).get("honorificSuffix"), - ] - if x - ] - ) - result = { - "email": data["userName"], - "internal_id": data.get("externalId"), - "name": name, - "role_id": role_id, - } - result = {k: v for k, v in result.items() if v is not None} - return result - - -def _parse_user_patch_payload(data: dict[str, Any], tenant_id: str) -> dict[str, Any]: - result = {} - if "userType" in data: - role = roles.get_role_by_name(tenant_id, data["userType"]) - result["role_id"] = role["roleId"] if role else None - if "name" in data: - # note(jon): we're currently not handling the case where the client - # send patches of individual name components (e.g. name.middleName) - name = data.get("name", {}).get("formatted") - if name: - result["name"] = name - if "userName" in data: - result["email"] = data["userName"] - if "externalId" in data: - result["internal_id"] = data["externalId"] - if "active" in data: - result["deleted_at"] = None if data["active"] else datetime.now() - return result - - -def _serialize_db_user_to_scim_user(db_user: dict[str, Any]) -> dict[str, Any]: - return { - "id": str(db_user["userId"]), - "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], - "meta": { - "resourceType": "User", - "created": db_user["createdAt"].strftime("%Y-%m-%dT%H:%M:%SZ"), - "lastModified": db_user["updatedAt"].strftime("%Y-%m-%dT%H:%M:%SZ"), - "location": f"Users/{db_user['userId']}", - }, - "userName": db_user["email"], - "externalId": db_user["internalId"], - "name": { - "formatted": db_user["name"], - }, - "displayName": db_user["name"] or db_user["email"], - "userType": db_user.get("roleName"), - "active": db_user["deletedAt"] is None, - } - - -def _serialize_db_group_to_scim_group(db_resource: dict[str, Any]) -> dict[str, Any]: - members = db_resource["users"] or [] - return { - "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], - "id": str(db_resource["groupId"]), - "externalId": db_resource["externalId"], - "meta": { - "resourceType": "Group", - "created": db_resource["createdAt"].strftime("%Y-%m-%dT%H:%M:%SZ"), - "lastModified": db_resource["updatedAt"].strftime("%Y-%m-%dT%H:%M:%SZ"), - "location": f"Groups/{db_resource['groupId']}", - }, - "displayName": db_resource["name"], - "members": [ - { - "value": str(member["userId"]), - "$ref": f"Users/{member['userId']}", - "type": "User", - } - for member in members - ], - } - - -def _parse_scim_group_input(data: dict[str, Any], tenant_id: int) -> dict[str, Any]: - return { - "name": data["displayName"], - "external_id": data.get("externalId"), - "user_ids": [int(member["value"]) for member in data.get("members", [])], - } - - -def _parse_scim_group_patch(data: dict[str, Any], tenant_id: int) -> dict[str, Any]: - result = {} - if "displayName" in data: - result["name"] = data["displayName"] - if "externalId" in data: - result["external_id"] = data["externalId"] - if "members" in data: - members = data["members"] or [] - result["user_ids"] = [int(member["value"]) for member in members] - return result - - -RESOURCE_TYPE_TO_RESOURCE_CONFIG = { - "Users": { - "max_items_per_page": 10, - "schema_id": "urn:ietf:params:scim:schemas:core:2.0:User", - "db_to_scim_serializer": _serialize_db_user_to_scim_user, - "count_total_resources": users.count_total_scim_users, - "get_paginated_resources": users.get_scim_users_paginated, - "get_unique_resource": users.get_scim_user_by_id, - "parse_post_payload": _parse_scim_user_input, - "get_resource_by_unique_values": users.get_existing_scim_user_by_unique_values_from_all_users, - "restore_resource": users.restore_scim_user, - "create_resource": users.create_scim_user, - "delete_resource": users.soft_delete_scim_user_by_id, - "parse_put_payload": _parse_scim_user_input, - "update_resource": users.update_scim_user, - "parse_patch_payload": _parse_user_patch_payload, - "patch_resource": users.patch_scim_user, - }, - "Groups": { - "max_items_per_page": 10, - "schema_id": "urn:ietf:params:scim:schemas:core:2.0:Group", - "db_to_scim_serializer": _serialize_db_group_to_scim_group, - "count_total_resources": scim_groups.count_total_resources, - "get_paginated_resources": scim_groups.get_resources_paginated, - "get_unique_resource": scim_groups.get_resource_by_id, - "parse_post_payload": _parse_scim_group_input, - "get_resource_by_unique_values": scim_groups.get_existing_resource_by_unique_values_from_all_resources, - # note(jon): we're not soft deleting groups, so we don't need this - "restore_resource": None, - "create_resource": scim_groups.create_resource, - "delete_resource": scim_groups.delete_resource, - "parse_put_payload": _parse_scim_group_input, - "update_resource": scim_groups.update_resource, - "parse_patch_payload": _parse_scim_group_patch, - "patch_resource": scim_groups.patch_resource, - }, -} - - -class ListResourceType(str, Enum): - USERS = "Users" - GROUPS = "Groups" - - -@public_app.get("/{resource_type}") -async def get_resources( - resource_type: ListResourceType, - tenant_id=Depends(auth_required), - requested_start_index: int = Query(1, alias="startIndex"), - requested_items_per_page: int | None = Query(None, alias="count"), - attributes: str | None = Query(None), - excluded_attributes: str | None = Query(None, alias="excludedAttributes"), -): - resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] - start_index = max(1, requested_start_index) - max_items_per_page = resource_config["max_items_per_page"] - items_per_page = min( - max(0, requested_items_per_page or max_items_per_page), max_items_per_page - ) - total_resources = resource_config["count_total_resources"](tenant_id) - db_resources = resource_config["get_paginated_resources"]( - start_index, tenant_id, items_per_page - ) - attributes = scim_helpers.convert_query_str_to_list(attributes) - excluded_attributes = scim_helpers.convert_query_str_to_list(excluded_attributes) - scim_resources = [ - _serialize_db_resource_to_scim_resource_with_attribute_awareness( - db_resource, - resource_config["schema_id"], - resource_config["db_to_scim_serializer"], - attributes, - excluded_attributes, - ) - for db_resource in db_resources - ] - return JSONResponse( - status_code=200, - content={ - "totalResults": total_resources, - "startIndex": start_index, - "itemsPerPage": len(scim_resources), - "Resources": scim_resources, - }, - ) - - -class GetResourceType(str, Enum): - USERS = "Users" - GROUPS = "Groups" - - -@public_app.get("/{resource_type}/{resource_id}") -async def get_resource( - resource_type: GetResourceType, - resource_id: int, - tenant_id=Depends(auth_required), - attributes: list[str] | None = Query(None), - excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"), -): - resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] - db_resource = resource_config["get_unique_resource"](resource_id, tenant_id) - if not db_resource: - return _not_found_error_response(resource_id) - scim_resource = _serialize_db_resource_to_scim_resource_with_attribute_awareness( - db_resource, - resource_config["schema_id"], - resource_config["db_to_scim_serializer"], - attributes, - excluded_attributes, - ) - return JSONResponse(status_code=200, content=scim_resource) - - -class PostResourceType(str, Enum): - USERS = "Users" - GROUPS = "Groups" - - -@public_app.post("/{resource_type}") -async def create_resource( - resource_type: PostResourceType, - r: Request, - tenant_id=Depends(auth_required), -): - resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] - scim_payload = await r.json() - try: - db_payload = resource_config["parse_post_payload"](scim_payload, tenant_id) - except KeyError: - return _invalid_value_error_response() - existing_db_resource = resource_config["get_resource_by_unique_values"]( - **db_payload - ) - if existing_db_resource and existing_db_resource.get("deletedAt") is None: - return _uniqueness_error_response() - if existing_db_resource and existing_db_resource.get("deletedAt") is not None: - db_resource = resource_config["restore_resource"]( - tenant_id=tenant_id, **db_payload - ) - else: - db_resource = resource_config["create_resource"]( - tenant_id=tenant_id, - **db_payload, - ) - scim_resource = _serialize_db_resource_to_scim_resource_with_attribute_awareness( - db_resource, - resource_config["schema_id"], - resource_config["db_to_scim_serializer"], - ) - response = JSONResponse(status_code=201, content=scim_resource) - response.headers["Location"] = scim_resource["meta"]["location"] - return response - - -class DeleteResourceType(str, Enum): - USERS = "Users" - GROUPS = "Groups" - - -@public_app.delete("/{resource_type}/{resource_id}") -async def delete_resource( - resource_type: DeleteResourceType, - resource_id: str, - tenant_id=Depends(auth_required), -): - # note(jon): this can be a soft or a hard delete - resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] - db_resource = resource_config["get_unique_resource"](resource_id, tenant_id) - if not db_resource: - return _not_found_error_response(resource_id) - resource_config["delete_resource"](resource_id, tenant_id) - return Response(status_code=204, content="") - - -class PutResourceType(str, Enum): - USERS = "Users" - GROUPS = "Groups" - - -@public_app.put("/{resource_type}/{resource_id}") -async def put_resource( - resource_type: PutResourceType, - resource_id: str, - r: Request, - tenant_id=Depends(auth_required), -): - resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] - db_resource = resource_config["get_unique_resource"](resource_id, tenant_id) - if not db_resource: - return _not_found_error_response(resource_id) - current_scim_resource = ( - _serialize_db_resource_to_scim_resource_with_attribute_awareness( - db_resource, - resource_config["schema_id"], - resource_config["db_to_scim_serializer"], - ) - ) - requested_scim_changes = await r.json() - schema = SCHEMA_IDS_TO_SCHEMA_DETAILS[resource_config["schema_id"]] - try: - valid_mutable_scim_changes = scim_helpers.filter_mutable_attributes( - schema, requested_scim_changes, current_scim_resource - ) - except ValueError: - return _mutability_error_response() - valid_mutable_db_changes = resource_config["parse_put_payload"]( - valid_mutable_scim_changes, - tenant_id, - ) - try: - updated_db_resource = resource_config["update_resource"]( - resource_id, - tenant_id, - **valid_mutable_db_changes, - ) - updated_scim_resource = ( - _serialize_db_resource_to_scim_resource_with_attribute_awareness( - updated_db_resource, - resource_config["schema_id"], - resource_config["db_to_scim_serializer"], - ) - ) - return JSONResponse(status_code=200, content=updated_scim_resource) - except errors.UniqueViolation: - return _uniqueness_error_response() - except Exception as e: - return _internal_server_error_response(str(e)) - - -class PatchResourceType(str, Enum): - USERS = "Users" - GROUPS = "Groups" - - -@public_app.patch("/{resource_type}/{resource_id}") -async def patch_resource( - resource_type: PatchResourceType, - resource_id: str, - r: Request, - tenant_id=Depends(auth_required), -): - resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] - db_resource = resource_config["get_unique_resource"](resource_id, tenant_id) - if not db_resource: - return _not_found_error_response(resource_id) - current_scim_resource = ( - _serialize_db_resource_to_scim_resource_with_attribute_awareness( - db_resource, - resource_config["schema_id"], - resource_config["db_to_scim_serializer"], - ) - ) - payload = await r.json() - _, changes = scim_helpers.apply_scim_patch( - payload["Operations"], - current_scim_resource, - SCHEMA_IDS_TO_SCHEMA_DETAILS[resource_config["schema_id"]], - ) - reformatted_scim_changes = { - k: new_value for k, (old_value, new_value) in changes.items() - } - db_changes = resource_config["parse_patch_payload"]( - reformatted_scim_changes, - tenant_id, - ) - updated_db_resource = resource_config["patch_resource"]( - resource_id, - tenant_id, - **db_changes, - ) - updated_scim_resource = ( - _serialize_db_resource_to_scim_resource_with_attribute_awareness( - updated_db_resource, - resource_config["schema_id"], - resource_config["db_to_scim_serializer"], - ) - ) - return JSONResponse(status_code=200, content=updated_scim_resource) diff --git a/ee/api/routers/scim/__init__.py b/ee/api/routers/scim/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ee/api/routers/scim/api.py b/ee/api/routers/scim/api.py new file mode 100644 index 000000000..2607566cb --- /dev/null +++ b/ee/api/routers/scim/api.py @@ -0,0 +1,461 @@ +import logging +from copy import deepcopy +from enum import Enum + +from decouple import config +from fastapi import Depends, HTTPException, Header, Query, Response, Request +from fastapi.responses import JSONResponse +from fastapi.security import OAuth2PasswordRequestForm +from pydantic import BaseModel +from psycopg2 import errors + +from chalicelib.core import roles, tenants +from chalicelib.utils.scim_auth import ( + auth_optional, + auth_required, + create_tokens, + verify_refresh_token, +) +from routers.base import get_routers +from routers.scim.constants import ( + SERVICE_PROVIDER_CONFIG, + RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS, + SCHEMA_IDS_TO_SCHEMA_DETAILS, +) +from routers.scim import helpers, groups, users +from routers.scim.resource_config import ResourceConfig +from routers.scim import resource_config as api_helper + + +logger = logging.getLogger(__name__) + +public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2") + + +@public_app.post("/token") +async def post_token( + host: str = Header(..., alias="Host"), + form_data: OAuth2PasswordRequestForm = Depends(), +): + subdomain = host.split(".")[0] + + # Missing authentication part, to add + if form_data.username != config("SCIM_USER") or form_data.password != config( + "SCIM_PASSWORD" + ): + raise HTTPException(status_code=401, detail="Invalid credentials") + + tenant = tenants.get_by_name(subdomain) + access_token, refresh_token = create_tokens(tenant_id=tenant["tenantId"]) + + return { + "access_token": access_token, + "refresh_token": refresh_token, + "token_type": "Bearer", + } + + +class RefreshRequest(BaseModel): + refresh_token: str + + +@public_app.post("/refresh") +async def post_refresh(r: RefreshRequest): + payload = verify_refresh_token(r.refresh_token) + new_access_token, _ = create_tokens(tenant_id=payload["tenant_id"]) + return {"access_token": new_access_token, "token_type": "Bearer"} + + +def _not_found_error_response(resource_id: int): + return JSONResponse( + status_code=404, + content={ + "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], + "detail": f"Resource {resource_id} not found", + "status": "404", + }, + ) + + +def _uniqueness_error_response(): + return JSONResponse( + status_code=409, + content={ + "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], + "detail": "One or more of the attribute values are already in use or are reserved.", + "status": "409", + "scimType": "uniqueness", + }, + ) + + +def _mutability_error_response(): + return JSONResponse( + status_code=400, + content={ + "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], + "detail": "The attempted modification is not compatible with the target attribute's mutability or current state.", + "status": "400", + "scimType": "mutability", + }, + ) + + +def _operation_not_permitted_error_response(): + return JSONResponse( + status_code=403, + content={ + "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], + "detail": "Operation is not permitted based on the supplied authorization", + "status": "403", + }, + ) + + +def _invalid_value_error_response(): + return JSONResponse( + status_code=400, + content={ + "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], + "detail": "A required value was missing, or the value specified was not compatible with the operation or attribtue type, or resource schema.", + "status": "400", + "scimType": "invalidValue", + }, + ) + + +def _internal_server_error_response(detail: str): + return JSONResponse( + status_code=500, + content={ + "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], + "detail": detail, + "status": "500", + }, + ) + + +# note(jon): it was recommended to make this endpoint partially open +# so that clients can view the `authenticationSchemes` prior to being authenticated. +@public_app.get("/ServiceProviderConfig") +async def get_service_provider_config( + r: Request, tenant_id: str | None = Depends(auth_optional) +): + is_authenticated = tenant_id is not None + if not is_authenticated: + return JSONResponse( + status_code=200, + content={ + "schemas": SERVICE_PROVIDER_CONFIG["schemas"], + "authenticationSchemes": SERVICE_PROVIDER_CONFIG[ + "authenticationSchemes" + ], + "meta": SERVICE_PROVIDER_CONFIG["meta"], + }, + ) + return JSONResponse(status_code=200, content=SERVICE_PROVIDER_CONFIG) + + +@public_app.get("/ResourceTypes", dependencies=[Depends(auth_required)]) +async def get_resource_types(filter_param: str | None = Query(None, alias="filter")): + if filter_param is not None: + return _operation_not_permitted_error_response() + return JSONResponse( + status_code=200, + content={ + "totalResults": len(RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS), + "itemsPerPage": len(RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS), + "startIndex": 1, + "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], + "Resources": list(RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS.values()), + }, + ) + + +@public_app.get("/ResourceTypes/{resource_id}", dependencies=[Depends(auth_required)]) +async def get_resource_type(resource_id: str): + if resource_id not in RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS: + return _not_found_error_response(resource_id) + return JSONResponse( + status_code=200, + content=RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS[resource_id], + ) + + +@public_app.get("/Schemas", dependencies=[Depends(auth_required)]) +async def get_schemas(filter_param: str | None = Query(None, alias="filter")): + if filter_param is not None: + return _operation_not_permitted_error_response() + return JSONResponse( + status_code=200, + content={ + "totalResults": len(SCHEMA_IDS_TO_SCHEMA_DETAILS), + "itemsPerPage": len(SCHEMA_IDS_TO_SCHEMA_DETAILS), + "startIndex": 1, + "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], + "Resources": [ + value for _, value in sorted(SCHEMA_IDS_TO_SCHEMA_DETAILS.items()) + ], + }, + ) + + +@public_app.get("/Schemas/{schema_id}") +async def get_schema(schema_id: str, tenant_id=Depends(auth_required)): + if schema_id not in SCHEMA_IDS_TO_SCHEMA_DETAILS: + return _not_found_error_response(schema_id) + schema = deepcopy(SCHEMA_IDS_TO_SCHEMA_DETAILS[schema_id]) + if schema_id == "urn:ietf:params:scim:schemas:core:2.0:User": + db_roles = roles.get_roles(tenant_id) + role_names = [role["name"] for role in db_roles] + user_type_attribute = next( + filter(lambda x: x["name"] == "userType", schema["attributes"]) + ) + user_type_attribute["canonicalValues"] = role_names + return JSONResponse( + status_code=200, + content=schema, + ) + + +user_config = ResourceConfig( + schema_id="urn:ietf:params:scim:schemas:core:2.0:User", + max_chunk_size=10, + get_active_resource_count=users.get_active_resource_count, + convert_provider_resource_to_client_resource=users.convert_provider_resource_to_client_resource, + get_provider_resource_chunk=users.get_provider_resource_chunk, + get_provider_resource=users.get_provider_resource, + convert_client_resource_creation_input_to_provider_resource_creation_input=users.convert_client_resource_creation_input_to_provider_resource_creation_input, + get_provider_resource_from_unique_fields=users.get_provider_resource_from_unique_fields, + restore_provider_resource=users.restore_provider_resource, + create_provider_resource=users.create_provider_resource, + delete_provider_resource=users.delete_provider_resource, + convert_client_resource_rewrite_input_to_provider_resource_rewrite_input=users.convert_client_resource_rewrite_input_to_provider_resource_rewrite_input, + rewrite_provider_resource=users.rewrite_provider_resource, + convert_client_resource_update_input_to_provider_resource_update_input=users.convert_client_resource_update_input_to_provider_resource_update_input, + update_provider_resource=users.update_provider_resource, +) +group_config = ResourceConfig( + schema_id="urn:ietf:params:scim:schemas:core:2.0:Group", + max_chunk_size=10, + get_active_resource_count=groups.get_active_resource_count, + convert_provider_resource_to_client_resource=groups.convert_provider_resource_to_client_resource, + get_provider_resource_chunk=groups.get_provider_resource_chunk, + get_provider_resource=groups.get_provider_resource, + convert_client_resource_creation_input_to_provider_resource_creation_input=groups.convert_client_resource_creation_input_to_provider_resource_creation_input, + get_provider_resource_from_unique_fields=groups.get_provider_resource_from_unique_fields, + restore_provider_resource=None, + create_provider_resource=groups.create_provider_resource, + delete_provider_resource=groups.delete_provider_resource, + convert_client_resource_rewrite_input_to_provider_resource_rewrite_input=groups.convert_client_resource_rewrite_input_to_provider_resource_rewrite_input, + rewrite_provider_resource=groups.rewrite_provider_resource, + convert_client_resource_update_input_to_provider_resource_update_input=groups.convert_client_resource_update_input_to_provider_resource_update_input, + update_provider_resource=groups.update_provider_resource, +) + +RESOURCE_TYPE_TO_RESOURCE_CONFIG: dict[str, ResourceConfig] = { + "Users": user_config, + "Groups": group_config, +} + + +class SCIMResource(str, Enum): + USERS = "Users" + GROUPS = "Groups" + + +@public_app.get("/{resource_type}") +async def get_resources( + resource_type: SCIMResource, + tenant_id=Depends(auth_required), + requested_start_index_one_indexed: int = Query(1, alias="startIndex"), + requested_items_per_page: int | None = Query(None, alias="count"), + attributes: str | None = Query(None), + excluded_attributes: str | None = Query(None, alias="excludedAttributes"), +): + config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] + total_resources = config.get_active_resource_count(tenant_id) + start_index_one_indexed = max(1, requested_start_index_one_indexed) + offset = start_index_one_indexed - 1 + limit = min( + max(0, requested_items_per_page or config.max_chunk_size), config.max_chunk_size + ) + provider_resources = config.get_provider_resource_chunk(offset, tenant_id, limit) + client_resources = [ + api_helper.convert_provider_resource_to_client_resource( + config, provider_resource, attributes, excluded_attributes + ) + for provider_resource in provider_resources + ] + return JSONResponse( + status_code=200, + content={ + "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], + "totalResults": total_resources, + "startIndex": start_index_one_indexed, + "itemsPerPage": len(client_resources), + "Resources": client_resources, + }, + ) + + +@public_app.get("/{resource_type}/{resource_id}") +async def get_resource( + resource_type: SCIMResource, + resource_id: int | str, + tenant_id=Depends(auth_required), + attributes: list[str] | None = Query(None), + excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"), +): + resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] + resource = api_helper.get_resource( + resource_config, + resource_id, + tenant_id, + attributes, + excluded_attributes, + ) + if not resource: + return _not_found_error_response(resource_id) + return JSONResponse(status_code=200, content=resource) + + +@public_app.post("/{resource_type}") +async def create_resource( + resource_type: SCIMResource, + r: Request, + tenant_id=Depends(auth_required), + attributes: list[str] | None = Query(None), + excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"), +): + config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] + payload = await r.json() + try: + provider_resource_input = config.convert_client_resource_creation_input_to_provider_resource_creation_input( + tenant_id, + payload, + ) + except KeyError: + return _invalid_value_error_response() + existing_provider_resource = config.get_provider_resource_from_unique_fields( + **provider_resource_input + ) + if ( + existing_provider_resource + and existing_provider_resource.get("deleted_at") is None + ): + return _uniqueness_error_response() + if ( + existing_provider_resource + and existing_provider_resource.get("deleted_at") is not None + ): + provider_resource = config.restore_provider_resource( + tenant_id=tenant_id, **provider_resource_input + ) + else: + provider_resource = config.create_provider_resource( + tenant_id=tenant_id, **provider_resource_input + ) + client_resource = api_helper.convert_provider_resource_to_client_resource( + config, provider_resource, attributes, excluded_attributes + ) + response = JSONResponse(status_code=201, content=client_resource) + response.headers["Location"] = client_resource["meta"]["location"] + return response + + +@public_app.delete("/{resource_type}/{resource_id}") +async def delete_resource( + resource_type: SCIMResource, + resource_id: str, + tenant_id=Depends(auth_required), +): + # note(jon): this can be a soft or a hard delete + config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] + resource = api_helper.get_resource(config, resource_id, tenant_id) + if not resource: + return _not_found_error_response(resource_id) + config.delete_provider_resource(resource_id, tenant_id) + return Response(status_code=204, content="") + + +@public_app.put("/{resource_type}/{resource_id}") +async def put_resource( + resource_type: SCIMResource, + resource_id: str, + r: Request, + tenant_id=Depends(auth_required), + attributes: list[str] | None = Query(None), + excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"), +): + config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] + client_resource = api_helper.get_resource(config, resource_id, tenant_id) + if not client_resource: + return _not_found_error_response(resource_id) + schema = api_helper.get_schema(config) + payload = await r.json() + try: + client_resource_input = helpers.filter_mutable_attributes( + schema, payload, client_resource + ) + except ValueError: + return _mutability_error_response() + provider_resource_input = ( + config.convert_client_resource_rewrite_input_to_provider_resource_rewrite_input( + tenant_id, client_resource_input + ) + ) + try: + provider_resource = config.rewrite_provider_resource( + resource_id, + tenant_id, + **provider_resource_input, + ) + except errors.UniqueViolation: + return _uniqueness_error_response() + except Exception as e: + return _internal_server_error_response(str(e)) + client_resource = api_helper.convert_provider_resource_to_client_resource( + config, provider_resource, attributes, excluded_attributes + ) + return JSONResponse(status_code=200, content=client_resource) + + +@public_app.patch("/{resource_type}/{resource_id}") +async def patch_resource( + resource_type: SCIMResource, + resource_id: str, + r: Request, + tenant_id=Depends(auth_required), + attributes: list[str] | None = Query(None), + excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"), +): + config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] + client_resource = api_helper.get_resource(config, resource_id, tenant_id) + if not client_resource: + return _not_found_error_response(resource_id) + schema = api_helper.get_schema(config) + payload = await r.json() + _, changes = helpers.apply_scim_patch( + payload["Operations"], client_resource, schema + ) + client_resource_input = { + k: new_value for k, (old_value, new_value) in changes.items() + } + provider_resource_input = ( + config.convert_client_resource_update_input_to_provider_resource_update_input( + tenant_id, client_resource_input + ) + ) + try: + provider_resource = config.update_provider_resource( + resource_id, tenant_id, **provider_resource_input + ) + except errors.UniqueViolation: + return _uniqueness_error_response() + except Exception as e: + return _internal_server_error_response(str(e)) + client_resource = api_helper.convert_provider_resource_to_client_resource( + config, provider_resource, attributes, excluded_attributes + ) + return JSONResponse(status_code=200, content=client_resource) diff --git a/ee/api/routers/scim/constants.py b/ee/api/routers/scim/constants.py new file mode 100644 index 000000000..74dd19705 --- /dev/null +++ b/ee/api/routers/scim/constants.py @@ -0,0 +1,33 @@ +# note(jon): please see https://datatracker.ietf.org/doc/html/rfc7643 for details on these constants +import json + +SCHEMAS = sorted( + [ + json.load( + open("routers/scim/fixtures/service_provider_config_schema.json", "r") + ), + json.load(open("routers/scim/fixtures/resource_type_schema.json", "r")), + json.load(open("routers/scim/fixtures/schema_schema.json", "r")), + json.load(open("routers/scim/fixtures/user_schema.json", "r")), + json.load(open("routers/scim/fixtures/group_schema.json", "r")), + ], + key=lambda x: x["id"], +) + +SCHEMA_IDS_TO_SCHEMA_DETAILS = { + schema_detail["id"]: schema_detail for schema_detail in SCHEMAS +} + +SERVICE_PROVIDER_CONFIG = json.load( + open("routers/scim/fixtures/service_provider_config.json", "r") +) + +RESOURCE_TYPES = sorted( + json.load(open("routers/scim/fixtures/resource_type.json", "r")), + key=lambda x: x["id"], +) + +RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS = { + resource_type_detail["id"]: resource_type_detail + for resource_type_detail in RESOURCE_TYPES +} diff --git a/ee/api/routers/fixtures/group_schema.json b/ee/api/routers/scim/fixtures/group_schema.json similarity index 97% rename from ee/api/routers/fixtures/group_schema.json rename to ee/api/routers/scim/fixtures/group_schema.json index 1a56a3cdf..ddb030b92 100644 --- a/ee/api/routers/fixtures/group_schema.json +++ b/ee/api/routers/scim/fixtures/group_schema.json @@ -2,6 +2,7 @@ "id": "urn:ietf:params:scim:schemas:core:2.0:Group", "name": "Group", "description": "Group", + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"], "attributes": [ { "name": "schemas", @@ -97,7 +98,7 @@ "uniqueness": "none" } ] - }, + }, { "name": "displayName", "type": "string", @@ -159,7 +160,7 @@ "meta": { "resourceType": "Schema", "location": "/v2/Schemas/urn:ietf:params:scim:schemas:core:2.0:Group", - "created": "2025-04-17T15:48:00", - "lastModified": "2025-04-17T15:48:00" + "created": "2025-04-17T15:48:00Z", + "lastModified": "2025-04-17T15:48:00Z" } } diff --git a/ee/api/routers/fixtures/resource_type.json b/ee/api/routers/scim/fixtures/resource_type.json similarity index 100% rename from ee/api/routers/fixtures/resource_type.json rename to ee/api/routers/scim/fixtures/resource_type.json diff --git a/ee/api/routers/fixtures/resource_type_schema.json b/ee/api/routers/scim/fixtures/resource_type_schema.json similarity index 50% rename from ee/api/routers/fixtures/resource_type_schema.json rename to ee/api/routers/scim/fixtures/resource_type_schema.json index 040fd071c..ac53aefea 100644 --- a/ee/api/routers/fixtures/resource_type_schema.json +++ b/ee/api/routers/scim/fixtures/resource_type_schema.json @@ -2,7 +2,19 @@ "id": "urn:ietf:params:scim:schemas:core:2.0:ResourceType", "name": "ResourceType", "description": "Specifies the schema that describes a SCIM Resource Type", + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"], "attributes": [ + { + "name": "schemas", + "type": "string", + "multiValued": true, + "description": "An array of Strings containing URI that are used to indicate the namespaces of the SCIM schemas that define the attributes present in the current JSON structure.", + "required": true, + "caseExact": false, + "mutability": "immutable", + "returned": "always", + "uniqueness": "none" + }, { "name": "id", "type": "string", @@ -14,6 +26,79 @@ "returned": "default", "uniqueness": "none" }, + { + "name": "externalId", + "type": "string", + "multiValued": false, + "description": "Identifier for the resource as defined by the provisioning client. OPTIONAL; clients MAY include a non-empty value.", + "required": false, + "caseExact": true, + "mutability": "readWrite", + "returned": "default", + "uniqueness": "none" + }, + { + "name": "meta", + "type": "complex", + "multiValued": false, + "description": "Resource metadata. MUST be ignored when provided by clients.", + "required": false, + "mutability": "readOnly", + "returned": "default", + "subAttributes": [ + { + "name": "resourceType", + "type": "string", + "multiValued": false, + "description": "The resource type name.", + "required": false, + "caseExact": true, + "mutability": "readOnly", + "returned": "default", + "uniqueness": "none" + }, + { + "name": "created", + "type": "dateTime", + "multiValued": false, + "description": "The date and time the resource was added.", + "required": false, + "mutability": "readOnly", + "returned": "default" + }, + { + "name": "lastModified", + "type": "dateTime", + "multiValued": false, + "description": "The most recent date and time the resource was modified.", + "required": false, + "mutability": "readOnly", + "returned": "default" + }, + { + "name": "location", + "type": "reference", + "referenceTypes": ["external"], + "multiValued": false, + "description": "The URI of the resource being returned.", + "required": false, + "mutability": "readOnly", + "returned": "default", + "uniqueness": "none" + }, + { + "name": "version", + "type": "string", + "multiValued": false, + "description": "The version (ETag) of the resource being returned.", + "required": false, + "caseExact": true, + "mutability": "readOnly", + "returned": "default", + "uniqueness": "none" + } + ] + }, { "name": "name", "type": "string", @@ -96,7 +181,7 @@ "meta": { "resourceType": "Schema", "location": "/v2/Schemas/urn:ietf:params:scim:schemas:core:2.0:ResourceType", - "created": "2025-04-17T15:48:00", - "lastModified": "2025-04-17T15:48:00" + "created": "2025-04-17T15:48:00Z", + "lastModified": "2025-04-17T15:48:00Z" } } diff --git a/ee/api/routers/fixtures/schema_schema.json b/ee/api/routers/scim/fixtures/schema_schema.json similarity index 77% rename from ee/api/routers/fixtures/schema_schema.json rename to ee/api/routers/scim/fixtures/schema_schema.json index 4099700f3..231cbde54 100644 --- a/ee/api/routers/fixtures/schema_schema.json +++ b/ee/api/routers/scim/fixtures/schema_schema.json @@ -2,7 +2,19 @@ "id": "urn:ietf:params:scim:schemas:core:2.0:Schema", "name": "Schema", "description": "Specifies the schema that describes a SCIM Schema", + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"], "attributes": [ + { + "name": "schemas", + "type": "string", + "multiValued": true, + "description": "An array of Strings containing URI that are used to indicate the namespaces of the SCIM schemas that define the attributes present in the current JSON structure.", + "required": true, + "caseExact": false, + "mutability": "immutable", + "returned": "always", + "uniqueness": "none" + }, { "name": "id", "type": "string", @@ -14,6 +26,79 @@ "returned": "default", "uniqueness": "none" }, + { + "name": "externalId", + "type": "string", + "multiValued": false, + "description": "Identifier for the resource as defined by the provisioning client. OPTIONAL; clients MAY include a non-empty value.", + "required": false, + "caseExact": true, + "mutability": "readWrite", + "returned": "default", + "uniqueness": "none" + }, + { + "name": "meta", + "type": "complex", + "multiValued": false, + "description": "Resource metadata. MUST be ignored when provided by clients.", + "required": false, + "mutability": "readOnly", + "returned": "default", + "subAttributes": [ + { + "name": "resourceType", + "type": "string", + "multiValued": false, + "description": "The resource type name.", + "required": false, + "caseExact": true, + "mutability": "readOnly", + "returned": "default", + "uniqueness": "none" + }, + { + "name": "created", + "type": "dateTime", + "multiValued": false, + "description": "The date and time the resource was added.", + "required": false, + "mutability": "readOnly", + "returned": "default" + }, + { + "name": "lastModified", + "type": "dateTime", + "multiValued": false, + "description": "The most recent date and time the resource was modified.", + "required": false, + "mutability": "readOnly", + "returned": "default" + }, + { + "name": "location", + "type": "reference", + "referenceTypes": ["external"], + "multiValued": false, + "description": "The URI of the resource being returned.", + "required": false, + "mutability": "readOnly", + "returned": "default", + "uniqueness": "none" + }, + { + "name": "version", + "type": "string", + "multiValued": false, + "description": "The version (ETag) of the resource being returned.", + "required": false, + "caseExact": true, + "mutability": "readOnly", + "returned": "default", + "uniqueness": "none" + } + ] + }, { "name": "name", "type": "string", @@ -172,7 +257,7 @@ "required": false, "mutability": "readOnly", "returned": "default", - "subAttributes": [ + "subAttribtes": [ { "name": "name", "type": "string", @@ -298,7 +383,7 @@ "meta": { "resourceType": "Schema", "location": "/v2/Schemas/urn:ietf:params:scim:schemas:core:2.0:Schema", - "created": "2025-04-17T15:48:00", - "lastModified": "2025-04-17T15:48:00" + "created": "2025-04-17T15:48:00Z", + "lastModified": "2025-04-17T15:48:00Z" } } diff --git a/ee/api/routers/fixtures/service_provider_config.json b/ee/api/routers/scim/fixtures/service_provider_config.json similarity index 100% rename from ee/api/routers/fixtures/service_provider_config.json rename to ee/api/routers/scim/fixtures/service_provider_config.json diff --git a/ee/api/routers/fixtures/service_provider_config_schema.json b/ee/api/routers/scim/fixtures/service_provider_config_schema.json similarity index 97% rename from ee/api/routers/fixtures/service_provider_config_schema.json rename to ee/api/routers/scim/fixtures/service_provider_config_schema.json index c17ca5e18..2a90e8de4 100644 --- a/ee/api/routers/fixtures/service_provider_config_schema.json +++ b/ee/api/routers/scim/fixtures/service_provider_config_schema.json @@ -2,6 +2,7 @@ "id": "urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig", "name": "Service Provider Configuration", "description": "Schema for representing the service provider's configuration", + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"], "attributes": [ { "name": "documentationUri", @@ -206,7 +207,7 @@ "meta": { "resourceType": "Schema", "location": "/v2/Schemas/urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig", - "created": "2025-04-17T15:48:00", - "lastModified": "2025-04-17T15:48:00" + "created": "2025-04-17T15:48:00Z", + "lastModified": "2025-04-17T15:48:00Z" } } diff --git a/ee/api/routers/fixtures/user_schema.json b/ee/api/routers/scim/fixtures/user_schema.json similarity index 99% rename from ee/api/routers/fixtures/user_schema.json rename to ee/api/routers/scim/fixtures/user_schema.json index 736694c91..c80a084c5 100644 --- a/ee/api/routers/fixtures/user_schema.json +++ b/ee/api/routers/scim/fixtures/user_schema.json @@ -2,6 +2,7 @@ "id": "urn:ietf:params:scim:schemas:core:2.0:User", "name": "User", "description": "User Account", + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"], "attributes": [ { "name": "schemas", @@ -380,7 +381,7 @@ "meta": { "resourceType": "Schema", "location": "/v2/Schemas/urn:ietf:params:scim:schemas:core:2.0:User", - "created": "2025-04-17T15:48:00", - "lastModified": "2025-04-17T15:48:00" + "created": "2025-04-17T15:48:00Z", + "lastModified": "2025-04-17T15:48:00Z" } } diff --git a/ee/api/routers/scim_groups.py b/ee/api/routers/scim/groups.py similarity index 66% rename from ee/api/routers/scim_groups.py rename to ee/api/routers/scim/groups.py index a9ef352b8..433dbc75b 100644 --- a/ee/api/routers/scim_groups.py +++ b/ee/api/routers/scim/groups.py @@ -2,10 +2,59 @@ from typing import Any from datetime import datetime from psycopg2.extensions import AsIs -from chalicelib.utils import helper, pg_client +from chalicelib.utils import pg_client +from routers.scim.resource_config import ( + ProviderResource, + ClientResource, + ResourceId, + ClientInput, + ProviderInput, +) -def count_total_resources(tenant_id: int) -> int: +def convert_client_resource_update_input_to_provider_resource_update_input( + tenant_id: int, client_input: ClientInput +) -> ProviderInput: + result = {} + if "displayName" in client_input: + result["name"] = client_input["displayName"] + if "externalId" in client_input: + result["external_id"] = client_input["externalId"] + if "members" in client_input: + members = client_input["members"] or [] + result["user_ids"] = [int(member["value"]) for member in members] + return result + + +def convert_provider_resource_to_client_resource( + provider_resource: ProviderResource, +) -> ClientResource: + members = provider_resource["users"] or [] + return { + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], + "id": str(provider_resource["group_id"]), + "externalId": provider_resource["external_id"], + "meta": { + "resourceType": "Group", + "created": provider_resource["created_at"].strftime("%Y-%m-%dT%H:%M:%SZ"), + "lastModified": provider_resource["updated_at"].strftime( + "%Y-%m-%dT%H:%M:%SZ" + ), + "location": f"Groups/{provider_resource['group_id']}", + }, + "displayName": provider_resource["name"], + "members": [ + { + "value": str(member["user_id"]), + "$ref": f"Users/{member['user_id']}", + "type": "User", + } + for member in members + ], + } + + +def get_active_resource_count(tenant_id: int) -> int: with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( @@ -20,9 +69,9 @@ def count_total_resources(tenant_id: int) -> int: return cur.fetchone()["count"] -def get_resources_paginated( - offset_one_indexed: int, tenant_id: int, limit: int | None = None -) -> list[dict[str, Any]]: +def get_provider_resource_chunk( + offset: int, tenant_id: int, limit: int +) -> list[ProviderResource]: with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( @@ -41,16 +90,18 @@ def get_resources_paginated( OFFSET %(offset)s; """, { - "offset": offset_one_indexed - 1, + "offset": offset, "limit": limit, "tenant_id": tenant_id, }, ) ) - return helper.list_to_camel_case(cur.fetchall()) + return cur.fetchall() -def get_resource_by_id(group_id: int, tenant_id: int) -> dict[str, Any]: +def get_provider_resource( + resource_id: ResourceId, tenant_id: int +) -> ProviderResource | None: with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( @@ -69,26 +120,50 @@ def get_resource_by_id(group_id: int, tenant_id: int) -> dict[str, Any]: AND groups.group_id = %(group_id)s LIMIT 1; """, - {"group_id": group_id, "tenant_id": tenant_id}, + {"group_id": resource_id, "tenant_id": tenant_id}, ) ) - return helper.dict_to_camel_case(cur.fetchone()) + return cur.fetchone() -def get_existing_resource_by_unique_values_from_all_resources( - **kwargs, -) -> dict[str, Any] | None: +def get_provider_resource_from_unique_fields( + **kwargs: dict[str, Any], +) -> ProviderResource | None: # note(jon): we do not really use this for groups as we don't have unique values outside # of the primary key return None -def create_resource( +def convert_client_resource_creation_input_to_provider_resource_creation_input( + tenant_id: int, client_input: ClientInput +) -> ProviderInput: + return { + "name": client_input["displayName"], + "external_id": client_input.get("externalId"), + "user_ids": [ + int(member["value"]) for member in client_input.get("members", []) + ], + } + + +def convert_client_resource_rewrite_input_to_provider_resource_rewrite_input( + tenant_id: int, client_input: ClientInput +) -> ProviderInput: + return { + "name": client_input["displayName"], + "external_id": client_input.get("externalId"), + "user_ids": [ + int(member["value"]) for member in client_input.get("members", []) + ], + } + + +def create_provider_resource( name: str, tenant_id: int, user_ids: list[str] | None = None, **kwargs: dict[str, Any], -) -> dict[str, Any]: +) -> ProviderResource: with pg_client.PostgresClient() as cur: kwargs["name"] = name kwargs["tenant_id"] = tenant_id @@ -136,10 +211,10 @@ def create_resource( LIMIT 1; """ ) - return helper.dict_to_camel_case(cur.fetchone()) + return cur.fetchone() -def delete_resource(group_id: int, tenant_id: int) -> None: +def delete_provider_resource(resource_id: ResourceId, tenant_id: int) -> None: with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( @@ -148,7 +223,7 @@ def delete_resource(group_id: int, tenant_id: int) -> None: WHERE groups.group_id = %(group_id)s AND groups.tenant_id = %(tenant_id)s; """ ), - {"tenant_id": tenant_id, "group_id": group_id}, + {"tenant_id": tenant_id, "group_id": resource_id}, ) @@ -214,30 +289,30 @@ def _update_resource_sql( LIMIT 1; """ ) - return helper.dict_to_camel_case(cur.fetchone()) + return cur.fetchone() -def update_resource( - group_id: int, +def rewrite_provider_resource( + resource_id: int, tenant_id: int, name: str, **kwargs: dict[str, Any], ) -> dict[str, Any]: return _update_resource_sql( - group_id=group_id, + group_id=resource_id, tenant_id=tenant_id, name=name, **kwargs, ) -def patch_resource( - group_id: int, +def update_provider_resource( + resource_id: int, tenant_id: int, **kwargs: dict[str, Any], ): return _update_resource_sql( - group_id=group_id, + group_id=resource_id, tenant_id=tenant_id, **kwargs, ) diff --git a/ee/api/routers/scim_helpers.py b/ee/api/routers/scim/helpers.py similarity index 81% rename from ee/api/routers/scim_helpers.py rename to ee/api/routers/scim/helpers.py index b57ec1356..a94806d14 100644 --- a/ee/api/routers/scim_helpers.py +++ b/ee/api/routers/scim/helpers.py @@ -45,83 +45,50 @@ def get_all_attribute_names_where_returned_is_always( def filter_attributes( - resource: dict[str, Any], include_list: list[str] + obj: dict[str, Any], + attributes_query_str: str | None, + excluded_attributes_query_str: str | None, + schema: dict[str, Any], ) -> dict[str, Any]: - result = {} + all_attributes = get_all_attribute_names(schema) + always_returned_attributes = get_all_attribute_names_where_returned_is_always( + schema + ) + included_attributes = convert_query_str_to_list(attributes_query_str) + included_attributes = included_attributes or all_attributes + included_attributes_set = set(included_attributes).union( + set(always_returned_attributes) + ) + excluded_attributes = convert_query_str_to_list(excluded_attributes_query_str) + excluded_attributes = excluded_attributes or [] + excluded_attributes_set = set(excluded_attributes).difference( + set(always_returned_attributes) + ) + include_paths = included_attributes_set.difference(excluded_attributes_set) - # Group include paths by top-level key - includes_by_key = {} - for path in include_list: - parts = path.split(".", 1) - key = parts[0] - rest = parts[1] if len(parts) == 2 else None - includes_by_key.setdefault(key, []).append(rest) + include_tree = {} + for path in include_paths: + parts = path.split(".") + node = include_tree + for part in parts: + node = node.setdefault(part, {}) - for key, subpaths in includes_by_key.items(): - if key not in resource: - continue + def _recurse(o, tree, parent_key=None): + if isinstance(o, dict): + out = {} + for key, subtree in tree.items(): + if key in o: + out[key] = _recurse(o[key], subtree, key) + return out + if isinstance(o, list): + out = [_recurse(item, tree, parent_key) for item in o] + return out + return o - value = resource[key] - if all(p is None for p in subpaths): - result[key] = value - else: - nested_paths = [p for p in subpaths if p is not None] - if isinstance(value, dict): - filtered = filter_attributes(value, nested_paths) - if filtered: - result[key] = filtered - elif isinstance(value, list): - new_list = [] - for item in value: - if isinstance(item, dict): - filtered_item = filter_attributes(item, nested_paths) - if filtered_item: - new_list.append(filtered_item) - if new_list: - result[key] = new_list + result = _recurse(obj, include_tree) return result -def exclude_attributes( - resource: dict[str, Any], exclude_list: list[str] -) -> dict[str, Any]: - exclude_map = {} - for attr in exclude_list: - parts = attr.split(".", 1) - key = parts[0] - # rest is empty string for top-level exclusion - rest = parts[1] if len(parts) == 2 else "" - exclude_map.setdefault(key, []).append(rest) - - new_resource = {} - for key, value in resource.items(): - if key in exclude_map: - subs = exclude_map[key] - # If any attr has no rest, exclude entire key - if "" in subs: - continue - # Exclude nested attributes - if isinstance(value, dict): - new_sub = exclude_attributes(value, subs) - if not new_sub: - continue - new_resource[key] = new_sub - elif isinstance(value, list): - new_list = [] - for item in value: - # note(jon): `item` should always be a dict here - new_item = exclude_attributes(item, subs) - new_list.append(new_item) - new_resource[key] = new_list - else: - # No exclusion for this key: copy safely - if isinstance(value, (dict, list)): - new_resource[key] = deepcopy(value) - else: - new_resource[key] = value - return new_resource - - def filter_mutable_attributes( schema: dict[str, Any], requested_changes: dict[str, Any], diff --git a/ee/api/routers/scim/resource_config.py b/ee/api/routers/scim/resource_config.py new file mode 100644 index 000000000..afae5eed6 --- /dev/null +++ b/ee/api/routers/scim/resource_config.py @@ -0,0 +1,78 @@ +from dataclasses import dataclass +from typing import Any, Callable + +from routers.scim.constants import ( + SCHEMA_IDS_TO_SCHEMA_DETAILS, +) +from routers.scim import helpers + + +Schema = dict[str, Any] +ProviderResource = dict[str, Any] +ClientResource = dict[str, Any] +ResourceId = int | str +ClientInput = dict[str, Any] +ProviderInput = dict[str, Any] + + +@dataclass +class ResourceConfig: + schema_id: str + max_chunk_size: int + get_active_resource_count: Callable[[int], int] + convert_provider_resource_to_client_resource: Callable[ + [ProviderResource], ClientResource + ] + get_provider_resource_chunk: Callable[[int, int, int], list[ProviderResource]] + get_provider_resource: Callable[[ResourceId, int], ProviderResource | None] + convert_client_resource_creation_input_to_provider_resource_creation_input: ( + Callable[[int, ClientInput], ProviderInput] + ) + get_provider_resource_from_unique_fields: Callable[..., ProviderResource | None] + restore_provider_resource: Callable[..., ProviderResource] | None + create_provider_resource: Callable[..., ProviderResource] + delete_provider_resource: Callable[[ResourceId, int], None] + convert_client_resource_rewrite_input_to_provider_resource_rewrite_input: Callable[ + [int, ClientInput], ProviderInput + ] + rewrite_provider_resource: Callable[..., ProviderResource] + convert_client_resource_update_input_to_provider_resource_update_input: Callable[ + [int, ClientInput], ProviderInput + ] + update_provider_resource: Callable[..., ProviderResource] + + +def get_schema(config: ResourceConfig) -> Schema: + return SCHEMA_IDS_TO_SCHEMA_DETAILS[config.schema_id] + + +def convert_provider_resource_to_client_resource( + config: ResourceConfig, + provider_resource: ProviderResource, + attributes_query_str: str | None, + excluded_attributes_query_str: str | None, +) -> ClientResource: + client_resource = config.convert_provider_resource_to_client_resource( + provider_resource + ) + schema = get_schema(config) + client_resource = helpers.filter_attributes( + client_resource, attributes_query_str, excluded_attributes_query_str, schema + ) + return client_resource + + +def get_resource( + config: ResourceConfig, + resource_id: ResourceId, + tenant_id: int, + attributes: str | None = None, + excluded_attributes: str | None = None, +) -> ClientResource | None: + provider_resource = config.get_provider_resource(resource_id, tenant_id) + if provider_resource is None: + return None + client_resource = convert_provider_resource_to_client_resource( + config, provider_resource, attributes, excluded_attributes + ) + return client_resource diff --git a/ee/api/routers/scim/users.py b/ee/api/routers/scim/users.py new file mode 100644 index 000000000..98e9b3ded --- /dev/null +++ b/ee/api/routers/scim/users.py @@ -0,0 +1,395 @@ +from typing import Any +from datetime import datetime +from psycopg2.extensions import AsIs + +from chalicelib.utils import pg_client +from chalicelib.core import roles +from routers.scim.resource_config import ( + ProviderResource, + ClientResource, + ResourceId, + ClientInput, + ProviderInput, +) + + +def convert_client_resource_update_input_to_provider_resource_update_input( + tenant_id: int, client_input: ClientInput +) -> ProviderInput: + result = {} + if "userType" in client_input: + role = roles.get_role_by_name(tenant_id, client_input["userType"]) + result["role_id"] = role["roleId"] if role else None + if "name" in client_input: + # note(jon): we're currently not handling the case where the client + # send patches of individual name components (e.g. name.middleName) + name = client_input.get("name", {}).get("formatted") + if name: + result["name"] = name + if "userName" in client_input: + result["email"] = client_input["userName"] + if "externalId" in client_input: + result["internal_id"] = client_input["externalId"] + if "active" in client_input: + result["deleted_at"] = None if client_input["active"] else datetime.now() + return result + + +def convert_client_resource_rewrite_input_to_provider_resource_rewrite_input( + tenant_id: int, client_input: ClientInput +) -> ProviderInput: + role_id = None + if "userType" in client_input: + role = roles.get_role_by_name(tenant_id, client_input["userType"]) + role_id = role["roleId"] if role else None + name = client_input.get("name", {}).get("formatted") + if not name: + name = " ".join( + [ + x + for x in [ + client_input.get("name", {}).get("honorificPrefix"), + client_input.get("name", {}).get("givenName"), + client_input.get("name", {}).get("middleName"), + client_input.get("name", {}).get("familyName"), + client_input.get("name", {}).get("honorificSuffix"), + ] + if x + ] + ) + result = { + "email": client_input["userName"], + "internal_id": client_input.get("externalId"), + "name": name, + "role_id": role_id, + } + result = {k: v for k, v in result.items() if v is not None} + return result + + +def convert_client_resource_creation_input_to_provider_resource_creation_input( + tenant_id: int, client_input: ClientInput +) -> ProviderInput: + role_id = None + if "userType" in client_input: + role = roles.get_role_by_name(tenant_id, client_input["userType"]) + role_id = role["roleId"] if role else None + name = client_input.get("name", {}).get("formatted") + if not name: + name = " ".join( + [ + x + for x in [ + client_input.get("name", {}).get("honorificPrefix"), + client_input.get("name", {}).get("givenName"), + client_input.get("name", {}).get("middleName"), + client_input.get("name", {}).get("familyName"), + client_input.get("name", {}).get("honorificSuffix"), + ] + if x + ] + ) + result = { + "email": client_input["userName"], + "internal_id": client_input.get("externalId"), + "name": name, + "role_id": role_id, + } + result = {k: v for k, v in result.items() if v is not None} + return result + + +def get_provider_resource_from_unique_fields( + email: str, **kwargs: dict[str, Any] +) -> ProviderResource | None: + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + SELECT * + FROM public.users + WHERE users.email = %(email)s + """, + {"email": email}, + ) + ) + return cur.fetchone() + + +def delete_provider_resource(resource_id: ResourceId, tenant_id: int) -> None: + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + UPDATE public.users + SET + deleted_at = NULL, + updated_at = default + WHERE + users.user_id = %(user_id)s + AND users.tenant_id = %(tenant_id)s + """, + {"user_id": resource_id, "tenant_id": tenant_id}, + ) + ) + + +def convert_provider_resource_to_client_resource( + provider_resource: ProviderResource, +) -> ClientResource: + return { + "id": str(provider_resource["user_id"]), + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], + "meta": { + "resourceType": "User", + "created": provider_resource["created_at"].strftime("%Y-%m-%dT%H:%M:%SZ"), + "lastModified": provider_resource["updated_at"].strftime( + "%Y-%m-%dT%H:%M:%SZ" + ), + "location": f"Users/{provider_resource['user_id']}", + }, + "userName": provider_resource["email"], + "externalId": provider_resource["internal_id"], + "name": { + "formatted": provider_resource["name"], + }, + "displayName": provider_resource["name"] or provider_resource["email"], + "userType": provider_resource.get("role_name"), + "active": provider_resource["deleted_at"] is None, + } + + +def get_active_resource_count(tenant_id: int) -> int: + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + SELECT COUNT(*) + FROM public.users + WHERE + users.tenant_id = %(tenant_id)s + AND users.deleted_at IS NULL + """, + {"tenant_id": tenant_id}, + ) + ) + return cur.fetchone()["count"] + + +def get_provider_resource_chunk( + offset: int, tenant_id: int, limit: int +) -> list[ProviderResource]: + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + SELECT + users.*, + roles.name AS role_name + FROM public.users + LEFT JOIN public.roles USING (role_id) + WHERE + users.tenant_id = %(tenant_id)s + AND users.deleted_at IS NULL + LIMIT %(limit)s + OFFSET %(offset)s; + """, + {"offset": offset, "limit": limit, "tenant_id": tenant_id}, + ) + ) + return cur.fetchall() + + +def get_provider_resource( + resource_id: ResourceId, tenant_id: int +) -> ProviderResource | None: + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + SELECT + users.*, + roles.name AS role_name + FROM public.users + LEFT JOIN public.roles USING (role_id) + WHERE + users.user_id = %(user_id)s + AND users.tenant_id = %(tenant_id)s + AND users.deleted_at IS NULL + LIMIT 1; + """, + { + "user_id": resource_id, + "tenant_id": tenant_id, + }, + ) + ) + return cur.fetchone() + + +def create_provider_resource( + email: str, + tenant_id: int, + name: str = "", + internal_id: str | None = None, + role_id: int | None = None, +) -> ProviderResource: + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + WITH u AS ( + INSERT INTO public.users ( + tenant_id, + email, + name, + internal_id, + role_id + ) + VALUES ( + %(tenant_id)s, + %(email)s, + %(name)s, + %(internal_id)s, + %(role_id)s + ) + RETURNING * + ) + SELECT + u.*, + roles.name as role_name + FROM u LEFT JOIN public.roles USING (role_id); + """, + { + "tenant_id": tenant_id, + "email": email, + "name": name, + "internal_id": internal_id, + "role_id": role_id, + }, + ) + ) + return cur.fetchone() + + +def restore_provider_resource( + tenant_id: int, + email: str, + name: str = "", + internal_id: str | None = None, + role_id: int | None = None, + **kwargs: dict[str, Any], +) -> ProviderResource: + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + WITH u AS ( + UPDATE public.users + SET + tenant_id = %(tenant_id)s, + email = %(email)s, + name = %(name)s, + internal_id = %(internal_id)s, + role_id = %(role_id)s, + deleted_at = NULL, + created_at = now(), + updated_at = now(), + api_key = default, + jwt_iat = NULL, + weekly_report = default + WHERE users.email = %(email)s + RETURNING * + ) + SELECT + u.*, + roles.name as role_name + FROM u LEFT JOIN public.roles USING (role_id); + """, + { + "tenant_id": tenant_id, + "email": email, + "name": name, + "internal_id": internal_id, + "role_id": role_id, + }, + ) + ) + return cur.fetchone() + + +def rewrite_provider_resource( + resource_id: int, + tenant_id: int, + email: str, + name: str = "", + internal_id: str | None = None, + role_id: int | None = None, +): + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + WITH u AS ( + UPDATE public.users + SET + email = %(email)s, + name = %(name)s, + internal_id = %(internal_id)s, + role_id = %(role_id)s, + updated_at = now() + WHERE + users.user_id = %(user_id)s + AND users.tenant_id = %(tenant_id)s + AND users.deleted_at IS NULL + RETURNING * + ) + SELECT + u.*, + roles.name as role_name + FROM u LEFT JOIN public.roles USING (role_id); + """, + { + "tenant_id": tenant_id, + "user_id": resource_id, + "email": email, + "name": name, + "internal_id": internal_id, + "role_id": role_id, + }, + ) + ) + return cur.fetchone() + + +def update_provider_resource( + resource_id: int, + tenant_id: int, + **kwargs, +): + with pg_client.PostgresClient() as cur: + set_fragments = [] + kwargs["updated_at"] = datetime.now() + for k, v in kwargs.items(): + fragment = cur.mogrify( + "%s = %s", + (AsIs(k), v), + ).decode("utf-8") + set_fragments.append(fragment) + set_clause = ", ".join(set_fragments) + query = f""" + WITH u AS ( + UPDATE public.users + SET {set_clause} + WHERE + users.user_id = {resource_id} + AND users.tenant_id = {tenant_id} + AND users.deleted_at IS NULL + RETURNING * + ) + SELECT + u.*, + roles.name as role_name + FROM u LEFT JOIN public.roles USING (role_id);""" + cur.execute(query) + return cur.fetchone() diff --git a/ee/api/routers/scim_constants.py b/ee/api/routers/scim_constants.py deleted file mode 100644 index 4b52b1a8e..000000000 --- a/ee/api/routers/scim_constants.py +++ /dev/null @@ -1,22 +0,0 @@ -# note(jon): please see https://datatracker.ietf.org/doc/html/rfc7643 for details on these constants -import json - -SCHEMAS = sorted( - [ - json.load(open("routers/fixtures/service_provider_config_schema.json", "r")), - json.load(open("routers/fixtures/resource_type_schema.json", "r")), - json.load(open("routers/fixtures/schema_schema.json", "r")), - json.load(open("routers/fixtures/user_schema.json", "r")), - json.load(open("routers/fixtures/group_schema.json", "r")), - ], - key=lambda x: x["id"], -) - -SERVICE_PROVIDER_CONFIG = json.load( - open("routers/fixtures/service_provider_config.json", "r") -) - -RESOURCE_TYPES = sorted( - json.load(open("routers/fixtures/resource_type.json", "r")), - key=lambda x: x["id"], -)