add patch endpoint for groups

This commit is contained in:
Jonathan Griffin 2025-04-24 10:41:06 +02:00
parent 54b8ccb39c
commit 23d696b407
4 changed files with 436 additions and 122 deletions

View file

@ -483,8 +483,8 @@ def restore_scim_user(
internal_id = %(internal_id)s, internal_id = %(internal_id)s,
role_id = %(role_id)s, role_id = %(role_id)s,
deleted_at = NULL, deleted_at = NULL,
created_at = default, created_at = now(),
updated_at = default, updated_at = now(),
api_key = default, api_key = default,
jwt_iat = NULL, jwt_iat = NULL,
weekly_report = default weekly_report = default
@ -527,7 +527,7 @@ def update_scim_user(
name = %(name)s, name = %(name)s,
internal_id = %(internal_id)s, internal_id = %(internal_id)s,
role_id = %(role_id)s, role_id = %(role_id)s,
updated_at = default updated_at = now()
WHERE WHERE
users.user_id = %(user_id)s users.user_id = %(user_id)s
AND users.tenant_id = %(tenant_id)s AND users.tenant_id = %(tenant_id)s
@ -582,7 +582,7 @@ def patch_scim_user(
roles.name as role_name roles.name as role_name
FROM u LEFT JOIN public.roles USING (role_id);""" FROM u LEFT JOIN public.roles USING (role_id);"""
cur.execute(query) cur.execute(query)
return helper.dict_to_camel_case(cur.fetchone()) return helper.dict_to_camel_case(cur.fetchone())
def generate_new_api_key(user_id): def generate_new_api_key(user_id):

View file

@ -254,22 +254,48 @@ def _parse_scim_user_input(data: dict[str, Any], tenant_id: str) -> dict[str, An
if "userType" in data: if "userType" in data:
role = roles.get_role_by_name(tenant_id, data["userType"]) role = roles.get_role_by_name(tenant_id, data["userType"])
role_id = role["roleId"] if role else None role_id = role["roleId"] if role else None
name = data.get("name", {}).get("formatted")
if not name:
name = " ".join(
[
x
for x in [
data.get("name", {}).get("honorificPrefix"),
data.get("name", {}).get("givenName"),
data.get("name", {}).get("middleName"),
data.get("name", {}).get("familyName"),
data.get("name", {}).get("honorificSuffix"),
]
if x
]
)
result = { result = {
"email": data["userName"], "email": data["userName"],
"internal_id": data.get("externalId"), "internal_id": data.get("externalId"),
"name": data.get("name", {}).get("formatted") or data.get("displayName"), "name": name,
"role_id": role_id, "role_id": role_id,
} }
result = {k: v for k, v in result.items() if v is not None} result = {k: v for k, v in result.items() if v is not None}
return result return result
def _parse_user_patch_operations(data: dict[str, Any]) -> dict[str, Any]: def _parse_user_patch_payload(data: dict[str, Any], tenant_id: str) -> dict[str, Any]:
result = {} result = {}
operations = data["Operations"] if "userType" in data:
for operation in operations: role = roles.get_role_by_name(tenant_id, data["userType"])
if operation["op"] == "replace" and "active" in operation["value"]: result["role_id"] = role["roleId"] if role else None
result["deleted_at"] = None if operation["value"]["active"] is True else datetime.now() if "name" in data:
# note(jon): we're currently not handling the case where the client
# send patches of individual name components (e.g. name.middleName)
name = data.get("name", {}).get("formatted")
if name:
result["name"] = name
if "userName" in data:
result["email"] = data["userName"]
if "externalId" in data:
result["internal_id"] = data["externalId"]
if "active" in data:
result["deleted_at"] = None if data["active"] else datetime.now()
return result return result
@ -326,6 +352,18 @@ def _parse_scim_group_input(data: dict[str, Any], tenant_id: int) -> dict[str, A
} }
def _parse_scim_group_patch(data: dict[str, Any], tenant_id: int) -> dict[str, Any]:
result = {}
if "displayName" in data:
result["name"] = data["displayName"]
if "externalId" in data:
result["external_id"] = data["externalId"]
if "members" in data:
members = data["members"] or []
result["user_ids"] = [int(member["value"]) for member in members]
return result
RESOURCE_TYPE_TO_RESOURCE_CONFIG = { RESOURCE_TYPE_TO_RESOURCE_CONFIG = {
"Users": { "Users": {
"max_items_per_page": 10, "max_items_per_page": 10,
@ -341,7 +379,7 @@ RESOURCE_TYPE_TO_RESOURCE_CONFIG = {
"delete_resource": users.soft_delete_scim_user_by_id, "delete_resource": users.soft_delete_scim_user_by_id,
"parse_put_payload": _parse_scim_user_input, "parse_put_payload": _parse_scim_user_input,
"update_resource": users.update_scim_user, "update_resource": users.update_scim_user,
"parse_patch_operations": _parse_user_patch_operations, "parse_patch_payload": _parse_user_patch_payload,
"patch_resource": users.patch_scim_user, "patch_resource": users.patch_scim_user,
}, },
"Groups": { "Groups": {
@ -359,6 +397,8 @@ RESOURCE_TYPE_TO_RESOURCE_CONFIG = {
"delete_resource": scim_groups.delete_resource, "delete_resource": scim_groups.delete_resource,
"parse_put_payload": _parse_scim_group_input, "parse_put_payload": _parse_scim_group_input,
"update_resource": scim_groups.update_resource, "update_resource": scim_groups.update_resource,
"parse_patch_payload": _parse_scim_group_patch,
"patch_resource": scim_groups.patch_resource,
}, },
} }
@ -442,7 +482,6 @@ class PostResourceType(str, Enum):
GROUPS = "Groups" GROUPS = "Groups"
@public_app.post("/{resource_type}") @public_app.post("/{resource_type}")
async def create_resource( async def create_resource(
resource_type: PostResourceType, resource_type: PostResourceType,
@ -556,6 +595,7 @@ async def put_resource(
class PatchResourceType(str, Enum): class PatchResourceType(str, Enum):
USERS = "Users" USERS = "Users"
GROUPS = "Groups"
@public_app.patch("/{resource_type}/{resource_id}") @public_app.patch("/{resource_type}/{resource_id}")
@ -577,13 +617,22 @@ async def patch_resource(
) )
) )
payload = await r.json() payload = await r.json()
parsed_payload = resource_config["parse_patch_operations"](payload) _, changes = scim_helpers.apply_scim_patch(
# note(jon): we don't need to handle uniqueness contraints and etc. like in PUT payload["Operations"],
# because we are only covering the User resource and the field `active` current_scim_resource,
SCHEMA_IDS_TO_SCHEMA_DETAILS[resource_config["schema_id"]],
)
reformatted_scim_changes = {
k: new_value for k, (old_value, new_value) in changes.items()
}
db_changes = resource_config["parse_patch_payload"](
reformatted_scim_changes,
tenant_id,
)
updated_db_resource = resource_config["patch_resource"]( updated_db_resource = resource_config["patch_resource"](
resource_id, resource_id,
tenant_id, tenant_id,
**parsed_payload, **db_changes,
) )
updated_scim_resource = ( updated_scim_resource = (
_serialize_db_resource_to_scim_resource_with_attribute_awareness( _serialize_db_resource_to_scim_resource_with_attribute_awareness(

View file

@ -1,4 +1,6 @@
from typing import Any from typing import Any
from datetime import datetime
from psycopg2.extensions import AsIs
from chalicelib.utils import helper, pg_client from chalicelib.utils import helper, pg_client
@ -82,65 +84,57 @@ def get_existing_resource_by_unique_values_from_all_resources(
def create_resource( def create_resource(
name: str, tenant_id: int, **kwargs: dict[str, Any] name: str,
tenant_id: int,
user_ids: list[str] | None = None,
**kwargs: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
with pg_client.PostgresClient() as cur: with pg_client.PostgresClient() as cur:
kwargs["name"] = name
kwargs["tenant_id"] = tenant_id
column_fragments = [
cur.mogrify("%s", (AsIs(k),)).decode("utf-8") for k in kwargs.keys()
]
column_clause = ", ".join(column_fragments)
value_fragments = [
cur.mogrify("%s", (v,)).decode("utf-8") for v in kwargs.values()
]
value_clause = ", ".join(value_fragments)
user_ids = user_ids or []
user_id_fragments = [
cur.mogrify("%s", (user_id,)).decode("utf-8") for user_id in user_ids
]
user_id_clause = f"ARRAY[{', '.join(user_id_fragments)}]::int[]"
cur.execute( cur.execute(
cur.mogrify( f"""
""" WITH
WITH g AS( g AS (
INSERT INTO public.groups INSERT INTO public.groups ({column_clause})
(tenant_id, name, external_id) VALUES ({value_clause})
VALUES (%(tenant_id)s, %(name)s, %(external_id)s) RETURNING *
),
linked_users AS (
UPDATE public.users
SET
group_id = g.group_id,
updated_at = now()
FROM g
WHERE
users.user_id = ANY({user_id_clause})
AND users.deleted_at IS NULL
AND users.tenant_id = {tenant_id}
RETURNING * RETURNING *
) )
SELECT g.group_id SELECT
FROM g; g.*,
""", COALESCE(users_data.array, '[]') as users
{ FROM g
"tenant_id": tenant_id, LEFT JOIN LATERAL (
"name": name, SELECT json_agg(lu) AS array
"external_id": kwargs.get("external_id"), FROM linked_users AS lu
}, ) users_data ON true
) LIMIT 1;
) """
group_id = cur.fetchone()["group_id"]
user_ids = kwargs.get("user_ids", [])
if user_ids:
cur.execute(
cur.mogrify(
"""
UPDATE public.users
SET group_id = %s
WHERE users.user_id = ANY(%s)
""",
(group_id, user_ids),
)
)
cur.execute(
cur.mogrify(
"""
SELECT
groups.*,
users_data.array as users
FROM public.groups
LEFT JOIN LATERAL (
SELECT json_agg(users) AS array
FROM public.users
WHERE users.group_id = %(group_id)s
) users_data ON true
WHERE
groups.group_id = %(group_id)s
AND groups.tenant_id = %(tenant_id)s
LIMIT 1;
""",
{
"group_id": group_id,
"tenant_id": tenant_id,
"name": name,
"external_id": kwargs.get("external_id"),
},
)
) )
return helper.dict_to_camel_case(cur.fetchone()) return helper.dict_to_camel_case(cur.fetchone())
@ -158,64 +152,92 @@ def delete_resource(group_id: int, tenant_id: int) -> None:
) )
def update_resource( def _update_resource_sql(
group_id: int, tenant_id: int, name: str, **kwargs: dict[str, Any] group_id: int,
tenant_id: int,
user_ids: list[int] | None = None,
**kwargs: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
with pg_client.PostgresClient() as cur: with pg_client.PostgresClient() as cur:
kwargs["updated_at"] = datetime.now()
set_fragments = [
cur.mogrify("%s = %s", (AsIs(k), v)).decode("utf-8")
for k, v in kwargs.items()
]
set_clause = ", ".join(set_fragments)
user_ids = user_ids or []
user_id_fragments = [
cur.mogrify("%s", (user_id,)).decode("utf-8") for user_id in user_ids
]
user_id_clause = f"ARRAY[{', '.join(user_id_fragments)}]::int[]"
cur.execute( cur.execute(
cur.mogrify( f"""
""" WITH
UPDATE public.users g AS (
SET group_id = null
WHERE users.group_id = %(group_id)s;
""",
{"group_id": group_id},
)
)
user_ids = kwargs.get("user_ids", [])
if user_ids:
cur.execute(
cur.mogrify(
"""
UPDATE public.users
SET group_id = %s
WHERE users.user_id = ANY(%s);
""",
(group_id, user_ids),
)
)
cur.execute(
cur.mogrify(
"""
WITH g AS (
UPDATE public.groups UPDATE public.groups
SET SET {set_clause}
tenant_id = %(tenant_id)s,
name = %(name)s,
external_id = %(external_id)s,
updated_at = default
WHERE WHERE
groups.group_id = %(group_id)s groups.group_id = {group_id}
AND groups.tenant_id = %(tenant_id)s AND groups.tenant_id = {tenant_id}
RETURNING *
),
unlinked_users AS (
UPDATE public.users
SET
group_id = null,
updated_at = now()
WHERE
users.group_id = {group_id}
AND users.user_id <> ALL({user_id_clause})
AND users.deleted_at IS NULL
AND users.tenant_id = {tenant_id}
),
linked_users AS (
UPDATE public.users
SET
group_id = {group_id},
updated_at = now()
WHERE
users.user_id = ANY({user_id_clause})
AND users.deleted_at IS NULL
AND users.tenant_id = {tenant_id}
RETURNING * RETURNING *
) )
SELECT SELECT
g.*, g.*,
users_data.array as users COALESCE(users_data.array, '[]') as users
FROM g FROM g
LEFT JOIN LATERAL ( LEFT JOIN LATERAL (
SELECT json_agg(users) AS array SELECT json_agg(lu) AS array
FROM public.users FROM linked_users AS lu
WHERE users.group_id = g.group_id ) users_data ON true
) users_data ON true LIMIT 1;
LIMIT 1; """
""",
{
"group_id": group_id,
"tenant_id": tenant_id,
"name": name,
"external_id": kwargs.get("external_id"),
},
)
) )
return helper.dict_to_camel_case(cur.fetchone()) return helper.dict_to_camel_case(cur.fetchone())
def update_resource(
group_id: int,
tenant_id: int,
name: str,
**kwargs: dict[str, Any],
) -> dict[str, Any]:
return _update_resource_sql(
group_id=group_id,
tenant_id=tenant_id,
name=name,
**kwargs,
)
def patch_resource(
group_id: int,
tenant_id: int,
**kwargs: dict[str, Any],
):
return _update_resource_sql(
group_id=group_id,
tenant_id=tenant_id,
**kwargs,
)

View file

@ -1,5 +1,6 @@
from typing import Any from typing import Any
from copy import deepcopy from copy import deepcopy
import re
def convert_query_str_to_list(query_str: str | None) -> list[str]: def convert_query_str_to_list(query_str: str | None) -> list[str]:
@ -156,3 +157,245 @@ def filter_mutable_attributes(
# If it matches, no change is needed (already set) # If it matches, no change is needed (already set)
return valid_changes return valid_changes
def apply_scim_patch(
operations: list[dict[str, Any]], resource: dict[str, Any], schema: dict[str, Any]
) -> dict[str, Any]:
"""
Apply SCIM patch operations to a resource based on schema.
Returns (updated_resource, changes) where `updated_resource` is the new SCIM
resource dict and `changes` maps attribute or path to (old_value, new_value).
Additions have old_value=None if attribute didn't exist; removals have new_value=None.
For add/remove on list-valued attributes, changes record the full list before/after.
"""
# Deep copy to avoid mutating original
updated = deepcopy(resource)
changes = {}
# Allowed attributes from schema
allowed_attrs = {attr["name"]: attr for attr in schema.get("attributes", [])}
for op in operations:
op_type = op.get("op", "").strip().lower()
path = op.get("path")
value = op.get("value")
if not path:
# Top-level merge
if op_type in ("add", "replace"):
if not isinstance(value, dict):
raise ValueError(
"When path is not provided, value must be a dict of attributes to merge."
)
for attr, val in value.items():
if attr not in allowed_attrs:
raise ValueError(
f"Attribute '{attr}' not defined in SCIM schema"
)
old = updated.get(attr)
updated[attr] = val if val is not None else updated.pop(attr, None)
changes[attr] = (old, val)
else:
raise ValueError(f"Unsupported operation without path: {op_type}")
continue
tokens = parse_scim_path(path)
# Detect simple top-level list add/remove
if (
op_type in ("add", "remove")
and len(tokens) == 1
and isinstance(tokens[0], str)
):
attr = tokens[0]
if attr not in allowed_attrs:
raise ValueError(f"Attribute '{attr}' not defined in SCIM schema")
current_list = updated.get(attr, [])
if isinstance(current_list, list):
before = deepcopy(current_list)
if op_type == "add":
# Ensure list exists
updated.setdefault(attr, [])
# Append new items
items = value if isinstance(value, list) else [value]
updated[attr].extend(items)
else: # remove
# Remove items matching filter if value not provided
# For remove on list without filter, remove all values equal to value
if value is None:
updated.pop(attr, None)
else:
# filter value items out
items = value if isinstance(value, list) else [value]
updated[attr] = [
e for e in updated.get(attr, []) if e not in items
]
after = deepcopy(updated.get(attr, []))
changes[attr] = (before, after)
continue
# For other operations, get old value and apply normally
old_val = get_by_path(updated, tokens)
if op_type == "add":
set_by_path(updated, tokens, value)
elif op_type == "replace":
if value is None:
remove_by_path(updated, tokens)
else:
set_by_path(updated, tokens, value)
elif op_type == "remove":
remove_by_path(updated, tokens)
else:
raise ValueError(f"Unsupported operation type: {op_type}")
# Record change for non-list or nested paths
new_val = None if op_type == "remove" else get_by_path(updated, tokens)
changes[path] = (old_val, new_val)
return updated, changes
def parse_scim_path(path):
"""
Parse a SCIM-style path (e.g., 'emails[type eq "work"].value') into a list
of tokens. Each token is either a string attribute name or a tuple
(attr, filter_attr, filter_value) for list-filtering.
"""
tokens = []
# Regex matches segments like attr or attr[filter] where filter is e.g. type eq "work"
segment_re = re.compile(r"([^\.\[]+)(?:\[(.*?)\])?")
for match in segment_re.finditer(path):
attr = match.group(1)
filt = match.group(2)
if filt:
# Support simple equality filter of form: subAttr eq "value"
m = re.match(r"\s*(\w+)\s+eq\s+\"([^\"]+)\"", filt)
if not m:
raise ValueError(f"Unsupported filter expression: {filt}")
filter_attr, filter_val = m.group(1), m.group(2)
tokens.append((attr, filter_attr, filter_val))
else:
tokens.append(attr)
return tokens
def get_by_path(doc, tokens):
"""
Retrieve a value from nested dicts/lists using parsed tokens.
Returns None if any step is missing.
"""
cur = doc
for token in tokens:
if cur is None:
return None
if isinstance(token, tuple):
attr, fattr, fval = token
lst = cur.get(attr)
if not isinstance(lst, list):
return None
# Find first dict element matching filter
for elem in lst:
if isinstance(elem, dict) and elem.get(fattr) == fval:
cur = elem
break
else:
return None
else:
if isinstance(cur, dict):
cur = cur.get(token)
elif isinstance(cur, list) and isinstance(token, int):
if 0 <= token < len(cur):
cur = cur[token]
else:
return None
else:
return None
return cur
def set_by_path(doc, tokens, value):
"""
Set a value in nested dicts/lists using parsed tokens.
Creates intermediate dicts/lists as needed.
"""
cur = doc
for i, token in enumerate(tokens):
last = i == len(tokens) - 1
if isinstance(token, tuple):
attr, fattr, fval = token
lst = cur.setdefault(attr, [])
if not isinstance(lst, list):
raise ValueError(f"Expected list at attribute '{attr}'")
# Find existing entry
idx = next(
(
j
for j, e in enumerate(lst)
if isinstance(e, dict) and e.get(fattr) == fval
),
None,
)
if idx is None:
if last:
lst.append(value)
return
else:
new = {}
lst.append(new)
cur = new
else:
if last:
lst[idx] = value
return
cur = lst[idx]
else:
if last:
if value is None:
if isinstance(cur, dict):
cur.pop(token, None)
else:
cur[token] = value
else:
cur = cur.setdefault(token, {})
def remove_by_path(doc, tokens):
"""
Remove a value in nested dicts/lists using parsed tokens.
Does nothing if path not present.
"""
cur = doc
for i, token in enumerate(tokens):
last = i == len(tokens) - 1
if isinstance(token, tuple):
attr, fattr, fval = token
lst = cur.get(attr)
if not isinstance(lst, list):
return
for j, elem in enumerate(lst):
if isinstance(elem, dict) and elem.get(fattr) == fval:
if last:
lst.pop(j)
return
cur = elem
break
else:
return
else:
if last:
if isinstance(cur, dict):
cur.pop(token, None)
elif isinstance(cur, list) and isinstance(token, int):
if 0 <= token < len(cur):
cur.pop(token)
return
else:
if isinstance(cur, dict):
cur = cur.get(token)
elif isinstance(cur, list) and isinstance(token, int):
cur = cur[token] if 0 <= token < len(cur) else None
else:
return