from typing import Any from copy import deepcopy import re def convert_query_str_to_list(query_str: str | None) -> list[str]: if query_str is None: return None return query_str.split(",") def get_all_attribute_names(schema: dict[str, Any]) -> list[str]: result = [] def _walk(attrs, prefix=None): for attr in attrs: name = attr["name"] path = f"{prefix}.{name}" if prefix else name result.append(path) if attr["type"] == "complex": sub = attr.get("subAttributes") or attr.get("attributes") or [] _walk(sub, path) _walk(schema["attributes"]) return result def get_all_attribute_names_where_returned_is_always( schema: dict[str, Any], ) -> list[str]: result = [] def _walk(attrs, prefix=None): for attr in attrs: name = attr["name"] path = f"{prefix}.{name}" if prefix else name if attr["returned"] == "always": result.append(path) if attr["type"] == "complex": sub = attr.get("subAttributes") or attr.get("attributes") or [] _walk(sub, path) _walk(schema["attributes"]) return result def filter_attributes( obj: dict[str, Any], attributes_query_str: str | None, excluded_attributes_query_str: str | None, schema: dict[str, Any], ) -> dict[str, Any]: all_attributes = get_all_attribute_names(schema) always_returned_attributes = get_all_attribute_names_where_returned_is_always( schema ) included_attributes = convert_query_str_to_list(attributes_query_str) included_attributes = included_attributes or all_attributes included_attributes_set = set(included_attributes).union( set(always_returned_attributes) ) excluded_attributes = convert_query_str_to_list(excluded_attributes_query_str) excluded_attributes = excluded_attributes or [] excluded_attributes_set = set(excluded_attributes).difference( set(always_returned_attributes) ) include_paths = included_attributes_set.difference(excluded_attributes_set) include_tree = {} for path in include_paths: parts = path.split(".") node = include_tree for part in parts: node = node.setdefault(part, {}) def _recurse(o, tree, parent_key=None): if isinstance(o, dict): out = {} for key, subtree in tree.items(): if key in o: out[key] = _recurse(o[key], subtree, key) return out if isinstance(o, list): out = [_recurse(item, tree, parent_key) for item in o] return out return o result = _recurse(obj, include_tree) return result def filter_mutable_attributes( schema: dict[str, Any], requested_changes: dict[str, Any], current_values: dict[str, Any], ) -> dict[str, Any]: attributes = {attr.get("name"): attr for attr in schema.get("attributes", [])} valid_changes = {} for attr_name, new_value in requested_changes.items(): attr_def = attributes.get(attr_name) if not attr_def: # Unknown attribute: ignore per RFC 7644 continue mutability = attr_def.get("mutability", "readWrite") if mutability == "readWrite" or mutability == "writeOnly": valid_changes[attr_name] = new_value elif mutability == "readOnly": # Cannot modify read-only attributes: ignore continue elif mutability == "immutable": # Only valid if the new value matches the current value exactly current_value = current_values.get(attr_name) if new_value != current_value: raise ValueError( f"Attribute '{attr_name}' is immutable (cannot change). " f"Current value: {current_value!r}, attempted change: {new_value!r}" ) # If it matches, no change is needed (already set) return valid_changes def apply_scim_patch( operations: list[dict[str, Any]], resource: dict[str, Any], schema: dict[str, Any] ) -> dict[str, Any]: """ Apply SCIM patch operations to a resource based on schema. Returns (updated_resource, changes) where `updated_resource` is the new SCIM resource dict and `changes` maps attribute or path to (old_value, new_value). Additions have old_value=None if attribute didn't exist; removals have new_value=None. For add/remove on list-valued attributes, changes record the full list before/after. """ # Deep copy to avoid mutating original updated = deepcopy(resource) changes = {} # Allowed attributes from schema allowed_attrs = {attr["name"]: attr for attr in schema.get("attributes", [])} for op in operations: op_type = op.get("op", "").strip().lower() path = op.get("path") value = op.get("value") if not path: # Top-level merge if op_type in ("add", "replace"): if not isinstance(value, dict): raise ValueError( "When path is not provided, value must be a dict of attributes to merge." ) for attr, val in value.items(): if attr not in allowed_attrs: raise ValueError( f"Attribute '{attr}' not defined in SCIM schema" ) old = updated.get(attr) updated[attr] = val if val is not None else updated.pop(attr, None) changes[attr] = (old, val) else: raise ValueError(f"Unsupported operation without path: {op_type}") continue tokens = parse_scim_path(path) # Detect simple top-level list add/remove if ( op_type in ("add", "remove") and len(tokens) == 1 and isinstance(tokens[0], str) ): attr = tokens[0] if attr not in allowed_attrs: raise ValueError(f"Attribute '{attr}' not defined in SCIM schema") current_list = updated.get(attr, []) if isinstance(current_list, list): before = deepcopy(current_list) if op_type == "add": # Ensure list exists updated.setdefault(attr, []) # Append new items items = value if isinstance(value, list) else [value] updated[attr].extend(items) else: # remove # Remove items matching filter if value not provided # For remove on list without filter, remove all values equal to value if value is None: updated.pop(attr, None) else: # filter value items out items = value if isinstance(value, list) else [value] updated[attr] = [ e for e in updated.get(attr, []) if e not in items ] after = deepcopy(updated.get(attr, [])) changes[attr] = (before, after) continue # For other operations, get old value and apply normally old_val = get_by_path(updated, tokens) if op_type == "add": set_by_path(updated, tokens, value) elif op_type == "replace": if value is None: remove_by_path(updated, tokens) else: set_by_path(updated, tokens, value) elif op_type == "remove": remove_by_path(updated, tokens) else: raise ValueError(f"Unsupported operation type: {op_type}") # Record change for non-list or nested paths new_val = None if op_type == "remove" else get_by_path(updated, tokens) changes[path] = (old_val, new_val) return updated, changes def parse_scim_path(path): """ Parse a SCIM-style path (e.g., 'emails[type eq "work"].value') into a list of tokens. Each token is either a string attribute name or a tuple (attr, filter_attr, filter_value) for list-filtering. """ tokens = [] # Regex matches segments like attr or attr[filter] where filter is e.g. type eq "work" segment_re = re.compile(r"([^\.\[]+)(?:\[(.*?)\])?") for match in segment_re.finditer(path): attr = match.group(1) filt = match.group(2) if filt: # Support simple equality filter of form: subAttr eq "value" m = re.match(r"\s*(\w+)\s+eq\s+\"([^\"]+)\"", filt) if not m: raise ValueError(f"Unsupported filter expression: {filt}") filter_attr, filter_val = m.group(1), m.group(2) tokens.append((attr, filter_attr, filter_val)) else: tokens.append(attr) return tokens def get_by_path(doc, tokens): """ Retrieve a value from nested dicts/lists using parsed tokens. Returns None if any step is missing. """ cur = doc for token in tokens: if cur is None: return None if isinstance(token, tuple): attr, fattr, fval = token lst = cur.get(attr) if not isinstance(lst, list): return None # Find first dict element matching filter for elem in lst: if isinstance(elem, dict) and elem.get(fattr) == fval: cur = elem break else: return None else: if isinstance(cur, dict): cur = cur.get(token) elif isinstance(cur, list) and isinstance(token, int): if 0 <= token < len(cur): cur = cur[token] else: return None else: return None return cur def set_by_path(doc, tokens, value): """ Set a value in nested dicts/lists using parsed tokens. Creates intermediate dicts/lists as needed. """ cur = doc for i, token in enumerate(tokens): last = i == len(tokens) - 1 if isinstance(token, tuple): attr, fattr, fval = token lst = cur.setdefault(attr, []) if not isinstance(lst, list): raise ValueError(f"Expected list at attribute '{attr}'") # Find existing entry idx = next( ( j for j, e in enumerate(lst) if isinstance(e, dict) and e.get(fattr) == fval ), None, ) if idx is None: if last: lst.append(value) return else: new = {} lst.append(new) cur = new else: if last: lst[idx] = value return cur = lst[idx] else: if last: if value is None: if isinstance(cur, dict): cur.pop(token, None) else: cur[token] = value else: cur = cur.setdefault(token, {}) def remove_by_path(doc, tokens): """ Remove a value in nested dicts/lists using parsed tokens. Does nothing if path not present. """ cur = doc for i, token in enumerate(tokens): last = i == len(tokens) - 1 if isinstance(token, tuple): attr, fattr, fval = token lst = cur.get(attr) if not isinstance(lst, list): return for j, elem in enumerate(lst): if isinstance(elem, dict) and elem.get(fattr) == fval: if last: lst.pop(j) return cur = elem break else: return else: if last: if isinstance(cur, dict): cur.pop(token, None) elif isinstance(cur, list) and isinstance(token, int): if 0 <= token < len(cur): cur.pop(token) return else: if isinstance(cur, dict): cur = cur.get(token) elif isinstance(cur, list) and isinstance(token, int): cur = cur[token] if 0 <= token < len(cur) else None else: return