import logging from copy import deepcopy from enum import Enum from urllib.parse import urlencode from chalicelib.utils import pg_client from fastapi import Depends, HTTPException, Query, Response, Request from fastapi.responses import JSONResponse, RedirectResponse from psycopg2 import errors from chalicelib.core import roles from chalicelib.utils.scim_auth import ( auth_optional, auth_required, create_tokens, verify_refresh_token, ) from routers.base import get_routers from routers.scim.constants import ( SERVICE_PROVIDER_CONFIG, RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS, SCHEMA_IDS_TO_SCHEMA_DETAILS, ) from routers.scim import helpers, groups, users from routers.scim.resource_config import ResourceConfig from routers.scim import resource_config as api_helper logger = logging.getLogger(__name__) public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2") @public_app.post("/token") async def post_token(r: Request): form = await r.form() client_id = form.get("client_id") client_secret = form.get("client_secret") with pg_client.PostgresClient() as cur: try: cur.execute( cur.mogrify( """ SELECT tenant_id FROM public.tenants WHERE tenant_id=%(tenant_id)s AND tenant_key=%(tenant_key)s """, {"tenant_id": int(client_id), "tenant_key": client_secret}, ) ) except ValueError: raise HTTPException(status_code=401, detail="Invalid credentials") tenant = cur.fetchone() if not tenant: raise HTTPException(status_code=401, detail="Invalid credentials") grant_type = form.get("grant_type") if grant_type == "refresh_token": refresh_token = form.get("refresh_token") verify_refresh_token(refresh_token) else: code = form.get("code") with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( """ SELECT * FROM public.scim_auth_codes WHERE auth_code=%(auth_code)s AND tenant_id=%(tenant_id)s AND used IS FALSE """, {"auth_code": code, "tenant_id": int(client_id)}, ) ) if cur.fetchone() is None: raise HTTPException( status_code=401, detail="Invalid code/client_id pair" ) cur.execute( cur.mogrify( """ UPDATE public.scim_auth_codes SET used=TRUE WHERE auth_code=%(auth_code)s AND tenant_id=%(tenant_id)s AND used IS FALSE """, {"auth_code": code, "tenant_id": int(client_id)}, ) ) access_token, refresh_token, expires_in = create_tokens( tenant_id=tenant["tenant_id"] ) return { "access_token": access_token, "token_type": "Bearer", "expires_in": expires_in, "refresh_token": refresh_token, } # note(jon): this might be specific to okta. if so, we should probably put specify that in the endpoint @public_app.get("/authorize") async def get_authorize( r: Request, response_type: str, client_id: str, redirect_uri: str, state: str | None = None, ): with pg_client.PostgresClient() as cur: cur.execute( cur.mogrify( """ UPDATE public.scim_auth_codes SET used=TRUE WHERE tenant_id=%(tenant_id)s """, {"tenant_id": int(client_id)}, ) ) cur.execute( cur.mogrify( """ INSERT INTO public.scim_auth_codes (tenant_id) VALUES (%(tenant_id)s) RETURNING auth_code """, {"tenant_id": int(client_id)}, ) ) code = cur.fetchone()["auth_code"] params = {"code": code} if state: params["state"] = state url = f"{redirect_uri}?{urlencode(params)}" return RedirectResponse(url) def _not_found_error_response(resource_id: int): return JSONResponse( status_code=404, content={ "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], "detail": f"Resource {resource_id} not found", "status": "404", }, ) def _uniqueness_error_response(): return JSONResponse( status_code=409, content={ "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], "detail": "One or more of the attribute values are already in use or are reserved.", "status": "409", "scimType": "uniqueness", }, ) def _mutability_error_response(): return JSONResponse( status_code=400, content={ "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], "detail": "The attempted modification is not compatible with the target attribute's mutability or current state.", "status": "400", "scimType": "mutability", }, ) def _operation_not_permitted_error_response(): return JSONResponse( status_code=403, content={ "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], "detail": "Operation is not permitted based on the supplied authorization", "status": "403", }, ) def _invalid_value_error_response(): return JSONResponse( status_code=400, content={ "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], "detail": "A required value was missing, or the value specified was not compatible with the operation or attribtue type, or resource schema.", "status": "400", "scimType": "invalidValue", }, ) def _internal_server_error_response(detail: str): return JSONResponse( status_code=500, content={ "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], "detail": detail, "status": "500", }, ) # note(jon): it was recommended to make this endpoint partially open # so that clients can view the `authenticationSchemes` prior to being authenticated. @public_app.get("/ServiceProviderConfig") async def get_service_provider_config( r: Request, tenant_id: str | None = Depends(auth_optional) ): is_authenticated = tenant_id is not None if not is_authenticated: return JSONResponse( status_code=200, content={ "schemas": SERVICE_PROVIDER_CONFIG["schemas"], "authenticationSchemes": SERVICE_PROVIDER_CONFIG[ "authenticationSchemes" ], "meta": SERVICE_PROVIDER_CONFIG["meta"], }, ) return JSONResponse(status_code=200, content=SERVICE_PROVIDER_CONFIG) @public_app.get("/ResourceTypes", dependencies=[Depends(auth_required)]) async def get_resource_types(filter_param: str | None = Query(None, alias="filter")): if filter_param is not None: return _operation_not_permitted_error_response() return JSONResponse( status_code=200, content={ "totalResults": len(RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS), "itemsPerPage": len(RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS), "startIndex": 1, "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], "Resources": list(RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS.values()), }, ) @public_app.get("/ResourceTypes/{resource_id}", dependencies=[Depends(auth_required)]) async def get_resource_type(resource_id: str): if resource_id not in RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS: return _not_found_error_response(resource_id) return JSONResponse( status_code=200, content=RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS[resource_id], ) @public_app.get("/Schemas", dependencies=[Depends(auth_required)]) async def get_schemas(filter_param: str | None = Query(None, alias="filter")): if filter_param is not None: return _operation_not_permitted_error_response() return JSONResponse( status_code=200, content={ "totalResults": len(SCHEMA_IDS_TO_SCHEMA_DETAILS), "itemsPerPage": len(SCHEMA_IDS_TO_SCHEMA_DETAILS), "startIndex": 1, "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], "Resources": [ value for _, value in sorted(SCHEMA_IDS_TO_SCHEMA_DETAILS.items()) ], }, ) @public_app.get("/Schemas/{schema_id}") async def get_schema(schema_id: str, tenant_id=Depends(auth_required)): if schema_id not in SCHEMA_IDS_TO_SCHEMA_DETAILS: return _not_found_error_response(schema_id) schema = deepcopy(SCHEMA_IDS_TO_SCHEMA_DETAILS[schema_id]) if schema_id == "urn:ietf:params:scim:schemas:core:2.0:User": db_roles = roles.get_roles(tenant_id) role_names = [role["name"] for role in db_roles] user_type_attribute = next( filter(lambda x: x["name"] == "userType", schema["attributes"]) ) user_type_attribute["canonicalValues"] = role_names return JSONResponse( status_code=200, content=schema, ) user_config = ResourceConfig( schema_id="urn:ietf:params:scim:schemas:core:2.0:User", max_chunk_size=10, get_active_resource_count=users.get_active_resource_count, convert_provider_resource_to_client_resource=users.convert_provider_resource_to_client_resource, get_provider_resource_chunk=users.get_provider_resource_chunk, get_provider_resource=users.get_provider_resource, convert_client_resource_creation_input_to_provider_resource_creation_input=users.convert_client_resource_creation_input_to_provider_resource_creation_input, get_provider_resource_from_unique_fields=users.get_provider_resource_from_unique_fields, restore_provider_resource=users.restore_provider_resource, create_provider_resource=users.create_provider_resource, delete_provider_resource=users.delete_provider_resource, convert_client_resource_rewrite_input_to_provider_resource_rewrite_input=users.convert_client_resource_rewrite_input_to_provider_resource_rewrite_input, rewrite_provider_resource=users.rewrite_provider_resource, convert_client_resource_update_input_to_provider_resource_update_input=users.convert_client_resource_update_input_to_provider_resource_update_input, update_provider_resource=users.update_provider_resource, filter_attribute_mapping=users.filter_attribute_mapping, ) group_config = ResourceConfig( schema_id="urn:ietf:params:scim:schemas:core:2.0:Group", max_chunk_size=10, get_active_resource_count=groups.get_active_resource_count, convert_provider_resource_to_client_resource=groups.convert_provider_resource_to_client_resource, get_provider_resource_chunk=groups.get_provider_resource_chunk, get_provider_resource=groups.get_provider_resource, convert_client_resource_creation_input_to_provider_resource_creation_input=groups.convert_client_resource_creation_input_to_provider_resource_creation_input, get_provider_resource_from_unique_fields=lambda **kwargs: None, restore_provider_resource=None, create_provider_resource=groups.create_provider_resource, delete_provider_resource=groups.delete_provider_resource, convert_client_resource_rewrite_input_to_provider_resource_rewrite_input=groups.convert_client_resource_rewrite_input_to_provider_resource_rewrite_input, rewrite_provider_resource=groups.rewrite_provider_resource, convert_client_resource_update_input_to_provider_resource_update_input=groups.convert_client_resource_update_input_to_provider_resource_update_input, update_provider_resource=groups.update_provider_resource, filter_attribute_mapping=groups.filter_attribute_mapping, ) RESOURCE_TYPE_TO_RESOURCE_CONFIG: dict[str, ResourceConfig] = { "Users": user_config, "Groups": group_config, } class SCIMResource(str, Enum): USERS = "Users" GROUPS = "Groups" @public_app.get("/{resource_type}") async def get_resources( resource_type: SCIMResource, tenant_id=Depends(auth_required), requested_start_index_one_indexed: int = Query(1, alias="startIndex"), requested_items_per_page: int | None = Query(None, alias="count"), attributes: str | None = Query(None), excluded_attributes: str | None = Query(None, alias="excludedAttributes"), filter: str | None = Query(None), ): config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] filter_clause = helpers.scim_to_sql_where(filter, config.filter_attribute_mapping()) total_resources = config.get_active_resource_count(tenant_id, filter_clause) start_index_one_indexed = max(1, requested_start_index_one_indexed) offset = start_index_one_indexed - 1 limit = min( max(0, requested_items_per_page or config.max_chunk_size), config.max_chunk_size ) provider_resources = config.get_provider_resource_chunk( offset, tenant_id, limit, filter_clause ) client_resources = [ api_helper.convert_provider_resource_to_client_resource( config, provider_resource, attributes, excluded_attributes ) for provider_resource in provider_resources ] return JSONResponse( status_code=200, content={ "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], "totalResults": total_resources, "startIndex": start_index_one_indexed, "itemsPerPage": len(client_resources), "Resources": client_resources, }, ) @public_app.get("/{resource_type}/{resource_id}") async def get_resource( resource_type: SCIMResource, resource_id: int | str, tenant_id=Depends(auth_required), attributes: list[str] | None = Query(None), excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"), ): resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] resource = api_helper.get_resource( resource_config, resource_id, tenant_id, attributes, excluded_attributes, ) if not resource: return _not_found_error_response(resource_id) return JSONResponse(status_code=200, content=resource) @public_app.post("/{resource_type}") async def create_resource( resource_type: SCIMResource, r: Request, tenant_id=Depends(auth_required), attributes: list[str] | None = Query(None), excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"), ): config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] payload = await r.json() try: provider_resource_input = config.convert_client_resource_creation_input_to_provider_resource_creation_input( tenant_id, payload, ) except KeyError: return _invalid_value_error_response() existing_provider_resource = config.get_provider_resource_from_unique_fields( **provider_resource_input ) if ( existing_provider_resource and existing_provider_resource.get("deleted_at") is None ): return _uniqueness_error_response() if ( existing_provider_resource and existing_provider_resource.get("deleted_at") is not None ): provider_resource = config.restore_provider_resource( tenant_id=tenant_id, **provider_resource_input ) else: provider_resource = config.create_provider_resource( tenant_id=tenant_id, **provider_resource_input ) client_resource = api_helper.convert_provider_resource_to_client_resource( config, provider_resource, attributes, excluded_attributes ) response = JSONResponse(status_code=201, content=client_resource) response.headers["Location"] = client_resource["meta"]["location"] return response @public_app.delete("/{resource_type}/{resource_id}") async def delete_resource( resource_type: SCIMResource, resource_id: str, tenant_id=Depends(auth_required), ): # note(jon): this can be a soft or a hard delete config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] resource = api_helper.get_resource(config, resource_id, tenant_id) if not resource: return _not_found_error_response(resource_id) config.delete_provider_resource(resource_id, tenant_id) return Response(status_code=204, content="") @public_app.put("/{resource_type}/{resource_id}") async def put_resource( resource_type: SCIMResource, resource_id: int | str, r: Request, tenant_id=Depends(auth_required), attributes: list[str] | None = Query(None), excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"), ): config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] client_resource = api_helper.get_resource(config, resource_id, tenant_id) if not client_resource: return _not_found_error_response(resource_id) schema = api_helper.get_schema(config) payload = await r.json() try: client_resource_input = helpers.filter_mutable_attributes( schema, payload, client_resource ) except ValueError: return _mutability_error_response() provider_resource_input = ( config.convert_client_resource_rewrite_input_to_provider_resource_rewrite_input( tenant_id, client_resource_input ) ) try: provider_resource = config.rewrite_provider_resource( resource_id, tenant_id, **provider_resource_input, ) except errors.UniqueViolation: return _uniqueness_error_response() except Exception as e: return _internal_server_error_response(str(e)) client_resource = api_helper.convert_provider_resource_to_client_resource( config, provider_resource, attributes, excluded_attributes ) return JSONResponse(status_code=200, content=client_resource) @public_app.patch("/{resource_type}/{resource_id}") async def patch_resource( resource_type: SCIMResource, resource_id: int | str, r: Request, tenant_id=Depends(auth_required), attributes: list[str] | None = Query(None), excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"), ): config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] client_resource = api_helper.get_resource(config, resource_id, tenant_id) if not client_resource: return _not_found_error_response(resource_id) schema = api_helper.get_schema(config) payload = await r.json() _, changes = helpers.apply_scim_patch( payload["Operations"], client_resource, schema ) client_resource_input = { k: new_value for k, (old_value, new_value) in changes.items() } provider_resource_input = ( config.convert_client_resource_update_input_to_provider_resource_update_input( tenant_id, client_resource_input ) ) try: provider_resource = config.update_provider_resource( resource_id, tenant_id, **provider_resource_input ) except errors.UniqueViolation: return _uniqueness_error_response() except Exception as e: return _internal_server_error_response(str(e)) client_resource = api_helper.convert_provider_resource_to_client_resource( config, provider_resource, attributes, excluded_attributes ) return JSONResponse(status_code=200, content=client_resource)