restructure files

This commit is contained in:
Jonathan Griffin 2025-04-24 11:19:21 +02:00
parent a5ddf786d4
commit eccb753c3c
18 changed files with 1290 additions and 1042 deletions

View file

@ -26,7 +26,7 @@ from routers.subs import v1_api_ee
if config("ENABLE_SSO", cast=bool, default=True):
from routers import saml
from routers import scim
from routers.scim import api as scim
loglevel = config("LOGLEVEL", default=logging.WARNING)
print(f">Loglevel set to: {loglevel}")

View file

@ -2,11 +2,9 @@ import json
import logging
import secrets
from typing import Optional
from datetime import datetime
from decouple import config
from fastapi import BackgroundTasks, HTTPException
from psycopg2.extensions import AsIs
from psycopg2.extras import Json
from pydantic import BaseModel, model_validator
from starlette import status
@ -352,239 +350,6 @@ def get(user_id, tenant_id):
return helper.dict_to_camel_case(r)
def count_total_scim_users(tenant_id: int) -> int:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT COUNT(*)
FROM public.users
WHERE
users.tenant_id = %(tenant_id)s
AND users.deleted_at IS NULL
""",
{"tenant_id": tenant_id},
)
)
return cur.fetchone()["count"]
def get_scim_users_paginated(start_index, tenant_id, count=None):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT
users.*,
roles.name AS role_name
FROM public.users
LEFT JOIN public.roles USING (role_id)
WHERE
users.tenant_id = %(tenant_id)s
AND users.deleted_at IS NULL
LIMIT %(limit)s
OFFSET %(offset)s;
""",
{"offset": start_index - 1, "limit": count, "tenant_id": tenant_id},
)
)
r = cur.fetchall()
return helper.list_to_camel_case(r)
def get_scim_user_by_id(user_id, tenant_id):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT
users.*,
roles.name AS role_name
FROM public.users
LEFT JOIN public.roles USING (role_id)
WHERE
users.user_id = %(user_id)s
AND users.tenant_id = %(tenant_id)s
AND users.deleted_at IS NULL
LIMIT 1;
""",
{
"user_id": user_id,
"tenant_id": tenant_id,
},
)
)
return helper.dict_to_camel_case(cur.fetchone())
def create_scim_user(
email: str,
tenant_id: int,
name: str = "",
internal_id: str | None = None,
role_id: int | None = None,
):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
WITH u AS (
INSERT INTO public.users (
tenant_id,
email,
name,
internal_id,
role_id
)
VALUES (
%(tenant_id)s,
%(email)s,
%(name)s,
%(internal_id)s,
%(role_id)s
)
RETURNING *
)
SELECT
u.*,
roles.name as role_name
FROM u LEFT JOIN public.roles USING (role_id);
""",
{
"tenant_id": tenant_id,
"email": email,
"name": name,
"internal_id": internal_id,
"role_id": role_id,
},
)
)
return helper.dict_to_camel_case(cur.fetchone())
def restore_scim_user(
tenant_id: int,
email: str,
name: str = "",
internal_id: str | None = None,
role_id: int | None = None,
**kwargs,
):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
WITH u AS (
UPDATE public.users
SET
tenant_id = %(tenant_id)s,
email = %(email)s,
name = %(name)s,
internal_id = %(internal_id)s,
role_id = %(role_id)s,
deleted_at = NULL,
created_at = now(),
updated_at = now(),
api_key = default,
jwt_iat = NULL,
weekly_report = default
WHERE users.email = %(email)s
RETURNING *
)
SELECT
u.*,
roles.name as role_name
FROM u LEFT JOIN public.roles USING (role_id);
""",
{
"tenant_id": tenant_id,
"email": email,
"name": name,
"internal_id": internal_id,
"role_id": role_id,
},
)
)
return helper.dict_to_camel_case(cur.fetchone())
def update_scim_user(
user_id: int,
tenant_id: int,
email: str,
name: str = "",
internal_id: str | None = None,
role_id: int | None = None,
):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
WITH u AS (
UPDATE public.users
SET
email = %(email)s,
name = %(name)s,
internal_id = %(internal_id)s,
role_id = %(role_id)s,
updated_at = now()
WHERE
users.user_id = %(user_id)s
AND users.tenant_id = %(tenant_id)s
AND users.deleted_at IS NULL
RETURNING *
)
SELECT
u.*,
roles.name as role_name
FROM u LEFT JOIN public.roles USING (role_id);
""",
{
"tenant_id": tenant_id,
"user_id": user_id,
"email": email,
"name": name,
"internal_id": internal_id,
"role_id": role_id,
},
)
)
return helper.dict_to_camel_case(cur.fetchone())
def patch_scim_user(
user_id: int,
tenant_id: int,
**kwargs,
):
with pg_client.PostgresClient() as cur:
set_fragments = []
kwargs["updated_at"] = datetime.now()
for k, v in kwargs.items():
fragment = cur.mogrify(
"%s = %s",
(AsIs(k), v),
).decode("utf-8")
set_fragments.append(fragment)
set_clause = ", ".join(set_fragments)
query = f"""
WITH u AS (
UPDATE public.users
SET {set_clause}
WHERE
users.user_id = {user_id}
AND users.tenant_id = {tenant_id}
AND users.deleted_at IS NULL
RETURNING *
)
SELECT
u.*,
roles.name as role_name
FROM u LEFT JOIN public.roles USING (role_id);"""
cur.execute(query)
return helper.dict_to_camel_case(cur.fetchone())
def generate_new_api_key(user_id):
with pg_client.PostgresClient() as cur:
cur.execute(
@ -695,21 +460,6 @@ def edit_member(
return {"data": user}
def get_existing_scim_user_by_unique_values_from_all_users(email: str, **kwargs):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT *
FROM public.users
WHERE users.email = %(email)s
""",
{"email": email},
)
)
return helper.dict_to_camel_case(cur.fetchone())
def get_by_email_only(email):
with pg_client.PostgresClient() as cur:
cur.execute(
@ -1255,24 +1005,6 @@ def create_sso_user(tenant_id, email, admin, name, origin, role_id, internal_id=
return helper.dict_to_camel_case(cur.fetchone())
def soft_delete_scim_user_by_id(user_id, tenant_id):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
UPDATE public.users
SET
deleted_at = NULL,
updated_at = default
WHERE
users.user_id = %(user_id)s
AND users.tenant_id = %(tenant_id)s
""",
{"user_id": user_id, "tenant_id": tenant_id},
)
)
def __hard_delete_user(user_id):
with pg_client.PostgresClient() as cur:
query = cur.mogrify(

View file

@ -1,644 +0,0 @@
from copy import deepcopy
import logging
from typing import Any, Callable
from enum import Enum
from datetime import datetime
from decouple import config
from fastapi import Depends, HTTPException, Header, Query, Response, Request
from fastapi.responses import JSONResponse
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel
from psycopg2 import errors
from chalicelib.core import users, roles, tenants
from chalicelib.utils.scim_auth import (
auth_optional,
auth_required,
create_tokens,
verify_refresh_token,
)
from routers.base import get_routers
from routers.scim_constants import RESOURCE_TYPES, SCHEMAS, SERVICE_PROVIDER_CONFIG
from routers import scim_helpers, scim_groups
logger = logging.getLogger(__name__)
public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2")
@public_app.post("/token")
async def post_token(
host: str = Header(..., alias="Host"),
form_data: OAuth2PasswordRequestForm = Depends(),
):
subdomain = host.split(".")[0]
# Missing authentication part, to add
if form_data.username != config("SCIM_USER") or form_data.password != config(
"SCIM_PASSWORD"
):
raise HTTPException(status_code=401, detail="Invalid credentials")
tenant = tenants.get_by_name(subdomain)
access_token, refresh_token = create_tokens(tenant_id=tenant["tenantId"])
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "Bearer",
}
class RefreshRequest(BaseModel):
refresh_token: str
@public_app.post("/refresh")
async def post_refresh(r: RefreshRequest):
payload = verify_refresh_token(r.refresh_token)
new_access_token, _ = create_tokens(tenant_id=payload["tenant_id"])
return {"access_token": new_access_token, "token_type": "Bearer"}
RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS = {
resource_type_detail["id"]: resource_type_detail
for resource_type_detail in RESOURCE_TYPES
}
def _not_found_error_response(resource_id: int):
return JSONResponse(
status_code=404,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": f"Resource {resource_id} not found",
"status": "404",
},
)
def _uniqueness_error_response():
return JSONResponse(
status_code=409,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": "One or more of the attribute values are already in use or are reserved.",
"status": "409",
"scimType": "uniqueness",
},
)
def _mutability_error_response():
return JSONResponse(
status_code=400,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": "The attempted modification is not compatible with the target attribute's mutability or current state.",
"status": "400",
"scimType": "mutability",
},
)
def _operation_not_permitted_error_response():
return JSONResponse(
status_code=403,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": "Operation is not permitted based on the supplied authorization",
"status": "403",
},
)
def _invalid_value_error_response():
return JSONResponse(
status_code=400,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": "A required value was missing, or the value specified was not compatible with the operation or attribtue type, or resource schema.",
"status": "400",
"scimType": "invalidValue",
},
)
def _internal_server_error_response(detail: str):
return JSONResponse(
status_code=500,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": detail,
"status": "500",
},
)
@public_app.get("/ResourceTypes", dependencies=[Depends(auth_required)])
async def get_resource_types(filter_param: str | None = Query(None, alias="filter")):
if filter_param is not None:
return _operation_not_permitted_error_response()
return JSONResponse(
status_code=200,
content={
"totalResults": len(RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS),
"itemsPerPage": len(RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS),
"startIndex": 1,
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
"Resources": list(RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS.values()),
},
)
@public_app.get("/ResourceTypes/{resource_id}", dependencies=[Depends(auth_required)])
async def get_resource_type(resource_id: str):
if resource_id not in RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS:
return _not_found_error_response(resource_id)
return JSONResponse(
status_code=200,
content=RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS[resource_id],
)
SCHEMA_IDS_TO_SCHEMA_DETAILS = {
schema_detail["id"]: schema_detail for schema_detail in SCHEMAS
}
@public_app.get("/Schemas", dependencies=[Depends(auth_required)])
async def get_schemas(filter_param: str | None = Query(None, alias="filter")):
if filter_param is not None:
return _operation_not_permitted_error_response()
return JSONResponse(
status_code=200,
content={
"totalResults": len(SCHEMA_IDS_TO_SCHEMA_DETAILS),
"itemsPerPage": len(SCHEMA_IDS_TO_SCHEMA_DETAILS),
"startIndex": 1,
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
"Resources": [
value for _, value in sorted(SCHEMA_IDS_TO_SCHEMA_DETAILS.items())
],
},
)
@public_app.get("/Schemas/{schema_id}")
async def get_schema(schema_id: str, tenant_id=Depends(auth_required)):
if schema_id not in SCHEMA_IDS_TO_SCHEMA_DETAILS:
return _not_found_error_response(schema_id)
schema = deepcopy(SCHEMA_IDS_TO_SCHEMA_DETAILS[schema_id])
if schema_id == "urn:ietf:params:scim:schemas:core:2.0:User":
db_roles = roles.get_roles(tenant_id)
role_names = [role["name"] for role in db_roles]
user_type_attribute = next(
filter(lambda x: x["name"] == "userType", schema["attributes"])
)
user_type_attribute["canonicalValues"] = role_names
return JSONResponse(
status_code=200,
content=schema,
)
# note(jon): it was recommended to make this endpoint partially open
# so that clients can view the `authenticationSchemes` prior to being authenticated.
@public_app.get("/ServiceProviderConfig")
async def get_service_provider_config(
r: Request, tenant_id: str | None = Depends(auth_optional)
):
is_authenticated = tenant_id is not None
if not is_authenticated:
return JSONResponse(
status_code=200,
content={
"schemas": SERVICE_PROVIDER_CONFIG["schemas"],
"authenticationSchemes": SERVICE_PROVIDER_CONFIG[
"authenticationSchemes"
],
"meta": SERVICE_PROVIDER_CONFIG["meta"],
},
)
return JSONResponse(status_code=200, content=SERVICE_PROVIDER_CONFIG)
def _serialize_db_resource_to_scim_resource_with_attribute_awareness(
db_resource: dict[str, Any],
schema_id: str,
serialize_db_resource_to_scim_resource: Callable[[dict[str, Any]], dict[str, Any]],
attributes: list[str] | None = None,
excluded_attributes: list[str] | None = None,
) -> dict[str, Any]:
schema = SCHEMA_IDS_TO_SCHEMA_DETAILS[schema_id]
all_attributes = scim_helpers.get_all_attribute_names(schema)
attributes = attributes or all_attributes
always_returned_attributes = (
scim_helpers.get_all_attribute_names_where_returned_is_always(schema)
)
included_attributes = list(set(attributes).union(set(always_returned_attributes)))
excluded_attributes = excluded_attributes or []
excluded_attributes = list(
set(excluded_attributes).difference(set(always_returned_attributes))
)
scim_resource = serialize_db_resource_to_scim_resource(db_resource)
scim_resource = scim_helpers.filter_attributes(scim_resource, included_attributes)
scim_resource = scim_helpers.exclude_attributes(scim_resource, excluded_attributes)
return scim_resource
def _parse_scim_user_input(data: dict[str, Any], tenant_id: str) -> dict[str, Any]:
role_id = None
if "userType" in data:
role = roles.get_role_by_name(tenant_id, data["userType"])
role_id = role["roleId"] if role else None
name = data.get("name", {}).get("formatted")
if not name:
name = " ".join(
[
x
for x in [
data.get("name", {}).get("honorificPrefix"),
data.get("name", {}).get("givenName"),
data.get("name", {}).get("middleName"),
data.get("name", {}).get("familyName"),
data.get("name", {}).get("honorificSuffix"),
]
if x
]
)
result = {
"email": data["userName"],
"internal_id": data.get("externalId"),
"name": name,
"role_id": role_id,
}
result = {k: v for k, v in result.items() if v is not None}
return result
def _parse_user_patch_payload(data: dict[str, Any], tenant_id: str) -> dict[str, Any]:
result = {}
if "userType" in data:
role = roles.get_role_by_name(tenant_id, data["userType"])
result["role_id"] = role["roleId"] if role else None
if "name" in data:
# note(jon): we're currently not handling the case where the client
# send patches of individual name components (e.g. name.middleName)
name = data.get("name", {}).get("formatted")
if name:
result["name"] = name
if "userName" in data:
result["email"] = data["userName"]
if "externalId" in data:
result["internal_id"] = data["externalId"]
if "active" in data:
result["deleted_at"] = None if data["active"] else datetime.now()
return result
def _serialize_db_user_to_scim_user(db_user: dict[str, Any]) -> dict[str, Any]:
return {
"id": str(db_user["userId"]),
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"meta": {
"resourceType": "User",
"created": db_user["createdAt"].strftime("%Y-%m-%dT%H:%M:%SZ"),
"lastModified": db_user["updatedAt"].strftime("%Y-%m-%dT%H:%M:%SZ"),
"location": f"Users/{db_user['userId']}",
},
"userName": db_user["email"],
"externalId": db_user["internalId"],
"name": {
"formatted": db_user["name"],
},
"displayName": db_user["name"] or db_user["email"],
"userType": db_user.get("roleName"),
"active": db_user["deletedAt"] is None,
}
def _serialize_db_group_to_scim_group(db_resource: dict[str, Any]) -> dict[str, Any]:
members = db_resource["users"] or []
return {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"id": str(db_resource["groupId"]),
"externalId": db_resource["externalId"],
"meta": {
"resourceType": "Group",
"created": db_resource["createdAt"].strftime("%Y-%m-%dT%H:%M:%SZ"),
"lastModified": db_resource["updatedAt"].strftime("%Y-%m-%dT%H:%M:%SZ"),
"location": f"Groups/{db_resource['groupId']}",
},
"displayName": db_resource["name"],
"members": [
{
"value": str(member["userId"]),
"$ref": f"Users/{member['userId']}",
"type": "User",
}
for member in members
],
}
def _parse_scim_group_input(data: dict[str, Any], tenant_id: int) -> dict[str, Any]:
return {
"name": data["displayName"],
"external_id": data.get("externalId"),
"user_ids": [int(member["value"]) for member in data.get("members", [])],
}
def _parse_scim_group_patch(data: dict[str, Any], tenant_id: int) -> dict[str, Any]:
result = {}
if "displayName" in data:
result["name"] = data["displayName"]
if "externalId" in data:
result["external_id"] = data["externalId"]
if "members" in data:
members = data["members"] or []
result["user_ids"] = [int(member["value"]) for member in members]
return result
RESOURCE_TYPE_TO_RESOURCE_CONFIG = {
"Users": {
"max_items_per_page": 10,
"schema_id": "urn:ietf:params:scim:schemas:core:2.0:User",
"db_to_scim_serializer": _serialize_db_user_to_scim_user,
"count_total_resources": users.count_total_scim_users,
"get_paginated_resources": users.get_scim_users_paginated,
"get_unique_resource": users.get_scim_user_by_id,
"parse_post_payload": _parse_scim_user_input,
"get_resource_by_unique_values": users.get_existing_scim_user_by_unique_values_from_all_users,
"restore_resource": users.restore_scim_user,
"create_resource": users.create_scim_user,
"delete_resource": users.soft_delete_scim_user_by_id,
"parse_put_payload": _parse_scim_user_input,
"update_resource": users.update_scim_user,
"parse_patch_payload": _parse_user_patch_payload,
"patch_resource": users.patch_scim_user,
},
"Groups": {
"max_items_per_page": 10,
"schema_id": "urn:ietf:params:scim:schemas:core:2.0:Group",
"db_to_scim_serializer": _serialize_db_group_to_scim_group,
"count_total_resources": scim_groups.count_total_resources,
"get_paginated_resources": scim_groups.get_resources_paginated,
"get_unique_resource": scim_groups.get_resource_by_id,
"parse_post_payload": _parse_scim_group_input,
"get_resource_by_unique_values": scim_groups.get_existing_resource_by_unique_values_from_all_resources,
# note(jon): we're not soft deleting groups, so we don't need this
"restore_resource": None,
"create_resource": scim_groups.create_resource,
"delete_resource": scim_groups.delete_resource,
"parse_put_payload": _parse_scim_group_input,
"update_resource": scim_groups.update_resource,
"parse_patch_payload": _parse_scim_group_patch,
"patch_resource": scim_groups.patch_resource,
},
}
class ListResourceType(str, Enum):
USERS = "Users"
GROUPS = "Groups"
@public_app.get("/{resource_type}")
async def get_resources(
resource_type: ListResourceType,
tenant_id=Depends(auth_required),
requested_start_index: int = Query(1, alias="startIndex"),
requested_items_per_page: int | None = Query(None, alias="count"),
attributes: str | None = Query(None),
excluded_attributes: str | None = Query(None, alias="excludedAttributes"),
):
resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
start_index = max(1, requested_start_index)
max_items_per_page = resource_config["max_items_per_page"]
items_per_page = min(
max(0, requested_items_per_page or max_items_per_page), max_items_per_page
)
total_resources = resource_config["count_total_resources"](tenant_id)
db_resources = resource_config["get_paginated_resources"](
start_index, tenant_id, items_per_page
)
attributes = scim_helpers.convert_query_str_to_list(attributes)
excluded_attributes = scim_helpers.convert_query_str_to_list(excluded_attributes)
scim_resources = [
_serialize_db_resource_to_scim_resource_with_attribute_awareness(
db_resource,
resource_config["schema_id"],
resource_config["db_to_scim_serializer"],
attributes,
excluded_attributes,
)
for db_resource in db_resources
]
return JSONResponse(
status_code=200,
content={
"totalResults": total_resources,
"startIndex": start_index,
"itemsPerPage": len(scim_resources),
"Resources": scim_resources,
},
)
class GetResourceType(str, Enum):
USERS = "Users"
GROUPS = "Groups"
@public_app.get("/{resource_type}/{resource_id}")
async def get_resource(
resource_type: GetResourceType,
resource_id: int,
tenant_id=Depends(auth_required),
attributes: list[str] | None = Query(None),
excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"),
):
resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
db_resource = resource_config["get_unique_resource"](resource_id, tenant_id)
if not db_resource:
return _not_found_error_response(resource_id)
scim_resource = _serialize_db_resource_to_scim_resource_with_attribute_awareness(
db_resource,
resource_config["schema_id"],
resource_config["db_to_scim_serializer"],
attributes,
excluded_attributes,
)
return JSONResponse(status_code=200, content=scim_resource)
class PostResourceType(str, Enum):
USERS = "Users"
GROUPS = "Groups"
@public_app.post("/{resource_type}")
async def create_resource(
resource_type: PostResourceType,
r: Request,
tenant_id=Depends(auth_required),
):
resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
scim_payload = await r.json()
try:
db_payload = resource_config["parse_post_payload"](scim_payload, tenant_id)
except KeyError:
return _invalid_value_error_response()
existing_db_resource = resource_config["get_resource_by_unique_values"](
**db_payload
)
if existing_db_resource and existing_db_resource.get("deletedAt") is None:
return _uniqueness_error_response()
if existing_db_resource and existing_db_resource.get("deletedAt") is not None:
db_resource = resource_config["restore_resource"](
tenant_id=tenant_id, **db_payload
)
else:
db_resource = resource_config["create_resource"](
tenant_id=tenant_id,
**db_payload,
)
scim_resource = _serialize_db_resource_to_scim_resource_with_attribute_awareness(
db_resource,
resource_config["schema_id"],
resource_config["db_to_scim_serializer"],
)
response = JSONResponse(status_code=201, content=scim_resource)
response.headers["Location"] = scim_resource["meta"]["location"]
return response
class DeleteResourceType(str, Enum):
USERS = "Users"
GROUPS = "Groups"
@public_app.delete("/{resource_type}/{resource_id}")
async def delete_resource(
resource_type: DeleteResourceType,
resource_id: str,
tenant_id=Depends(auth_required),
):
# note(jon): this can be a soft or a hard delete
resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
db_resource = resource_config["get_unique_resource"](resource_id, tenant_id)
if not db_resource:
return _not_found_error_response(resource_id)
resource_config["delete_resource"](resource_id, tenant_id)
return Response(status_code=204, content="")
class PutResourceType(str, Enum):
USERS = "Users"
GROUPS = "Groups"
@public_app.put("/{resource_type}/{resource_id}")
async def put_resource(
resource_type: PutResourceType,
resource_id: str,
r: Request,
tenant_id=Depends(auth_required),
):
resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
db_resource = resource_config["get_unique_resource"](resource_id, tenant_id)
if not db_resource:
return _not_found_error_response(resource_id)
current_scim_resource = (
_serialize_db_resource_to_scim_resource_with_attribute_awareness(
db_resource,
resource_config["schema_id"],
resource_config["db_to_scim_serializer"],
)
)
requested_scim_changes = await r.json()
schema = SCHEMA_IDS_TO_SCHEMA_DETAILS[resource_config["schema_id"]]
try:
valid_mutable_scim_changes = scim_helpers.filter_mutable_attributes(
schema, requested_scim_changes, current_scim_resource
)
except ValueError:
return _mutability_error_response()
valid_mutable_db_changes = resource_config["parse_put_payload"](
valid_mutable_scim_changes,
tenant_id,
)
try:
updated_db_resource = resource_config["update_resource"](
resource_id,
tenant_id,
**valid_mutable_db_changes,
)
updated_scim_resource = (
_serialize_db_resource_to_scim_resource_with_attribute_awareness(
updated_db_resource,
resource_config["schema_id"],
resource_config["db_to_scim_serializer"],
)
)
return JSONResponse(status_code=200, content=updated_scim_resource)
except errors.UniqueViolation:
return _uniqueness_error_response()
except Exception as e:
return _internal_server_error_response(str(e))
class PatchResourceType(str, Enum):
USERS = "Users"
GROUPS = "Groups"
@public_app.patch("/{resource_type}/{resource_id}")
async def patch_resource(
resource_type: PatchResourceType,
resource_id: str,
r: Request,
tenant_id=Depends(auth_required),
):
resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
db_resource = resource_config["get_unique_resource"](resource_id, tenant_id)
if not db_resource:
return _not_found_error_response(resource_id)
current_scim_resource = (
_serialize_db_resource_to_scim_resource_with_attribute_awareness(
db_resource,
resource_config["schema_id"],
resource_config["db_to_scim_serializer"],
)
)
payload = await r.json()
_, changes = scim_helpers.apply_scim_patch(
payload["Operations"],
current_scim_resource,
SCHEMA_IDS_TO_SCHEMA_DETAILS[resource_config["schema_id"]],
)
reformatted_scim_changes = {
k: new_value for k, (old_value, new_value) in changes.items()
}
db_changes = resource_config["parse_patch_payload"](
reformatted_scim_changes,
tenant_id,
)
updated_db_resource = resource_config["patch_resource"](
resource_id,
tenant_id,
**db_changes,
)
updated_scim_resource = (
_serialize_db_resource_to_scim_resource_with_attribute_awareness(
updated_db_resource,
resource_config["schema_id"],
resource_config["db_to_scim_serializer"],
)
)
return JSONResponse(status_code=200, content=updated_scim_resource)

