reformat files and remove unnecessary imports

This commit is contained in:
Jonathan Griffin 2025-04-18 10:37:31 +02:00
parent a8d36d40b5
commit 464b9b1b47
6 changed files with 588 additions and 300 deletions

View file

@ -9,13 +9,15 @@ from chalicelib.utils.TimeUTC import TimeUTC
def __exists_by_name(tenant_id: int, name: str, exclude_id: Optional[int]) -> bool:
with pg_client.PostgresClient() as cur:
query = cur.mogrify(f"""SELECT EXISTS(SELECT 1
query = cur.mogrify(
f"""SELECT EXISTS(SELECT 1
FROM public.roles
WHERE tenant_id = %(tenant_id)s
AND name ILIKE %(name)s
AND deleted_at ISNULL
{"AND role_id!=%(exclude_id)s" if exclude_id else ""}) AS exists;""",
{"tenant_id": tenant_id, "name": name, "exclude_id": exclude_id})
{"tenant_id": tenant_id, "name": name, "exclude_id": exclude_id},
)
cur.execute(query=query)
row = cur.fetchone()
return row["exists"]
@ -27,24 +29,31 @@ def update(tenant_id, user_id, role_id, data: schemas.RolePayloadSchema):
if not admin["admin"] and not admin["superAdmin"]:
return {"errors": ["unauthorized"]}
if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=role_id):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"name already exists.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="name already exists."
)
if not data.all_projects and (data.projects is None or len(data.projects) == 0):
return {"errors": ["must specify a project or all projects"]}
if data.projects is not None and len(data.projects) > 0 and not data.all_projects:
data.projects = projects.is_authorized_batch(project_ids=data.projects, tenant_id=tenant_id)
data.projects = projects.is_authorized_batch(
project_ids=data.projects, tenant_id=tenant_id
)
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT 1
query = cur.mogrify(
"""SELECT 1
FROM public.roles
WHERE role_id = %(role_id)s
AND tenant_id = %(tenant_id)s
AND protected = TRUE
LIMIT 1;""",
{"tenant_id": tenant_id, "role_id": role_id})
{"tenant_id": tenant_id, "role_id": role_id},
)
cur.execute(query=query)
if cur.fetchone() is not None:
return {"errors": ["this role is protected"]}
query = cur.mogrify("""UPDATE public.roles
query = cur.mogrify(
"""UPDATE public.roles
SET name= %(name)s,
description= %(description)s,
permissions= %(permissions)s,
@ -56,23 +65,31 @@ def update(tenant_id, user_id, role_id, data: schemas.RolePayloadSchema):
RETURNING *, COALESCE((SELECT ARRAY_AGG(project_id)
FROM roles_projects
WHERE roles_projects.role_id=%(role_id)s),'{}') AS projects;""",
{"tenant_id": tenant_id, "role_id": role_id, **data.model_dump()})
{"tenant_id": tenant_id, "role_id": role_id, **data.model_dump()},
)
cur.execute(query=query)
row = cur.fetchone()
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
if not data.all_projects:
d_projects = [i for i in row["projects"] if i not in data.projects]
if len(d_projects) > 0:
query = cur.mogrify("""DELETE FROM roles_projects
query = cur.mogrify(
"""DELETE FROM roles_projects
WHERE role_id=%(role_id)s
AND project_id IN %(project_ids)s""",
{"role_id": role_id, "project_ids": tuple(d_projects)})
{"role_id": role_id, "project_ids": tuple(d_projects)},
)
cur.execute(query=query)
n_projects = [i for i in data.projects if i not in row["projects"]]
if len(n_projects) > 0:
query = cur.mogrify(f"""INSERT INTO roles_projects(role_id, project_id)
query = cur.mogrify(
f"""INSERT INTO roles_projects(role_id, project_id)
VALUES {",".join([f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(n_projects))])}""",
{"role_id": role_id, **{f"project_id_{i}": p for i, p in enumerate(n_projects)}})
{
"role_id": role_id,
**{f"project_id_{i}": p for i, p in enumerate(n_projects)},
},
)
cur.execute(query=query)
row["projects"] = data.projects
@ -86,28 +103,44 @@ def create(tenant_id, user_id, data: schemas.RolePayloadSchema):
return {"errors": ["unauthorized"]}
if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=None):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"name already exists.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="name already exists."
)
if not data.all_projects and (data.projects is None or len(data.projects) == 0):
return {"errors": ["must specify a project or all projects"]}
if data.projects is not None and len(data.projects) > 0 and not data.all_projects:
data.projects = projects.is_authorized_batch(project_ids=data.projects, tenant_id=tenant_id)
data.projects = projects.is_authorized_batch(
project_ids=data.projects, tenant_id=tenant_id
)
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""INSERT INTO roles(tenant_id, name, description, permissions, all_projects)
query = cur.mogrify(
"""INSERT INTO roles(tenant_id, name, description, permissions, all_projects)
VALUES (%(tenant_id)s, %(name)s, %(description)s, %(permissions)s::text[], %(all_projects)s)
RETURNING *;""",
{"tenant_id": tenant_id, "name": data.name, "description": data.description,
"permissions": data.permissions, "all_projects": data.all_projects})
{
"tenant_id": tenant_id,
"name": data.name,
"description": data.description,
"permissions": data.permissions,
"all_projects": data.all_projects,
},
)
cur.execute(query=query)
row = cur.fetchone()
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
row["projects"] = []
if not data.all_projects:
role_id = row["role_id"]
query = cur.mogrify(f"""INSERT INTO roles_projects(role_id, project_id)
query = cur.mogrify(
f"""INSERT INTO roles_projects(role_id, project_id)
VALUES {",".join(f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(data.projects)))}
RETURNING project_id;""",
{"role_id": role_id, **{f"project_id_{i}": p for i, p in enumerate(data.projects)}})
{
"role_id": role_id,
**{f"project_id_{i}": p for i, p in enumerate(data.projects)},
},
)
cur.execute(query=query)
row["projects"] = [r["project_id"] for r in cur.fetchall()]
return helper.dict_to_camel_case(row)
@ -115,7 +148,8 @@ def create(tenant_id, user_id, data: schemas.RolePayloadSchema):
def get_roles(tenant_id):
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT roles.*, COALESCE(projects, '{}') AS projects
query = cur.mogrify(
"""SELECT roles.*, COALESCE(projects, '{}') AS projects
FROM public.roles
LEFT JOIN LATERAL (SELECT array_agg(project_id) AS projects
FROM roles_projects
@ -126,7 +160,8 @@ def get_roles(tenant_id):
AND deleted_at IS NULL
AND not service_role
ORDER BY role_id;""",
{"tenant_id": tenant_id})
{"tenant_id": tenant_id},
)
cur.execute(query=query)
rows = cur.fetchall()
for r in rows:
@ -136,12 +171,14 @@ def get_roles(tenant_id):
def get_role_by_name(tenant_id, name):
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT *
query = cur.mogrify(
"""SELECT *
FROM public.roles
WHERE tenant_id =%(tenant_id)s
AND deleted_at IS NULL
AND name ILIKE %(name)s;""",
{"tenant_id": tenant_id, "name": name})
{"tenant_id": tenant_id, "name": name},
)
cur.execute(query=query)
row = cur.fetchone()
if row is not None:
@ -155,45 +192,53 @@ def delete(tenant_id, user_id, role_id):
if not admin["admin"] and not admin["superAdmin"]:
return {"errors": ["unauthorized"]}
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT 1
query = cur.mogrify(
"""SELECT 1
FROM public.roles
WHERE role_id = %(role_id)s
AND tenant_id = %(tenant_id)s
AND protected = TRUE
LIMIT 1;""",
{"tenant_id": tenant_id, "role_id": role_id})
{"tenant_id": tenant_id, "role_id": role_id},
)
cur.execute(query=query)
if cur.fetchone() is not None:
return {"errors": ["this role is protected"]}
query = cur.mogrify("""SELECT 1
query = cur.mogrify(
"""SELECT 1
FROM public.users
WHERE role_id = %(role_id)s
AND tenant_id = %(tenant_id)s
LIMIT 1;""",
{"tenant_id": tenant_id, "role_id": role_id})
{"tenant_id": tenant_id, "role_id": role_id},
)
cur.execute(query=query)
if cur.fetchone() is not None:
return {"errors": ["this role is already attached to other user(s)"]}
query = cur.mogrify("""UPDATE public.roles
query = cur.mogrify(
"""UPDATE public.roles
SET deleted_at = timezone('utc'::text, now())
WHERE role_id = %(role_id)s
AND tenant_id = %(tenant_id)s
AND protected = FALSE;""",
{"tenant_id": tenant_id, "role_id": role_id})
{"tenant_id": tenant_id, "role_id": role_id},
)
cur.execute(query=query)
return get_roles(tenant_id=tenant_id)
def get_role(tenant_id, role_id):
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT roles.*
query = cur.mogrify(
"""SELECT roles.*
FROM public.roles
WHERE tenant_id =%(tenant_id)s
AND deleted_at IS NULL
AND not service_role
AND role_id = %(role_id)s
LIMIT 1;""",
{"tenant_id": tenant_id, "role_id": role_id})
{"tenant_id": tenant_id, "role_id": role_id},
)
cur.execute(query=query)
row = cur.fetchone()
if row is not None:

File diff suppressed because it is too large Load diff

View file

@ -13,8 +13,8 @@ REFRESH_SECRET_KEY = config("SCIM_REFRESH_SECRET_KEY")
ALGORITHM = config("SCIM_JWT_ALGORITHM")
ACCESS_TOKEN_EXPIRE_SECONDS = int(config("SCIM_ACCESS_TOKEN_EXPIRE_SECONDS"))
REFRESH_TOKEN_EXPIRE_SECONDS = int(config("SCIM_REFRESH_TOKEN_EXPIRE_SECONDS"))
AUDIENCE="okta_client"
ISSUER=config("JWT_ISSUER"),
AUDIENCE = "okta_client"
ISSUER = (config("JWT_ISSUER"),)
# Simulated Okta Client Credentials
# OKTA_CLIENT_ID = "okta-client"
@ -23,7 +23,7 @@ ISSUER=config("JWT_ISSUER"),
# class TokenRequest(BaseModel):
# client_id: str
# client_secret: str
# async def authenticate_client(token_request: TokenRequest):
# """Validate Okta Client Credentials and issue JWT"""
# if token_request.client_id != OKTA_CLIENT_ID or token_request.client_secret != OKTA_CLIENT_SECRET:
@ -31,6 +31,7 @@ ISSUER=config("JWT_ISSUER"),
# return {"access_token": create_jwt(), "token_type": "bearer"}
def create_tokens(tenant_id):
curr_time = time.time()
access_payload = {
@ -38,7 +39,7 @@ def create_tokens(tenant_id):
"sub": "scim_server",
"aud": AUDIENCE,
"iss": ISSUER,
"exp": ""
"exp": "",
}
access_payload.update({"exp": curr_time + ACCESS_TOKEN_EXPIRE_SECONDS})
access_token = jwt.encode(access_payload, ACCESS_SECRET_KEY, algorithm=ALGORITHM)
@ -49,18 +50,24 @@ def create_tokens(tenant_id):
return access_token, refresh_token
def verify_access_token(token: str):
try:
payload = jwt.decode(token, ACCESS_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE)
payload = jwt.decode(
token, ACCESS_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE
)
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")
def verify_refresh_token(token: str):
try:
payload = jwt.decode(token, REFRESH_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE)
payload = jwt.decode(
token, REFRESH_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE
)
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
@ -69,6 +76,8 @@ def verify_refresh_token(token: str):
required_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Authentication Dependency
def auth_required(token: str = Depends(required_oauth2_scheme)):
"""Dependency to check Authorization header."""
@ -78,6 +87,8 @@ def auth_required(token: str = Depends(required_oauth2_scheme)):
optional_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
def auth_optional(token: str | None = Depends(optional_oauth2_scheme)):
if token is None:
return None

View file

@ -1,7 +1,5 @@
import logging
import re
import uuid
from typing import Any, Literal, Optional
from typing import Any, Literal
import copy
from datetime import datetime
@ -9,11 +7,15 @@ from decouple import config
from fastapi import Depends, HTTPException, Header, Query, Response, Request
from fastapi.responses import JSONResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel, Field, field_serializer
from pydantic import BaseModel, field_serializer
import schemas
from chalicelib.core import users, roles, tenants
from chalicelib.utils.scim_auth import auth_optional, auth_required, create_tokens, verify_refresh_token
from chalicelib.core import users, tenants
from chalicelib.utils.scim_auth import (
auth_optional,
auth_required,
create_tokens,
verify_refresh_token,
)
from routers.base import get_routers
from routers.scim_constants import RESOURCE_TYPES, SCHEMAS, SERVICE_PROVIDER_CONFIG
from routers import scim_helpers
@ -26,29 +28,41 @@ public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2")
"""Authentication endpoints"""
class RefreshRequest(BaseModel):
refresh_token: str
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Login endpoint to generate tokens
@public_app.post("/token")
async def login(host: str = Header(..., alias="Host"), form_data: OAuth2PasswordRequestForm = Depends()):
async def login(
host: str = Header(..., alias="Host"),
form_data: OAuth2PasswordRequestForm = Depends(),
):
subdomain = host.split(".")[0]
# Missing authentication part, to add
if form_data.username != config("SCIM_USER") or form_data.password != config("SCIM_PASSWORD"):
if form_data.username != config("SCIM_USER") or form_data.password != config(
"SCIM_PASSWORD"
):
raise HTTPException(status_code=401, detail="Invalid credentials")
tenant = tenants.get_by_name(subdomain)
access_token, refresh_token = create_tokens(tenant_id=tenant["tenantId"])
return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"}
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
}
# Refresh token endpoint
@public_app.post("/refresh")
async def refresh_token(r: RefreshRequest):
payload = verify_refresh_token(r.refresh_token)
new_access_token, _ = create_tokens(tenant_id=payload["tenant_id"])
@ -68,7 +82,7 @@ def _not_found_error_response(resource_id: str):
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": f"Resource {resource_id} not found",
"status": "404",
}
},
)
@ -80,7 +94,7 @@ def _uniqueness_error_response():
"detail": "One or more of the attribute values are already in use or are reserved.",
"status": "409",
"scimType": "uniqueness",
}
},
)
@ -92,7 +106,7 @@ def _mutability_error_response():
"detail": "The attempted modification is not compatible with the target attribute's mutability or current state.",
"status": "400",
"scimType": "mutability",
}
},
)
@ -105,7 +119,7 @@ async def get_resource_types(filter_param: str | None = Query(None, alias="filte
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": "Operation is not permitted based on the supplied authorization",
"status": "403",
}
},
)
return JSONResponse(
status_code=200,
@ -130,8 +144,7 @@ async def get_resource_type(resource_id: str):
SCHEMA_IDS_TO_SCHEMA_DETAILS = {
schema_detail["id"]: schema_detail
for schema_detail in SCHEMAS
schema_detail["id"]: schema_detail for schema_detail in SCHEMAS
}
@ -144,7 +157,7 @@ async def get_schemas(filter_param: str | None = Query(None, alias="filter")):
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": "Operation is not permitted based on the supplied authorization",
"status": "403",
}
},
)
return JSONResponse(
status_code=200,
@ -154,9 +167,8 @@ async def get_schemas(filter_param: str | None = Query(None, alias="filter")):
"startIndex": 1,
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
"Resources": [
value
for _, value in sorted(SCHEMA_IDS_TO_SCHEMA_DETAILS.items())
]
value for _, value in sorted(SCHEMA_IDS_TO_SCHEMA_DETAILS.items())
],
},
)
@ -174,7 +186,9 @@ async def get_schema(schema_id: str):
# note(jon): it was recommended to make this endpoint partially open
# so that clients can view the `authenticationSchemes` prior to being authenticated.
@public_app.get("/ServiceProviderConfig")
async def get_service_provider_config(r: Request, tenant_id: str | None = Depends(auth_optional)):
async def get_service_provider_config(
r: Request, tenant_id: str | None = Depends(auth_optional)
):
content = copy.deepcopy(SERVICE_PROVIDER_CONFIG)
content["meta"]["location"] = str(r.url)
is_authenticated = tenant_id is not None
@ -193,6 +207,8 @@ async def get_service_provider_config(r: Request, tenant_id: str | None = Depend
"""
User endpoints
"""
class UserRequest(BaseModel):
userName: str
@ -203,7 +219,9 @@ class PatchUserRequest(BaseModel):
class ResourceMetaResponse(BaseModel):
resourceType: Literal["ServiceProviderConfig", "ResourceType", "Schema", "User"] | None = None
resourceType: (
Literal["ServiceProviderConfig", "ResourceType", "Schema", "User"] | None
) = None
created: datetime | None = None
lastModified: datetime | None = None
location: str | None = None
@ -231,12 +249,16 @@ class CommonResourceResponse(BaseModel):
class UserResponse(CommonResourceResponse):
schemas: list[Literal["urn:ietf:params:scim:schemas:core:2.0:User"]] = ["urn:ietf:params:scim:schemas:core:2.0:User"]
schemas: list[Literal["urn:ietf:params:scim:schemas:core:2.0:User"]] = [
"urn:ietf:params:scim:schemas:core:2.0:User"
]
userName: str | None = None
class QueryResourceResponse(BaseModel):
schemas: list[Literal["urn:ietf:params:scim:api:messages:2.0:ListResponse"]] = ["urn:ietf:params:scim:api:messages:2.0:ListResponse"]
schemas: list[Literal["urn:ietf:params:scim:api:messages:2.0:ListResponse"]] = [
"urn:ietf:params:scim:api:messages:2.0:ListResponse"
]
totalResults: int
# todo(jon): add the other schemas
Resources: list[UserResponse]
@ -247,21 +269,33 @@ class QueryResourceResponse(BaseModel):
MAX_USERS_PER_PAGE = 10
def _convert_db_user_to_scim_user(db_user: dict[str, Any], attributes: list[str] | None = None, excluded_attributes: list[str] | None = None) -> UserResponse:
user_schema = SCHEMA_IDS_TO_SCHEMA_DETAILS["urn:ietf:params:scim:schemas:core:2.0:User"]
def _convert_db_user_to_scim_user(
db_user: dict[str, Any],
attributes: list[str] | None = None,
excluded_attributes: list[str] | None = None,
) -> UserResponse:
user_schema = SCHEMA_IDS_TO_SCHEMA_DETAILS[
"urn:ietf:params:scim:schemas:core:2.0:User"
]
all_attributes = scim_helpers.get_all_attribute_names(user_schema)
attributes = attributes or all_attributes
always_returned_attributes = scim_helpers.get_all_attribute_names_where_returned_is_always(user_schema)
always_returned_attributes = (
scim_helpers.get_all_attribute_names_where_returned_is_always(user_schema)
)
included_attributes = list(set(attributes).union(set(always_returned_attributes)))
excluded_attributes = excluded_attributes or []
excluded_attributes = list(set(excluded_attributes).difference(set(always_returned_attributes)))
excluded_attributes = list(
set(excluded_attributes).difference(set(always_returned_attributes))
)
scim_user = {
"id": str(db_user["userId"]),
"meta": {
"resourceType": "User",
"created": db_user["createdAt"],
"lastModified": db_user["createdAt"], # todo(jon): we currently don't keep track of this in the db
"location": f"Users/{db_user['userId']}"
"lastModified": db_user[
"createdAt"
], # todo(jon): we currently don't keep track of this in the db
"location": f"Users/{db_user['userId']}",
},
"userName": db_user["email"],
}
@ -272,14 +306,16 @@ def _convert_db_user_to_scim_user(db_user: dict[str, Any], attributes: list[str]
@public_app.get("/Users")
async def get_users(
tenant_id = Depends(auth_required),
tenant_id=Depends(auth_required),
requested_start_index: int = Query(1, alias="startIndex"),
requested_items_per_page: int | None = Query(None, alias="count"),
attributes: list[str] | None = Query(None),
excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"),
):
start_index = max(1, requested_start_index)
items_per_page = min(max(0, requested_items_per_page or MAX_USERS_PER_PAGE), MAX_USERS_PER_PAGE)
items_per_page = min(
max(0, requested_items_per_page or MAX_USERS_PER_PAGE), MAX_USERS_PER_PAGE
)
# todo(jon): this might not be the most efficient thing to do. could be better to just do a count.
# but this is the fastest thing at the moment just to test that it's working
total_users = users.get_users_paginated(1, tenant_id)
@ -302,7 +338,7 @@ async def get_users(
@public_app.get("/Users/{user_id}")
def get_user(
user_id: str,
tenant_id = Depends(auth_required),
tenant_id=Depends(auth_required),
attributes: list[str] | None = Query(None),
excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"),
):
@ -311,13 +347,12 @@ def get_user(
return _not_found_error_response(user_id)
scim_user = _convert_db_user_to_scim_user(db_user, attributes, excluded_attributes)
return JSONResponse(
status_code=200,
content=scim_user.model_dump(mode="json", exclude_none=True)
status_code=200, content=scim_user.model_dump(mode="json", exclude_none=True)
)
@public_app.post("/Users")
async def create_user(r: UserRequest, tenant_id = Depends(auth_required)):
async def create_user(r: UserRequest, tenant_id=Depends(auth_required)):
# note(jon): this method will return soft deleted users as well
existing_db_user = users.get_existing_scim_user_by_unique_values(r.userName)
if existing_db_user and existing_db_user["deletedAt"] is None:
@ -334,23 +369,26 @@ async def create_user(r: UserRequest, tenant_id = Depends(auth_required)):
)
scim_user = _convert_db_user_to_scim_user(db_user)
response = JSONResponse(
status_code=201,
content=scim_user.model_dump(mode="json", exclude_none=True)
status_code=201, content=scim_user.model_dump(mode="json", exclude_none=True)
)
response.headers["Location"] = scim_user.meta.location
return response
@public_app.put("/Users/{user_id}")
def update_user(user_id: str, r: UserRequest, tenant_id = Depends(auth_required)):
def update_user(user_id: str, r: UserRequest, tenant_id=Depends(auth_required)):
db_resource = users.get_scim_user_by_id(user_id, tenant_id)
if not db_resource:
return _not_found_error_response(user_id)
current_scim_resource = _convert_db_user_to_scim_user(db_resource).model_dump(mode="json", exclude_none=True)
current_scim_resource = _convert_db_user_to_scim_user(db_resource).model_dump(
mode="json", exclude_none=True
)
changes = r.model_dump(mode="json")
schema = SCHEMA_IDS_TO_SCHEMA_DETAILS["urn:ietf:params:scim:schemas:core:2.0:User"]
try:
valid_mutable_changes = scim_helpers.filter_mutable_attributes(schema, changes, current_scim_resource)
valid_mutable_changes = scim_helpers.filter_mutable_attributes(
schema, changes, current_scim_resource
)
except ValueError:
# todo(jon): will need to add a test for this once we have an immutable field
return _mutability_error_response()
@ -371,7 +409,7 @@ def update_user(user_id: str, r: UserRequest, tenant_id = Depends(auth_required)
@public_app.delete("/Users/{user_id}")
def delete_user(user_id: str, tenant_id = Depends(auth_required)):
def delete_user(user_id: str, tenant_id=Depends(auth_required)):
user = users.get_scim_user_by_id(user_id, tenant_id)
if not user:
return _not_found_error_response(user_id)

View file

@ -1,22 +1,22 @@
# note(jon): please see https://datatracker.ietf.org/doc/html/rfc7643 for details on these constants
from typing import Any, Literal
from typing import Any
def _attribute_characteristics(
name: str,
description: str,
type: str="string",
sub_attributes: dict[str, Any] | None=None,
# note(jon): no default for multiValued is defined in the docs and it is marked as optional.
# from our side, we'll default it to False.
multi_valued: bool=False,
required: bool=False,
canonical_values: list[str] | None=None,
case_exact: bool=False,
mutability: str="readWrite",
returned: str="default",
uniqueness: str="none",
reference_types: list[str] | None=None,
name: str,
description: str,
type: str = "string",
sub_attributes: dict[str, Any] | None = None,
# note(jon): no default for multiValued is defined in the docs and it is marked as optional.
# from our side, we'll default it to False.
multi_valued: bool = False,
required: bool = False,
canonical_values: list[str] | None = None,
case_exact: bool = False,
mutability: str = "readWrite",
returned: str = "default",
uniqueness: str = "none",
reference_types: list[str] | None = None,
):
characteristics = {
"name": name,
@ -33,14 +33,16 @@ def _attribute_characteristics(
"referenceTypes": reference_types,
}
characteristics_without_none = {
key: value
for key, value in characteristics.items()
if value is not None
key: value for key, value in characteristics.items() if value is not None
}
return characteristics_without_none
def _multi_valued_attributes(type_canonical_values: list[str], type_required: bool=False, type_mutability="readWrite"):
def _multi_valued_attributes(
type_canonical_values: list[str],
type_required: bool = False,
type_mutability="readWrite",
):
return [
_attribute_characteristics(
name="type",
@ -68,7 +70,7 @@ def _multi_valued_attributes(type_canonical_values: list[str], type_required: bo
name="$ref",
type="reference",
reference_types=["uri"],
description="The reference URI of a target resource."
description="The reference URI of a target resource.",
),
]
@ -77,7 +79,7 @@ def _multi_valued_attributes(type_canonical_values: list[str], type_required: bo
# in section 3.1 of RFC7643, it is specified that ResourceType and
# ServiceProviderConfig are not included in the common attributes. but
# in other references, they treat them as a resource.
def _common_resource_attributes(id_required: bool=True, id_uniqueness: str="none"):
def _common_resource_attributes(id_required: bool = True, id_uniqueness: str = "none"):
return [
_attribute_characteristics(
name="id",
@ -151,7 +153,6 @@ def _common_resource_attributes(id_required: bool=True, id_uniqueness: str="none
]
SERVICE_PROVIDER_CONFIG_SCHEMA = {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"],
"id": "urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig",
@ -339,7 +340,7 @@ SERVICE_PROVIDER_CONFIG_SCHEMA = {
),
],
),
]
],
}
@ -409,9 +410,9 @@ RESOURCE_TYPE_SCHEMA = {
required=True,
mutability="readOnly",
),
]
],
),
]
],
}
SCHEMA_SCHEMA = {
@ -548,7 +549,7 @@ SCHEMA_SCHEMA = {
canonical_values=[
# todo(jon): add "User" and "Group" once those are done.
"external",
"uri"
"uri",
],
case_exact=True,
),
@ -659,15 +660,15 @@ SCHEMA_SCHEMA = {
canonical_values=[
# todo(jon): add "User" and "Group" once those are done.
"external",
"uri"
"uri",
],
case_exact=True,
),
],
),
]
)
]
],
),
],
}
@ -749,7 +750,7 @@ SERVICE_PROVIDER_CONFIG = {
# and then updating these timestamps from an api and such. for now, if we update
# the configuration, we should update the timestamp here.
"lastModified": "2025-04-15T15:45:00Z",
"location": "", # note(jon): this field will be computed in the /ServiceProviderConfig endpoint
"location": "", # note(jon): this field will be computed in the /ServiceProviderConfig endpoint
},
}

View file

@ -4,6 +4,7 @@ from copy import deepcopy
def get_all_attribute_names(schema: dict[str, Any]) -> list[str]:
result = []
def _walk(attrs, prefix=None):
for attr in attrs:
name = attr["name"]
@ -12,12 +13,16 @@ def get_all_attribute_names(schema: dict[str, Any]) -> list[str]:
if attr["type"] == "complex":
sub = attr.get("subAttributes") or attr.get("attributes") or []
_walk(sub, path)
_walk(schema["attributes"])
return result
def get_all_attribute_names_where_returned_is_always(schema: dict[str, Any]) -> list[str]:
def get_all_attribute_names_where_returned_is_always(
schema: dict[str, Any],
) -> list[str]:
result = []
def _walk(attrs, prefix=None):
for attr in attrs:
name = attr["name"]
@ -27,11 +32,14 @@ def get_all_attribute_names_where_returned_is_always(schema: dict[str, Any]) ->
if attr["type"] == "complex":
sub = attr.get("subAttributes") or attr.get("attributes") or []
_walk(sub, path)
_walk(schema["attributes"])
return result
def filter_attributes(resource: dict[str, Any], include_list: list[str]) -> dict[str, Any]:
def filter_attributes(
resource: dict[str, Any], include_list: list[str]
) -> dict[str, Any]:
result = {}
for attr in include_list:
parts = attr.split(".", 1)
@ -63,7 +71,9 @@ def filter_attributes(resource: dict[str, Any], include_list: list[str]) -> dict
return result
def exclude_attributes(resource: dict[str, Any], exclude_list: list[str]) -> dict[str, Any]:
def exclude_attributes(
resource: dict[str, Any], exclude_list: list[str]
) -> dict[str, Any]:
exclude_map = {}
for attr in exclude_list:
parts = attr.split(".", 1)
@ -105,7 +115,11 @@ def exclude_attributes(resource: dict[str, Any], exclude_list: list[str]) -> dic
return new_resource
def filter_mutable_attributes(schema: dict[str, Any], requested_changes: dict[str, Any], current: dict[str, Any]) -> dict[str, Any]:
def filter_mutable_attributes(
schema: dict[str, Any],
requested_changes: dict[str, Any],
current_values: dict[str, Any],
) -> dict[str, Any]:
attributes = {attr.get("name"): attr for attr in schema.get("attributes", [])}
valid_changes = {}