From 23d696b407e29a2e4c58fa2e0aacf8440727be76 Mon Sep 17 00:00:00 2001 From: Jonathan Griffin Date: Thu, 24 Apr 2025 10:41:06 +0200 Subject: [PATCH] add patch endpoint for groups --- ee/api/chalicelib/core/users.py | 8 +- ee/api/routers/scim.py | 73 ++++++++-- ee/api/routers/scim_groups.py | 234 ++++++++++++++++-------------- ee/api/routers/scim_helpers.py | 243 ++++++++++++++++++++++++++++++++ 4 files changed, 436 insertions(+), 122 deletions(-) diff --git a/ee/api/chalicelib/core/users.py b/ee/api/chalicelib/core/users.py index b1b014735..2c63faab1 100644 --- a/ee/api/chalicelib/core/users.py +++ b/ee/api/chalicelib/core/users.py @@ -483,8 +483,8 @@ def restore_scim_user( internal_id = %(internal_id)s, role_id = %(role_id)s, deleted_at = NULL, - created_at = default, - updated_at = default, + created_at = now(), + updated_at = now(), api_key = default, jwt_iat = NULL, weekly_report = default @@ -527,7 +527,7 @@ def update_scim_user( name = %(name)s, internal_id = %(internal_id)s, role_id = %(role_id)s, - updated_at = default + updated_at = now() WHERE users.user_id = %(user_id)s AND users.tenant_id = %(tenant_id)s @@ -582,7 +582,7 @@ def patch_scim_user( roles.name as role_name FROM u LEFT JOIN public.roles USING (role_id);""" cur.execute(query) - return helper.dict_to_camel_case(cur.fetchone()) + return helper.dict_to_camel_case(cur.fetchone()) def generate_new_api_key(user_id): diff --git a/ee/api/routers/scim.py b/ee/api/routers/scim.py index a00822969..024fde933 100644 --- a/ee/api/routers/scim.py +++ b/ee/api/routers/scim.py @@ -254,22 +254,48 @@ def _parse_scim_user_input(data: dict[str, Any], tenant_id: str) -> dict[str, An if "userType" in data: role = roles.get_role_by_name(tenant_id, data["userType"]) role_id = role["roleId"] if role else None + name = data.get("name", {}).get("formatted") + if not name: + name = " ".join( + [ + x + for x in [ + data.get("name", {}).get("honorificPrefix"), + data.get("name", {}).get("givenName"), + data.get("name", {}).get("middleName"), + data.get("name", {}).get("familyName"), + data.get("name", {}).get("honorificSuffix"), + ] + if x + ] + ) result = { "email": data["userName"], "internal_id": data.get("externalId"), - "name": data.get("name", {}).get("formatted") or data.get("displayName"), + "name": name, "role_id": role_id, } result = {k: v for k, v in result.items() if v is not None} return result -def _parse_user_patch_operations(data: dict[str, Any]) -> dict[str, Any]: +def _parse_user_patch_payload(data: dict[str, Any], tenant_id: str) -> dict[str, Any]: result = {} - operations = data["Operations"] - for operation in operations: - if operation["op"] == "replace" and "active" in operation["value"]: - result["deleted_at"] = None if operation["value"]["active"] is True else datetime.now() + if "userType" in data: + role = roles.get_role_by_name(tenant_id, data["userType"]) + result["role_id"] = role["roleId"] if role else None + if "name" in data: + # note(jon): we're currently not handling the case where the client + # send patches of individual name components (e.g. name.middleName) + name = data.get("name", {}).get("formatted") + if name: + result["name"] = name + if "userName" in data: + result["email"] = data["userName"] + if "externalId" in data: + result["internal_id"] = data["externalId"] + if "active" in data: + result["deleted_at"] = None if data["active"] else datetime.now() return result @@ -326,6 +352,18 @@ def _parse_scim_group_input(data: dict[str, Any], tenant_id: int) -> dict[str, A } +def _parse_scim_group_patch(data: dict[str, Any], tenant_id: int) -> dict[str, Any]: + result = {} + if "displayName" in data: + result["name"] = data["displayName"] + if "externalId" in data: + result["external_id"] = data["externalId"] + if "members" in data: + members = data["members"] or [] + result["user_ids"] = [int(member["value"]) for member in members] + return result + + RESOURCE_TYPE_TO_RESOURCE_CONFIG = { "Users": { "max_items_per_page": 10, @@ -341,7 +379,7 @@ RESOURCE_TYPE_TO_RESOURCE_CONFIG = { "delete_resource": users.soft_delete_scim_user_by_id, "parse_put_payload": _parse_scim_user_input, "update_resource": users.update_scim_user, - "parse_patch_operations": _parse_user_patch_operations, + "parse_patch_payload": _parse_user_patch_payload, "patch_resource": users.patch_scim_user, }, "Groups": { @@ -359,6 +397,8 @@ RESOURCE_TYPE_TO_RESOURCE_CONFIG = { "delete_resource": scim_groups.delete_resource, "parse_put_payload": _parse_scim_group_input, "update_resource": scim_groups.update_resource, + "parse_patch_payload": _parse_scim_group_patch, + "patch_resource": scim_groups.patch_resource, }, } @@ -442,7 +482,6 @@ class PostResourceType(str, Enum): GROUPS = "Groups" - @public_app.post("/{resource_type}") async def create_resource( resource_type: PostResourceType, @@ -556,6 +595,7 @@ async def put_resource( class PatchResourceType(str, Enum): USERS = "Users" + GROUPS = "Groups" @public_app.patch("/{resource_type}/{resource_id}") @@ -577,13 +617,22 @@ async def patch_resource( ) ) payload = await r.json() - parsed_payload = resource_config["parse_patch_operations"](payload) - # note(jon): we don't need to handle uniqueness contraints and etc. like in PUT - # because we are only covering the User resource and the field `active` + _, changes = scim_helpers.apply_scim_patch( + payload["Operations"], + current_scim_resource, + SCHEMA_IDS_TO_SCHEMA_DETAILS[resource_config["schema_id"]], + ) + reformatted_scim_changes = { + k: new_value for k, (old_value, new_value) in changes.items() + } + db_changes = resource_config["parse_patch_payload"]( + reformatted_scim_changes, + tenant_id, + ) updated_db_resource = resource_config["patch_resource"]( resource_id, tenant_id, - **parsed_payload, + **db_changes, ) updated_scim_resource = ( _serialize_db_resource_to_scim_resource_with_attribute_awareness( diff --git a/ee/api/routers/scim_groups.py b/ee/api/routers/scim_groups.py index 743d53512..a9ef352b8 100644 --- a/ee/api/routers/scim_groups.py +++ b/ee/api/routers/scim_groups.py @@ -1,4 +1,6 @@ from typing import Any +from datetime import datetime +from psycopg2.extensions import AsIs from chalicelib.utils import helper, pg_client @@ -82,65 +84,57 @@ def get_existing_resource_by_unique_values_from_all_resources( def create_resource( - name: str, tenant_id: int, **kwargs: dict[str, Any] + name: str, + tenant_id: int, + user_ids: list[str] | None = None, + **kwargs: dict[str, Any], ) -> dict[str, Any]: with pg_client.PostgresClient() as cur: + kwargs["name"] = name + kwargs["tenant_id"] = tenant_id + column_fragments = [ + cur.mogrify("%s", (AsIs(k),)).decode("utf-8") for k in kwargs.keys() + ] + column_clause = ", ".join(column_fragments) + value_fragments = [ + cur.mogrify("%s", (v,)).decode("utf-8") for v in kwargs.values() + ] + value_clause = ", ".join(value_fragments) + user_ids = user_ids or [] + user_id_fragments = [ + cur.mogrify("%s", (user_id,)).decode("utf-8") for user_id in user_ids + ] + user_id_clause = f"ARRAY[{', '.join(user_id_fragments)}]::int[]" cur.execute( - cur.mogrify( - """ - WITH g AS( - INSERT INTO public.groups - (tenant_id, name, external_id) - VALUES (%(tenant_id)s, %(name)s, %(external_id)s) + f""" + WITH + g AS ( + INSERT INTO public.groups ({column_clause}) + VALUES ({value_clause}) + RETURNING * + ), + linked_users AS ( + UPDATE public.users + SET + group_id = g.group_id, + updated_at = now() + FROM g + WHERE + users.user_id = ANY({user_id_clause}) + AND users.deleted_at IS NULL + AND users.tenant_id = {tenant_id} RETURNING * ) - SELECT g.group_id - FROM g; - """, - { - "tenant_id": tenant_id, - "name": name, - "external_id": kwargs.get("external_id"), - }, - ) - ) - group_id = cur.fetchone()["group_id"] - user_ids = kwargs.get("user_ids", []) - if user_ids: - cur.execute( - cur.mogrify( - """ - UPDATE public.users - SET group_id = %s - WHERE users.user_id = ANY(%s) - """, - (group_id, user_ids), - ) - ) - cur.execute( - cur.mogrify( - """ - SELECT - groups.*, - users_data.array as users - FROM public.groups - LEFT JOIN LATERAL ( - SELECT json_agg(users) AS array - FROM public.users - WHERE users.group_id = %(group_id)s - ) users_data ON true - WHERE - groups.group_id = %(group_id)s - AND groups.tenant_id = %(tenant_id)s - LIMIT 1; - """, - { - "group_id": group_id, - "tenant_id": tenant_id, - "name": name, - "external_id": kwargs.get("external_id"), - }, - ) + SELECT + g.*, + COALESCE(users_data.array, '[]') as users + FROM g + LEFT JOIN LATERAL ( + SELECT json_agg(lu) AS array + FROM linked_users AS lu + ) users_data ON true + LIMIT 1; + """ ) return helper.dict_to_camel_case(cur.fetchone()) @@ -158,64 +152,92 @@ def delete_resource(group_id: int, tenant_id: int) -> None: ) -def update_resource( - group_id: int, tenant_id: int, name: str, **kwargs: dict[str, Any] +def _update_resource_sql( + group_id: int, + tenant_id: int, + user_ids: list[int] | None = None, + **kwargs: dict[str, Any], ) -> dict[str, Any]: with pg_client.PostgresClient() as cur: + kwargs["updated_at"] = datetime.now() + set_fragments = [ + cur.mogrify("%s = %s", (AsIs(k), v)).decode("utf-8") + for k, v in kwargs.items() + ] + set_clause = ", ".join(set_fragments) + user_ids = user_ids or [] + user_id_fragments = [ + cur.mogrify("%s", (user_id,)).decode("utf-8") for user_id in user_ids + ] + user_id_clause = f"ARRAY[{', '.join(user_id_fragments)}]::int[]" cur.execute( - cur.mogrify( - """ - UPDATE public.users - SET group_id = null - WHERE users.group_id = %(group_id)s; - """, - {"group_id": group_id}, - ) - ) - user_ids = kwargs.get("user_ids", []) - if user_ids: - cur.execute( - cur.mogrify( - """ - UPDATE public.users - SET group_id = %s - WHERE users.user_id = ANY(%s); - """, - (group_id, user_ids), - ) - ) - cur.execute( - cur.mogrify( - """ - WITH g AS ( + f""" + WITH + g AS ( UPDATE public.groups - SET - tenant_id = %(tenant_id)s, - name = %(name)s, - external_id = %(external_id)s, - updated_at = default + SET {set_clause} WHERE - groups.group_id = %(group_id)s - AND groups.tenant_id = %(tenant_id)s + groups.group_id = {group_id} + AND groups.tenant_id = {tenant_id} + RETURNING * + ), + unlinked_users AS ( + UPDATE public.users + SET + group_id = null, + updated_at = now() + WHERE + users.group_id = {group_id} + AND users.user_id <> ALL({user_id_clause}) + AND users.deleted_at IS NULL + AND users.tenant_id = {tenant_id} + ), + linked_users AS ( + UPDATE public.users + SET + group_id = {group_id}, + updated_at = now() + WHERE + users.user_id = ANY({user_id_clause}) + AND users.deleted_at IS NULL + AND users.tenant_id = {tenant_id} RETURNING * ) - SELECT - g.*, - users_data.array as users - FROM g - LEFT JOIN LATERAL ( - SELECT json_agg(users) AS array - FROM public.users - WHERE users.group_id = g.group_id - ) users_data ON true - LIMIT 1; - """, - { - "group_id": group_id, - "tenant_id": tenant_id, - "name": name, - "external_id": kwargs.get("external_id"), - }, - ) + SELECT + g.*, + COALESCE(users_data.array, '[]') as users + FROM g + LEFT JOIN LATERAL ( + SELECT json_agg(lu) AS array + FROM linked_users AS lu + ) users_data ON true + LIMIT 1; + """ ) return helper.dict_to_camel_case(cur.fetchone()) + + +def update_resource( + group_id: int, + tenant_id: int, + name: str, + **kwargs: dict[str, Any], +) -> dict[str, Any]: + return _update_resource_sql( + group_id=group_id, + tenant_id=tenant_id, + name=name, + **kwargs, + ) + + +def patch_resource( + group_id: int, + tenant_id: int, + **kwargs: dict[str, Any], +): + return _update_resource_sql( + group_id=group_id, + tenant_id=tenant_id, + **kwargs, + ) diff --git a/ee/api/routers/scim_helpers.py b/ee/api/routers/scim_helpers.py index 477becc6c..b57ec1356 100644 --- a/ee/api/routers/scim_helpers.py +++ b/ee/api/routers/scim_helpers.py @@ -1,5 +1,6 @@ from typing import Any from copy import deepcopy +import re def convert_query_str_to_list(query_str: str | None) -> list[str]: @@ -156,3 +157,245 @@ def filter_mutable_attributes( # If it matches, no change is needed (already set) return valid_changes + + +def apply_scim_patch( + operations: list[dict[str, Any]], resource: dict[str, Any], schema: dict[str, Any] +) -> dict[str, Any]: + """ + Apply SCIM patch operations to a resource based on schema. + Returns (updated_resource, changes) where `updated_resource` is the new SCIM + resource dict and `changes` maps attribute or path to (old_value, new_value). + Additions have old_value=None if attribute didn't exist; removals have new_value=None. + For add/remove on list-valued attributes, changes record the full list before/after. + """ + # Deep copy to avoid mutating original + updated = deepcopy(resource) + changes = {} + + # Allowed attributes from schema + allowed_attrs = {attr["name"]: attr for attr in schema.get("attributes", [])} + + for op in operations: + op_type = op.get("op", "").strip().lower() + path = op.get("path") + value = op.get("value") + + if not path: + # Top-level merge + if op_type in ("add", "replace"): + if not isinstance(value, dict): + raise ValueError( + "When path is not provided, value must be a dict of attributes to merge." + ) + for attr, val in value.items(): + if attr not in allowed_attrs: + raise ValueError( + f"Attribute '{attr}' not defined in SCIM schema" + ) + old = updated.get(attr) + updated[attr] = val if val is not None else updated.pop(attr, None) + changes[attr] = (old, val) + else: + raise ValueError(f"Unsupported operation without path: {op_type}") + continue + + tokens = parse_scim_path(path) + + # Detect simple top-level list add/remove + if ( + op_type in ("add", "remove") + and len(tokens) == 1 + and isinstance(tokens[0], str) + ): + attr = tokens[0] + if attr not in allowed_attrs: + raise ValueError(f"Attribute '{attr}' not defined in SCIM schema") + current_list = updated.get(attr, []) + if isinstance(current_list, list): + before = deepcopy(current_list) + if op_type == "add": + # Ensure list exists + updated.setdefault(attr, []) + # Append new items + items = value if isinstance(value, list) else [value] + updated[attr].extend(items) + else: # remove + # Remove items matching filter if value not provided + # For remove on list without filter, remove all values equal to value + if value is None: + updated.pop(attr, None) + else: + # filter value items out + items = value if isinstance(value, list) else [value] + updated[attr] = [ + e for e in updated.get(attr, []) if e not in items + ] + after = deepcopy(updated.get(attr, [])) + changes[attr] = (before, after) + continue + + # For other operations, get old value and apply normally + old_val = get_by_path(updated, tokens) + + if op_type == "add": + set_by_path(updated, tokens, value) + elif op_type == "replace": + if value is None: + remove_by_path(updated, tokens) + else: + set_by_path(updated, tokens, value) + elif op_type == "remove": + remove_by_path(updated, tokens) + else: + raise ValueError(f"Unsupported operation type: {op_type}") + + # Record change for non-list or nested paths + new_val = None if op_type == "remove" else get_by_path(updated, tokens) + changes[path] = (old_val, new_val) + + return updated, changes + + +def parse_scim_path(path): + """ + Parse a SCIM-style path (e.g., 'emails[type eq "work"].value') into a list + of tokens. Each token is either a string attribute name or a tuple + (attr, filter_attr, filter_value) for list-filtering. + """ + tokens = [] + # Regex matches segments like attr or attr[filter] where filter is e.g. type eq "work" + segment_re = re.compile(r"([^\.\[]+)(?:\[(.*?)\])?") + for match in segment_re.finditer(path): + attr = match.group(1) + filt = match.group(2) + if filt: + # Support simple equality filter of form: subAttr eq "value" + m = re.match(r"\s*(\w+)\s+eq\s+\"([^\"]+)\"", filt) + if not m: + raise ValueError(f"Unsupported filter expression: {filt}") + filter_attr, filter_val = m.group(1), m.group(2) + tokens.append((attr, filter_attr, filter_val)) + else: + tokens.append(attr) + return tokens + + +def get_by_path(doc, tokens): + """ + Retrieve a value from nested dicts/lists using parsed tokens. + Returns None if any step is missing. + """ + cur = doc + for token in tokens: + if cur is None: + return None + if isinstance(token, tuple): + attr, fattr, fval = token + lst = cur.get(attr) + if not isinstance(lst, list): + return None + # Find first dict element matching filter + for elem in lst: + if isinstance(elem, dict) and elem.get(fattr) == fval: + cur = elem + break + else: + return None + else: + if isinstance(cur, dict): + cur = cur.get(token) + elif isinstance(cur, list) and isinstance(token, int): + if 0 <= token < len(cur): + cur = cur[token] + else: + return None + else: + return None + return cur + + +def set_by_path(doc, tokens, value): + """ + Set a value in nested dicts/lists using parsed tokens. + Creates intermediate dicts/lists as needed. + """ + cur = doc + for i, token in enumerate(tokens): + last = i == len(tokens) - 1 + if isinstance(token, tuple): + attr, fattr, fval = token + lst = cur.setdefault(attr, []) + if not isinstance(lst, list): + raise ValueError(f"Expected list at attribute '{attr}'") + # Find existing entry + idx = next( + ( + j + for j, e in enumerate(lst) + if isinstance(e, dict) and e.get(fattr) == fval + ), + None, + ) + if idx is None: + if last: + lst.append(value) + return + else: + new = {} + lst.append(new) + cur = new + else: + if last: + lst[idx] = value + return + cur = lst[idx] + + else: + if last: + if value is None: + if isinstance(cur, dict): + cur.pop(token, None) + else: + cur[token] = value + else: + cur = cur.setdefault(token, {}) + + +def remove_by_path(doc, tokens): + """ + Remove a value in nested dicts/lists using parsed tokens. + Does nothing if path not present. + """ + cur = doc + for i, token in enumerate(tokens): + last = i == len(tokens) - 1 + if isinstance(token, tuple): + attr, fattr, fval = token + lst = cur.get(attr) + if not isinstance(lst, list): + return + for j, elem in enumerate(lst): + if isinstance(elem, dict) and elem.get(fattr) == fval: + if last: + lst.pop(j) + return + cur = elem + break + else: + return + else: + if last: + if isinstance(cur, dict): + cur.pop(token, None) + elif isinstance(cur, list) and isinstance(token, int): + if 0 <= token < len(cur): + cur.pop(token) + return + else: + if isinstance(cur, dict): + cur = cur.get(token) + elif isinstance(cur, list) and isinstance(token, int): + cur = cur[token] if 0 <= token < len(cur) else None + else: + return