View file

461
ee/api/routers/scim/api.py Normal file
View file

@ -0,0 +1,461 @@
import logging
from copy import deepcopy
from enum import Enum
from decouple import config
from fastapi import Depends, HTTPException, Header, Query, Response, Request
from fastapi.responses import JSONResponse
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel
from psycopg2 import errors
from chalicelib.core import roles, tenants
from chalicelib.utils.scim_auth import (
auth_optional,
auth_required,
create_tokens,
verify_refresh_token,
)
from routers.base import get_routers
from routers.scim.constants import (
SERVICE_PROVIDER_CONFIG,
RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS,
SCHEMA_IDS_TO_SCHEMA_DETAILS,
)
from routers.scim import helpers, groups, users
from routers.scim.resource_config import ResourceConfig
from routers.scim import resource_config as api_helper
logger = logging.getLogger(__name__)
public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2")
@public_app.post("/token")
async def post_token(
host: str = Header(..., alias="Host"),
form_data: OAuth2PasswordRequestForm = Depends(),
):
subdomain = host.split(".")[0]
# Missing authentication part, to add
if form_data.username != config("SCIM_USER") or form_data.password != config(
"SCIM_PASSWORD"
):
raise HTTPException(status_code=401, detail="Invalid credentials")
tenant = tenants.get_by_name(subdomain)
access_token, refresh_token = create_tokens(tenant_id=tenant["tenantId"])
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "Bearer",
}
class RefreshRequest(BaseModel):
refresh_token: str
@public_app.post("/refresh")
async def post_refresh(r: RefreshRequest):
payload = verify_refresh_token(r.refresh_token)
new_access_token, _ = create_tokens(tenant_id=payload["tenant_id"])
return {"access_token": new_access_token, "token_type": "Bearer"}
def _not_found_error_response(resource_id: int):
return JSONResponse(
status_code=404,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": f"Resource {resource_id} not found",
"status": "404",
},
)
def _uniqueness_error_response():
return JSONResponse(
status_code=409,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": "One or more of the attribute values are already in use or are reserved.",
"status": "409",
"scimType": "uniqueness",
},
)
def _mutability_error_response():
return JSONResponse(
status_code=400,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": "The attempted modification is not compatible with the target attribute's mutability or current state.",
"status": "400",
"scimType": "mutability",
},
)
def _operation_not_permitted_error_response():
return JSONResponse(
status_code=403,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": "Operation is not permitted based on the supplied authorization",
"status": "403",
},
)
def _invalid_value_error_response():
return JSONResponse(
status_code=400,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": "A required value was missing, or the value specified was not compatible with the operation or attribtue type, or resource schema.",
"status": "400",
"scimType": "invalidValue",
},
)
def _internal_server_error_response(detail: str):
return JSONResponse(
status_code=500,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": detail,
"status": "500",
},
)
# note(jon): it was recommended to make this endpoint partially open
# so that clients can view the `authenticationSchemes` prior to being authenticated.
@public_app.get("/ServiceProviderConfig")
async def get_service_provider_config(
r: Request, tenant_id: str | None = Depends(auth_optional)
):
is_authenticated = tenant_id is not None
if not is_authenticated:
return JSONResponse(
status_code=200,
content={
"schemas": SERVICE_PROVIDER_CONFIG["schemas"],
"authenticationSchemes": SERVICE_PROVIDER_CONFIG[
"authenticationSchemes"
],
"meta": SERVICE_PROVIDER_CONFIG["meta"],
},
)
return JSONResponse(status_code=200, content=SERVICE_PROVIDER_CONFIG)
@public_app.get("/ResourceTypes", dependencies=[Depends(auth_required)])
async def get_resource_types(filter_param: str | None = Query(None, alias="filter")):
if filter_param is not None:
return _operation_not_permitted_error_response()
return JSONResponse(
status_code=200,
content={
"totalResults": len(RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS),
"itemsPerPage": len(RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS),
"startIndex": 1,
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
"Resources": list(RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS.values()),
},
)
@public_app.get("/ResourceTypes/{resource_id}", dependencies=[Depends(auth_required)])
async def get_resource_type(resource_id: str):
if resource_id not in RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS:
return _not_found_error_response(resource_id)
return JSONResponse(
status_code=200,
content=RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS[resource_id],
)
@public_app.get("/Schemas", dependencies=[Depends(auth_required)])
async def get_schemas(filter_param: str | None = Query(None, alias="filter")):
if filter_param is not None:
return _operation_not_permitted_error_response()
return JSONResponse(
status_code=200,
content={
"totalResults": len(SCHEMA_IDS_TO_SCHEMA_DETAILS),
"itemsPerPage": len(SCHEMA_IDS_TO_SCHEMA_DETAILS),
"startIndex": 1,
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
"Resources": [
value for _, value in sorted(SCHEMA_IDS_TO_SCHEMA_DETAILS.items())
],
},
)
@public_app.get("/Schemas/{schema_id}")
async def get_schema(schema_id: str, tenant_id=Depends(auth_required)):
if schema_id not in SCHEMA_IDS_TO_SCHEMA_DETAILS:
return _not_found_error_response(schema_id)
schema = deepcopy(SCHEMA_IDS_TO_SCHEMA_DETAILS[schema_id])
if schema_id == "urn:ietf:params:scim:schemas:core:2.0:User":
db_roles = roles.get_roles(tenant_id)
role_names = [role["name"] for role in db_roles]
user_type_attribute = next(
filter(lambda x: x["name"] == "userType", schema["attributes"])
)
user_type_attribute["canonicalValues"] = role_names
return JSONResponse(
status_code=200,
content=schema,
)
user_config = ResourceConfig(
schema_id="urn:ietf:params:scim:schemas:core:2.0:User",
max_chunk_size=10,
get_active_resource_count=users.get_active_resource_count,
convert_provider_resource_to_client_resource=users.convert_provider_resource_to_client_resource,
get_provider_resource_chunk=users.get_provider_resource_chunk,
get_provider_resource=users.get_provider_resource,
convert_client_resource_creation_input_to_provider_resource_creation_input=users.convert_client_resource_creation_input_to_provider_resource_creation_input,
get_provider_resource_from_unique_fields=users.get_provider_resource_from_unique_fields,
restore_provider_resource=users.restore_provider_resource,
create_provider_resource=users.create_provider_resource,
delete_provider_resource=users.delete_provider_resource,
convert_client_resource_rewrite_input_to_provider_resource_rewrite_input=users.convert_client_resource_rewrite_input_to_provider_resource_rewrite_input,
rewrite_provider_resource=users.rewrite_provider_resource,
convert_client_resource_update_input_to_provider_resource_update_input=users.convert_client_resource_update_input_to_provider_resource_update_input,
update_provider_resource=users.update_provider_resource,
)
group_config = ResourceConfig(
schema_id="urn:ietf:params:scim:schemas:core:2.0:Group",
max_chunk_size=10,
get_active_resource_count=groups.get_active_resource_count,
convert_provider_resource_to_client_resource=groups.convert_provider_resource_to_client_resource,
get_provider_resource_chunk=groups.get_provider_resource_chunk,
get_provider_resource=groups.get_provider_resource,
convert_client_resource_creation_input_to_provider_resource_creation_input=groups.convert_client_resource_creation_input_to_provider_resource_creation_input,
get_provider_resource_from_unique_fields=groups.get_provider_resource_from_unique_fields,
restore_provider_resource=None,
create_provider_resource=groups.create_provider_resource,
delete_provider_resource=groups.delete_provider_resource,
convert_client_resource_rewrite_input_to_provider_resource_rewrite_input=groups.convert_client_resource_rewrite_input_to_provider_resource_rewrite_input,
rewrite_provider_resource=groups.rewrite_provider_resource,
convert_client_resource_update_input_to_provider_resource_update_input=groups.convert_client_resource_update_input_to_provider_resource_update_input,
update_provider_resource=groups.update_provider_resource,
)
RESOURCE_TYPE_TO_RESOURCE_CONFIG: dict[str, ResourceConfig] = {
"Users": user_config,
"Groups": group_config,
}
class SCIMResource(str, Enum):
USERS = "Users"
GROUPS = "Groups"
@public_app.get("/{resource_type}")
async def get_resources(
resource_type: SCIMResource,
tenant_id=Depends(auth_required),
requested_start_index_one_indexed: int = Query(1, alias="startIndex"),
requested_items_per_page: int | None = Query(None, alias="count"),
attributes: str | None = Query(None),
excluded_attributes: str | None = Query(None, alias="excludedAttributes"),
):
config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
total_resources = config.get_active_resource_count(tenant_id)
start_index_one_indexed = max(1, requested_start_index_one_indexed)
offset = start_index_one_indexed - 1
limit = min(
max(0, requested_items_per_page or config.max_chunk_size), config.max_chunk_size
)
provider_resources = config.get_provider_resource_chunk(offset, tenant_id, limit)
client_resources = [
api_helper.convert_provider_resource_to_client_resource(
config, provider_resource, attributes, excluded_attributes
)
for provider_resource in provider_resources
]
return JSONResponse(
status_code=200,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
"totalResults": total_resources,
"startIndex": start_index_one_indexed,
"itemsPerPage": len(client_resources),
"Resources": client_resources,
},
)
@public_app.get("/{resource_type}/{resource_id}")
async def get_resource(
resource_type: SCIMResource,
resource_id: int | str,
tenant_id=Depends(auth_required),
attributes: list[str] | None = Query(None),
excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"),
):
resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
resource = api_helper.get_resource(
resource_config,
resource_id,
tenant_id,
attributes,
excluded_attributes,
)
if not resource:
return _not_found_error_response(resource_id)
return JSONResponse(status_code=200, content=resource)
@public_app.post("/{resource_type}")
async def create_resource(
resource_type: SCIMResource,
r: Request,
tenant_id=Depends(auth_required),
attributes: list[str] | None = Query(None),
excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"),
):
config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
payload = await r.json()
try:
provider_resource_input = config.convert_client_resource_creation_input_to_provider_resource_creation_input(
tenant_id,
payload,
)
except KeyError:
return _invalid_value_error_response()
existing_provider_resource = config.get_provider_resource_from_unique_fields(
**provider_resource_input
)
if (
existing_provider_resource
and existing_provider_resource.get("deleted_at") is None
):
return _uniqueness_error_response()
if (
existing_provider_resource
and existing_provider_resource.get("deleted_at") is not None
):
provider_resource = config.restore_provider_resource(
tenant_id=tenant_id, **provider_resource_input
)
else:
provider_resource = config.create_provider_resource(
tenant_id=tenant_id, **provider_resource_input
)
client_resource = api_helper.convert_provider_resource_to_client_resource(
config, provider_resource, attributes, excluded_attributes
)
response = JSONResponse(status_code=201, content=client_resource)
response.headers["Location"] = client_resource["meta"]["location"]
return response
@public_app.delete("/{resource_type}/{resource_id}")
async def delete_resource(
resource_type: SCIMResource,
resource_id: str,
tenant_id=Depends(auth_required),
):
# note(jon): this can be a soft or a hard delete
config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
resource = api_helper.get_resource(config, resource_id, tenant_id)
if not resource:
return _not_found_error_response(resource_id)
config.delete_provider_resource(resource_id, tenant_id)
return Response(status_code=204, content="")
@public_app.put("/{resource_type}/{resource_id}")
async def put_resource(
resource_type: SCIMResource,
resource_id: str,
r: Request,
tenant_id=Depends(auth_required),
attributes: list[str] | None = Query(None),
excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"),
):
config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
client_resource = api_helper.get_resource(config, resource_id, tenant_id)
if not client_resource:
return _not_found_error_response(resource_id)
schema = api_helper.get_schema(config)
payload = await r.json()
try:
client_resource_input = helpers.filter_mutable_attributes(
schema, payload, client_resource
)
except ValueError:
return _mutability_error_response()
provider_resource_input = (
config.convert_client_resource_rewrite_input_to_provider_resource_rewrite_input(
tenant_id, client_resource_input
)
)
try:
provider_resource = config.rewrite_provider_resource(
resource_id,
tenant_id,
**provider_resource_input,
)
except errors.UniqueViolation:
return _uniqueness_error_response()
except Exception as e:
return _internal_server_error_response(str(e))
client_resource = api_helper.convert_provider_resource_to_client_resource(
config, provider_resource, attributes, excluded_attributes
)
return JSONResponse(status_code=200, content=client_resource)
@public_app.patch("/{resource_type}/{resource_id}")
async def patch_resource(
resource_type: SCIMResource,
resource_id: str,
r: Request,
tenant_id=Depends(auth_required),
attributes: list[str] | None = Query(None),
excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"),
):
config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
client_resource = api_helper.get_resource(config, resource_id, tenant_id)
if not client_resource:
return _not_found_error_response(resource_id)
schema = api_helper.get_schema(config)
payload = await r.json()
_, changes = helpers.apply_scim_patch(
payload["Operations"], client_resource, schema
)
client_resource_input = {
k: new_value for k, (old_value, new_value) in changes.items()
}
provider_resource_input = (
config.convert_client_resource_update_input_to_provider_resource_update_input(
tenant_id, client_resource_input
)
)
try:
provider_resource = config.update_provider_resource(
resource_id, tenant_id, **provider_resource_input
)
except errors.UniqueViolation:
return _uniqueness_error_response()
except Exception as e:
return _internal_server_error_response(str(e))
client_resource = api_helper.convert_provider_resource_to_client_resource(
config, provider_resource, attributes, excluded_attributes
)
return JSONResponse(status_code=200, content=client_resource)

View file

@ -0,0 +1,33 @@
# note(jon): please see https://datatracker.ietf.org/doc/html/rfc7643 for details on these constants
import json
SCHEMAS = sorted(
[
json.load(
open("routers/scim/fixtures/service_provider_config_schema.json", "r")
),
json.load(open("routers/scim/fixtures/resource_type_schema.json", "r")),
json.load(open("routers/scim/fixtures/schema_schema.json", "r")),
json.load(open("routers/scim/fixtures/user_schema.json", "r")),
json.load(open("routers/scim/fixtures/group_schema.json", "r")),
],
key=lambda x: x["id"],
)
SCHEMA_IDS_TO_SCHEMA_DETAILS = {
schema_detail["id"]: schema_detail for schema_detail in SCHEMAS
}
SERVICE_PROVIDER_CONFIG = json.load(
open("routers/scim/fixtures/service_provider_config.json", "r")
)
RESOURCE_TYPES = sorted(
json.load(open("routers/scim/fixtures/resource_type.json", "r")),
key=lambda x: x["id"],
)
RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS = {
resource_type_detail["id"]: resource_type_detail
for resource_type_detail in RESOURCE_TYPES
}

View file

@ -2,6 +2,7 @@
"id": "urn:ietf:params:scim:schemas:core:2.0:Group",
"name": "Group",
"description": "Group",
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"],
"attributes": [
{
"name": "schemas",
@ -159,7 +160,7 @@
"meta": {
"resourceType": "Schema",
"location": "/v2/Schemas/urn:ietf:params:scim:schemas:core:2.0:Group",
"created": "2025-04-17T15:48:00",
"lastModified": "2025-04-17T15:48:00"
"created": "2025-04-17T15:48:00Z",
"lastModified": "2025-04-17T15:48:00Z"
}
}

View file

@ -2,7 +2,19 @@
"id": "urn:ietf:params:scim:schemas:core:2.0:ResourceType",
"name": "ResourceType",
"description": "Specifies the schema that describes a SCIM Resource Type",
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"],
"attributes": [
{
"name": "schemas",
"type": "string",
"multiValued": true,
"description": "An array of Strings containing URI that are used to indicate the namespaces of the SCIM schemas that define the attributes present in the current JSON structure.",
"required": true,
"caseExact": false,
"mutability": "immutable",
"returned": "always",
"uniqueness": "none"
},
{
"name": "id",
"type": "string",
@ -14,6 +26,79 @@
"returned": "default",
"uniqueness": "none"
},
{
"name": "externalId",
"type": "string",
"multiValued": false,
"description": "Identifier for the resource as defined by the provisioning client. OPTIONAL; clients MAY include a non-empty value.",
"required": false,
"caseExact": true,
"mutability": "readWrite",
"returned": "default",
"uniqueness": "none"
},
{
"name": "meta",
"type": "complex",
"multiValued": false,
"description": "Resource metadata. MUST be ignored when provided by clients.",
"required": false,
"mutability": "readOnly",
"returned": "default",
"subAttributes": [
{
"name": "resourceType",
"type": "string",
"multiValued": false,
"description": "The resource type name.",
"required": false,
"caseExact": true,
"mutability": "readOnly",
"returned": "default",
"uniqueness": "none"
},
{
"name": "created",
"type": "dateTime",
"multiValued": false,
"description": "The date and time the resource was added.",
"required": false,
"mutability": "readOnly",
"returned": "default"
},
{
"name": "lastModified",
"type": "dateTime",
"multiValued": false,
"description": "The most recent date and time the resource was modified.",
"required": false,
"mutability": "readOnly",
"returned": "default"
},
{
"name": "location",
"type": "reference",
"referenceTypes": ["external"],
"multiValued": false,
"description": "The URI of the resource being returned.",
"required": false,
"mutability": "readOnly",
"returned": "default",
"uniqueness": "none"
},
{
"name": "version",
"type": "string",
"multiValued": false,
"description": "The version (ETag) of the resource being returned.",
"required": false,
"caseExact": true,
"mutability": "readOnly",
"returned": "default",
"uniqueness": "none"
}
]
},
{
"name": "name",
"type": "string",
@ -96,7 +181,7 @@
"meta": {
"resourceType": "Schema",
"location": "/v2/Schemas/urn:ietf:params:scim:schemas:core:2.0:ResourceType",
"created": "2025-04-17T15:48:00",
"lastModified": "2025-04-17T15:48:00"
"created": "2025-04-17T15:48:00Z",
"lastModified": "2025-04-17T15:48:00Z"
}
}

View file

@ -2,7 +2,19 @@
"id": "urn:ietf:params:scim:schemas:core:2.0:Schema",
"name": "Schema",
"description": "Specifies the schema that describes a SCIM Schema",
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"],
"attributes": [
{
"name": "schemas",
"type": "string",
"multiValued": true,
"description": "An array of Strings containing URI that are used to indicate the namespaces of the SCIM schemas that define the attributes present in the current JSON structure.",
"required": true,
"caseExact": false,
"mutability": "immutable",
"returned": "always",
"uniqueness": "none"
},
{
"name": "id",
"type": "string",
@ -14,6 +26,79 @@
"returned": "default",
"uniqueness": "none"
},
{
"name": "externalId",
"type": "string",
"multiValued": false,
"description": "Identifier for the resource as defined by the provisioning client. OPTIONAL; clients MAY include a non-empty value.",
"required": false,
"caseExact": true,
"mutability": "readWrite",
"returned": "default",
"uniqueness": "none"
},
{
"name": "meta",
"type": "complex",
"multiValued": false,
"description": "Resource metadata. MUST be ignored when provided by clients.",
"required": false,
"mutability": "readOnly",
"returned": "default",
"subAttributes": [
{
"name": "resourceType",
"type": "string",
"multiValued": false,
"description": "The resource type name.",
"required": false,
"caseExact": true,
"mutability": "readOnly",
"returned": "default",
"uniqueness": "none"
},
{
"name": "created",
"type": "dateTime",
"multiValued": false,
"description": "The date and time the resource was added.",
"required": false,
"mutability": "readOnly",
"returned": "default"
},
{
"name": "lastModified",
"type": "dateTime",
"multiValued": false,
"description": "The most recent date and time the resource was modified.",
"required": false,
"mutability": "readOnly",
"returned": "default"
},
{
"name": "location",
"type": "reference",
"referenceTypes": ["external"],
"multiValued": false,
"description": "The URI of the resource being returned.",
"required": false,
"mutability": "readOnly",
"returned": "default",
"uniqueness": "none"
},
{
"name": "version",
"type": "string",
"multiValued": false,
"description": "The version (ETag) of the resource being returned.",
"required": false,
"caseExact": true,
"mutability": "readOnly",
"returned": "default",
"uniqueness": "none"
}
]
},
{
"name": "name",
"type": "string",
@ -172,7 +257,7 @@
"required": false,
"mutability": "readOnly",
"returned": "default",
"subAttributes": [
"subAttribtes": [
{
"name": "name",
"type": "string",
@ -298,7 +383,7 @@
"meta": {
"resourceType": "Schema",
"location": "/v2/Schemas/urn:ietf:params:scim:schemas:core:2.0:Schema",
"created": "2025-04-17T15:48:00",
"lastModified": "2025-04-17T15:48:00"
"created": "2025-04-17T15:48:00Z",
"lastModified": "2025-04-17T15:48:00Z"
}
}

View file

@ -2,6 +2,7 @@
"id": "urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig",
"name": "Service Provider Configuration",
"description": "Schema for representing the service provider's configuration",
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"],
"attributes": [
{
"name": "documentationUri",
@ -206,7 +207,7 @@
"meta": {
"resourceType": "Schema",
"location": "/v2/Schemas/urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig",
"created": "2025-04-17T15:48:00",
"lastModified": "2025-04-17T15:48:00"
"created": "2025-04-17T15:48:00Z",
"lastModified": "2025-04-17T15:48:00Z"
}
}

View file

@ -2,6 +2,7 @@
"id": "urn:ietf:params:scim:schemas:core:2.0:User",
"name": "User",
"description": "User Account",
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"],
"attributes": [
{
"name": "schemas",
@ -380,7 +381,7 @@
"meta": {
"resourceType": "Schema",
"location": "/v2/Schemas/urn:ietf:params:scim:schemas:core:2.0:User",
"created": "2025-04-17T15:48:00",
"lastModified": "2025-04-17T15:48:00"
"created": "2025-04-17T15:48:00Z",
"lastModified": "2025-04-17T15:48:00Z"
}
}

View file

@ -2,10 +2,59 @@ from typing import Any
from datetime import datetime
from psycopg2.extensions import AsIs
from chalicelib.utils import helper, pg_client
from chalicelib.utils import pg_client
from routers.scim.resource_config import (
ProviderResource,
ClientResource,
ResourceId,
ClientInput,
ProviderInput,
)
def count_total_resources(tenant_id: int) -> int:
def convert_client_resource_update_input_to_provider_resource_update_input(
tenant_id: int, client_input: ClientInput
) -> ProviderInput:
result = {}
if "displayName" in client_input:
result["name"] = client_input["displayName"]
if "externalId" in client_input:
result["external_id"] = client_input["externalId"]
if "members" in client_input:
members = client_input["members"] or []
result["user_ids"] = [int(member["value"]) for member in members]
return result
def convert_provider_resource_to_client_resource(
provider_resource: ProviderResource,
) -> ClientResource:
members = provider_resource["users"] or []
return {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"id": str(provider_resource["group_id"]),
"externalId": provider_resource["external_id"],
"meta": {
"resourceType": "Group",
"created": provider_resource["created_at"].strftime("%Y-%m-%dT%H:%M:%SZ"),
"lastModified": provider_resource["updated_at"].strftime(
"%Y-%m-%dT%H:%M:%SZ"
),
"location": f"Groups/{provider_resource['group_id']}",
},
"displayName": provider_resource["name"],
"members": [
{
"value": str(member["user_id"]),
"$ref": f"Users/{member['user_id']}",
"type": "User",
}
for member in members
],
}
def get_active_resource_count(tenant_id: int) -> int:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
@ -20,9 +69,9 @@ def count_total_resources(tenant_id: int) -> int:
return cur.fetchone()["count"]
def get_resources_paginated(
offset_one_indexed: int, tenant_id: int, limit: int | None = None
) -> list[dict[str, Any]]:
def get_provider_resource_chunk(
offset: int, tenant_id: int, limit: int
) -> list[ProviderResource]:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
@ -41,16 +90,18 @@ def get_resources_paginated(
OFFSET %(offset)s;
""",
{
"offset": offset_one_indexed - 1,
"offset": offset,
"limit": limit,
"tenant_id": tenant_id,
},
)
)
return helper.list_to_camel_case(cur.fetchall())
return cur.fetchall()
def get_resource_by_id(group_id: int, tenant_id: int) -> dict[str, Any]:
def get_provider_resource(
resource_id: ResourceId, tenant_id: int
) -> ProviderResource | None:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
@ -69,26 +120,50 @@ def get_resource_by_id(group_id: int, tenant_id: int) -> dict[str, Any]:
AND groups.group_id = %(group_id)s
LIMIT 1;
""",
{"group_id": group_id, "tenant_id": tenant_id},
{"group_id": resource_id, "tenant_id": tenant_id},
)
)
return helper.dict_to_camel_case(cur.fetchone())
return cur.fetchone()
def get_existing_resource_by_unique_values_from_all_resources(
**kwargs,
) -> dict[str, Any] | None:
def get_provider_resource_from_unique_fields(
**kwargs: dict[str, Any],
) -> ProviderResource | None:
# note(jon): we do not really use this for groups as we don't have unique values outside
# of the primary key
return None
def create_resource(
def convert_client_resource_creation_input_to_provider_resource_creation_input(
tenant_id: int, client_input: ClientInput
) -> ProviderInput:
return {
"name": client_input["displayName"],
"external_id": client_input.get("externalId"),
"user_ids": [
int(member["value"]) for member in client_input.get("members", [])
],
}
def convert_client_resource_rewrite_input_to_provider_resource_rewrite_input(
tenant_id: int, client_input: ClientInput
) -> ProviderInput:
return {
"name": client_input["displayName"],
"external_id": client_input.get("externalId"),
"user_ids": [
int(member["value"]) for member in client_input.get("members", [])
],
}
def create_provider_resource(
name: str,
tenant_id: int,
user_ids: list[str] | None = None,
**kwargs: dict[str, Any],
) -> dict[str, Any]:
) -> ProviderResource:
with pg_client.PostgresClient() as cur:
kwargs["name"] = name
kwargs["tenant_id"] = tenant_id
@ -136,10 +211,10 @@ def create_resource(
LIMIT 1;
"""
)
return helper.dict_to_camel_case(cur.fetchone())
return cur.fetchone()
def delete_resource(group_id: int, tenant_id: int) -> None:
def delete_provider_resource(resource_id: ResourceId, tenant_id: int) -> None:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
@ -148,7 +223,7 @@ def delete_resource(group_id: int, tenant_id: int) -> None:
WHERE groups.group_id = %(group_id)s AND groups.tenant_id = %(tenant_id)s;
"""
),
{"tenant_id": tenant_id, "group_id": group_id},
{"tenant_id": tenant_id, "group_id": resource_id},
)
@ -214,30 +289,30 @@ def _update_resource_sql(
LIMIT 1;
"""
)
return helper.dict_to_camel_case(cur.fetchone())
return cur.fetchone()
def update_resource(
group_id: int,
def rewrite_provider_resource(
resource_id: int,
tenant_id: int,
name: str,
**kwargs: dict[str, Any],
) -> dict[str, Any]:
return _update_resource_sql(
group_id=group_id,
group_id=resource_id,
tenant_id=tenant_id,
name=name,
**kwargs,
)
def patch_resource(
group_id: int,
def update_provider_resource(
resource_id: int,
tenant_id: int,
**kwargs: dict[str, Any],
):
return _update_resource_sql(
group_id=group_id,
group_id=resource_id,
tenant_id=tenant_id,
**kwargs,
)

View file

@ -45,83 +45,50 @@ def get_all_attribute_names_where_returned_is_always(
def filter_attributes(
resource: dict[str, Any], include_list: list[str]
obj: dict[str, Any],
attributes_query_str: str | None,
excluded_attributes_query_str: str | None,
schema: dict[str, Any],
) -> dict[str, Any]:
result = {}
all_attributes = get_all_attribute_names(schema)
always_returned_attributes = get_all_attribute_names_where_returned_is_always(
schema
)
included_attributes = convert_query_str_to_list(attributes_query_str)
included_attributes = included_attributes or all_attributes
included_attributes_set = set(included_attributes).union(
set(always_returned_attributes)
)
excluded_attributes = convert_query_str_to_list(excluded_attributes_query_str)
excluded_attributes = excluded_attributes or []
excluded_attributes_set = set(excluded_attributes).difference(
set(always_returned_attributes)
)
include_paths = included_attributes_set.difference(excluded_attributes_set)
# Group include paths by top-level key
includes_by_key = {}
for path in include_list:
parts = path.split(".", 1)
key = parts[0]
rest = parts[1] if len(parts) == 2 else None
includes_by_key.setdefault(key, []).append(rest)
include_tree = {}
for path in include_paths:
parts = path.split(".")
node = include_tree
for part in parts:
node = node.setdefault(part, {})
for key, subpaths in includes_by_key.items():
if key not in resource:
continue
def _recurse(o, tree, parent_key=None):
if isinstance(o, dict):
out = {}
for key, subtree in tree.items():
if key in o:
out[key] = _recurse(o[key], subtree, key)
return out
if isinstance(o, list):
out = [_recurse(item, tree, parent_key) for item in o]
return out
return o
value = resource[key]
if all(p is None for p in subpaths):
result[key] = value
else:
nested_paths = [p for p in subpaths if p is not None]
if isinstance(value, dict):
filtered = filter_attributes(value, nested_paths)
if filtered:
result[key] = filtered
elif isinstance(value, list):
new_list = []
for item in value:
if isinstance(item, dict):
filtered_item = filter_attributes(item, nested_paths)
if filtered_item:
new_list.append(filtered_item)
if new_list:
result[key] = new_list
result = _recurse(obj, include_tree)
return result
def exclude_attributes(
resource: dict[str, Any], exclude_list: list[str]
) -> dict[str, Any]:
exclude_map = {}
for attr in exclude_list:
parts = attr.split(".", 1)
key = parts[0]
# rest is empty string for top-level exclusion
rest = parts[1] if len(parts) == 2 else ""
exclude_map.setdefault(key, []).append(rest)
new_resource = {}
for key, value in resource.items():
if key in exclude_map:
subs = exclude_map[key]
# If any attr has no rest, exclude entire key
if "" in subs:
continue
# Exclude nested attributes
if isinstance(value, dict):
new_sub = exclude_attributes(value, subs)
if not new_sub:
continue
new_resource[key] = new_sub
elif isinstance(value, list):
new_list = []
for item in value:
# note(jon): `item` should always be a dict here
new_item = exclude_attributes(item, subs)
new_list.append(new_item)
new_resource[key] = new_list
else:
# No exclusion for this key: copy safely
if isinstance(value, (dict, list)):
new_resource[key] = deepcopy(value)
else:
new_resource[key] = value
return new_resource
def filter_mutable_attributes(
schema: dict[str, Any],
requested_changes: dict[str, Any],

View file

@ -0,0 +1,78 @@
from dataclasses import dataclass
from typing import Any, Callable
from routers.scim.constants import (
SCHEMA_IDS_TO_SCHEMA_DETAILS,
)
from routers.scim import helpers
Schema = dict[str, Any]
ProviderResource = dict[str, Any]
ClientResource = dict[str, Any]
ResourceId = int | str
ClientInput = dict[str, Any]
ProviderInput = dict[str, Any]
@dataclass
class ResourceConfig:
schema_id: str
max_chunk_size: int
get_active_resource_count: Callable[[int], int]
convert_provider_resource_to_client_resource: Callable[
[ProviderResource], ClientResource
]
get_provider_resource_chunk: Callable[[int, int, int], list[ProviderResource]]
get_provider_resource: Callable[[ResourceId, int], ProviderResource | None]
convert_client_resource_creation_input_to_provider_resource_creation_input: (
Callable[[int, ClientInput], ProviderInput]
)
get_provider_resource_from_unique_fields: Callable[..., ProviderResource | None]
restore_provider_resource: Callable[..., ProviderResource] | None
create_provider_resource: Callable[..., ProviderResource]
delete_provider_resource: Callable[[ResourceId, int], None]
convert_client_resource_rewrite_input_to_provider_resource_rewrite_input: Callable[
[int, ClientInput], ProviderInput
]
rewrite_provider_resource: Callable[..., ProviderResource]
convert_client_resource_update_input_to_provider_resource_update_input: Callable[
[int, ClientInput], ProviderInput
]
update_provider_resource: Callable[..., ProviderResource]
def get_schema(config: ResourceConfig) -> Schema:
return SCHEMA_IDS_TO_SCHEMA_DETAILS[config.schema_id]
def convert_provider_resource_to_client_resource(
config: ResourceConfig,
provider_resource: ProviderResource,
attributes_query_str: str | None,
excluded_attributes_query_str: str | None,
) -> ClientResource:
client_resource = config.convert_provider_resource_to_client_resource(
provider_resource
)
schema = get_schema(config)
client_resource = helpers.filter_attributes(
client_resource, attributes_query_str, excluded_attributes_query_str, schema
)
return client_resource
def get_resource(
config: ResourceConfig,
resource_id: ResourceId,
tenant_id: int,
attributes: str | None = None,
excluded_attributes: str | None = None,
) -> ClientResource | None:
provider_resource = config.get_provider_resource(resource_id, tenant_id)
if provider_resource is None:
return None
client_resource = convert_provider_resource_to_client_resource(
config, provider_resource, attributes, excluded_attributes
)
return client_resource

View file

@ -0,0 +1,395 @@
from typing import Any
from datetime import datetime
from psycopg2.extensions import AsIs
from chalicelib.utils import pg_client
from chalicelib.core import roles
from routers.scim.resource_config import (
ProviderResource,
ClientResource,
ResourceId,
ClientInput,
ProviderInput,
)
def convert_client_resource_update_input_to_provider_resource_update_input(
tenant_id: int, client_input: ClientInput
) -> ProviderInput:
result = {}
if "userType" in client_input:
role = roles.get_role_by_name(tenant_id, client_input["userType"])
result["role_id"] = role["roleId"] if role else None
if "name" in client_input:
# note(jon): we're currently not handling the case where the client
# send patches of individual name components (e.g. name.middleName)
name = client_input.get("name", {}).get("formatted")
if name:
result["name"] = name
if "userName" in client_input:
result["email"] = client_input["userName"]
if "externalId" in client_input:
result["internal_id"] = client_input["externalId"]
if "active" in client_input:
result["deleted_at"] = None if client_input["active"] else datetime.now()
return result
def convert_client_resource_rewrite_input_to_provider_resource_rewrite_input(
tenant_id: int, client_input: ClientInput
) -> ProviderInput:
role_id = None
if "userType" in client_input:
role = roles.get_role_by_name(tenant_id, client_input["userType"])
role_id = role["roleId"] if role else None
name = client_input.get("name", {}).get("formatted")
if not name:
name = " ".join(
[
x
for x in [
client_input.get("name", {}).get("honorificPrefix"),
client_input.get("name", {}).get("givenName"),
client_input.get("name", {}).get("middleName"),
client_input.get("name", {}).get("familyName"),
client_input.get("name", {}).get("honorificSuffix"),
]
if x
]
)
result = {
"email": client_input["userName"],
"internal_id": client_input.get("externalId"),
"name": name,
"role_id": role_id,
}
result = {k: v for k, v in result.items() if v is not None}
return result
def convert_client_resource_creation_input_to_provider_resource_creation_input(
tenant_id: int, client_input: ClientInput
) -> ProviderInput:
role_id = None
if "userType" in client_input:
role = roles.get_role_by_name(tenant_id, client_input["userType"])
role_id = role["roleId"] if role else None
name = client_input.get("name", {}).get("formatted")
if not name:
name = " ".join(
[
x
for x in [
client_input.get("name", {}).get("honorificPrefix"),
client_input.get("name", {}).get("givenName"),
client_input.get("name", {}).get("middleName"),
client_input.get("name", {}).get("familyName"),
client_input.get("name", {}).get("honorificSuffix"),
]
if x
]
)
result = {
"email": client_input["userName"],
"internal_id": client_input.get("externalId"),
"name": name,
"role_id": role_id,
}
result = {k: v for k, v in result.items() if v is not None}
return result
def get_provider_resource_from_unique_fields(
email: str, **kwargs: dict[str, Any]
) -> ProviderResource | None:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT *
FROM public.users
WHERE users.email = %(email)s
""",
{"email": email},
)
)
return cur.fetchone()
def delete_provider_resource(resource_id: ResourceId, tenant_id: int) -> None:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
UPDATE public.users
SET
deleted_at = NULL,
updated_at = default
WHERE
users.user_id = %(user_id)s
AND users.tenant_id = %(tenant_id)s
""",
{"user_id": resource_id, "tenant_id": tenant_id},
)
)
def convert_provider_resource_to_client_resource(
provider_resource: ProviderResource,
) -> ClientResource:
return {
"id": str(provider_resource["user_id"]),
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"meta": {
"resourceType": "User",
"created": provider_resource["created_at"].strftime("%Y-%m-%dT%H:%M:%SZ"),
"lastModified": provider_resource["updated_at"].strftime(
"%Y-%m-%dT%H:%M:%SZ"
),
"location": f"Users/{provider_resource['user_id']}",
},
"userName": provider_resource["email"],
"externalId": provider_resource["internal_id"],
"name": {
"formatted": provider_resource["name"],
},
"displayName": provider_resource["name"] or provider_resource["email"],
"userType": provider_resource.get("role_name"),
"active": provider_resource["deleted_at"] is None,
}
def get_active_resource_count(tenant_id: int) -> int:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT COUNT(*)
FROM public.users
WHERE
users.tenant_id = %(tenant_id)s
AND users.deleted_at IS NULL
""",
{"tenant_id": tenant_id},
)
)
return cur.fetchone()["count"]
def get_provider_resource_chunk(
offset: int, tenant_id: int, limit: int
) -> list[ProviderResource]:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT
users.*,
roles.name AS role_name
FROM public.users
LEFT JOIN public.roles USING (role_id)
WHERE
users.tenant_id = %(tenant_id)s
AND users.deleted_at IS NULL
LIMIT %(limit)s
OFFSET %(offset)s;
""",
{"offset": offset, "limit": limit, "tenant_id": tenant_id},
)
)
return cur.fetchall()
def get_provider_resource(
resource_id: ResourceId, tenant_id: int
) -> ProviderResource | None:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT
users.*,
roles.name AS role_name
FROM public.users
LEFT JOIN public.roles USING (role_id)
WHERE
users.user_id = %(user_id)s
AND users.tenant_id = %(tenant_id)s
AND users.deleted_at IS NULL
LIMIT 1;
""",
{
"user_id": resource_id,
"tenant_id": tenant_id,
},
)
)
return cur.fetchone()
def create_provider_resource(
email: str,
tenant_id: int,
name: str = "",
internal_id: str | None = None,
role_id: int | None = None,
) -> ProviderResource:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
WITH u AS (
INSERT INTO public.users (
tenant_id,
email,
name,
internal_id,
role_id
)
VALUES (
%(tenant_id)s,
%(email)s,
%(name)s,
%(internal_id)s,
%(role_id)s
)
RETURNING *
)
SELECT
u.*,
roles.name as role_name
FROM u LEFT JOIN public.roles USING (role_id);
""",
{
"tenant_id": tenant_id,
"email": email,
"name": name,
"internal_id": internal_id,
"role_id": role_id,
},
)
)
return cur.fetchone()
def restore_provider_resource(
tenant_id: int,
email: str,
name: str = "",
internal_id: str | None = None,
role_id: int | None = None,
**kwargs: dict[str, Any],
) -> ProviderResource:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
WITH u AS (
UPDATE public.users
SET
tenant_id = %(tenant_id)s,
email = %(email)s,
name = %(name)s,
internal_id = %(internal_id)s,
role_id = %(role_id)s,
deleted_at = NULL,
created_at = now(),
updated_at = now(),
api_key = default,
jwt_iat = NULL,
weekly_report = default
WHERE users.email = %(email)s
RETURNING *
)
SELECT
u.*,
roles.name as role_name
FROM u LEFT JOIN public.roles USING (role_id);
""",
{
"tenant_id": tenant_id,
"email": email,
"name": name,
"internal_id": internal_id,
"role_id": role_id,
},
)
)
return cur.fetchone()
def rewrite_provider_resource(
resource_id: int,
tenant_id: int,
email: str,
name: str = "",
internal_id: str | None = None,
role_id: int | None = None,
):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
WITH u AS (
UPDATE public.users
SET
email = %(email)s,
name = %(name)s,
internal_id = %(internal_id)s,
role_id = %(role_id)s,
updated_at = now()
WHERE
users.user_id = %(user_id)s
AND users.tenant_id = %(tenant_id)s
AND users.deleted_at IS NULL
RETURNING *
)
SELECT
u.*,
roles.name as role_name
FROM u LEFT JOIN public.roles USING (role_id);
""",
{
"tenant_id": tenant_id,
"user_id": resource_id,
"email": email,
"name": name,
"internal_id": internal_id,
"role_id": role_id,
},
)
)
return cur.fetchone()
def update_provider_resource(
resource_id: int,
tenant_id: int,
**kwargs,
):
with pg_client.PostgresClient() as cur:
set_fragments = []
kwargs["updated_at"] = datetime.now()
for k, v in kwargs.items():
fragment = cur.mogrify(
"%s = %s",
(AsIs(k), v),
).decode("utf-8")
set_fragments.append(fragment)
set_clause = ", ".join(set_fragments)
query = f"""
WITH u AS (
UPDATE public.users
SET {set_clause}
WHERE
users.user_id = {resource_id}
AND users.tenant_id = {tenant_id}
AND users.deleted_at IS NULL
RETURNING *
)
SELECT
u.*,
roles.name as role_name
FROM u LEFT JOIN public.roles USING (role_id);"""
cur.execute(query)
return cur.fetchone()

View file

@ -1,22 +0,0 @@
# note(jon): please see https://datatracker.ietf.org/doc/html/rfc7643 for details on these constants
import json
SCHEMAS = sorted(
[
json.load(open("routers/fixtures/service_provider_config_schema.json", "r")),
json.load(open("routers/fixtures/resource_type_schema.json", "r")),
json.load(open("routers/fixtures/schema_schema.json", "r")),
json.load(open("routers/fixtures/user_schema.json", "r")),
json.load(open("routers/fixtures/group_schema.json", "r")),
],
key=lambda x: x["id"],
)
SERVICE_PROVIDER_CONFIG = json.load(
open("routers/fixtures/service_provider_config.json", "r")
)
RESOURCE_TYPES = sorted(
json.load(open("routers/fixtures/resource_type.json", "r")),
key=lambda x: x["id"],
)