495 lines
16 KiB
Python
495 lines
16 KiB
Python
from typing import Any, Literal
|
|
from copy import deepcopy
|
|
import re
|
|
from chalicelib.utils import pg_client
|
|
|
|
|
|
def safe_mogrify_array(
|
|
items: list[Any] | None,
|
|
array_type: Literal["varchar", "int"],
|
|
cursor: pg_client.PostgresClient,
|
|
) -> str:
|
|
items = items or []
|
|
fragments = [cursor.mogrify("%s", (item,)).decode("utf-8") for item in items]
|
|
result = f"ARRAY[{', '.join(fragments)}]::{array_type}[]"
|
|
return result
|
|
|
|
|
|
def convert_query_str_to_list(query_str: str | None) -> list[str]:
|
|
if query_str is None:
|
|
return None
|
|
return query_str.split(",")
|
|
|
|
|
|
def get_all_attribute_names(schema: dict[str, Any]) -> list[str]:
|
|
result = []
|
|
|
|
def _walk(attrs, prefix=None):
|
|
for attr in attrs:
|
|
name = attr["name"]
|
|
path = f"{prefix}.{name}" if prefix else name
|
|
result.append(path)
|
|
if attr["type"] == "complex":
|
|
sub = attr.get("subAttributes") or attr.get("attributes") or []
|
|
_walk(sub, path)
|
|
|
|
_walk(schema["attributes"])
|
|
return result
|
|
|
|
|
|
def get_all_attribute_names_where_returned_is_always(
|
|
schema: dict[str, Any],
|
|
) -> list[str]:
|
|
result = []
|
|
|
|
def _walk(attrs, prefix=None):
|
|
for attr in attrs:
|
|
name = attr["name"]
|
|
path = f"{prefix}.{name}" if prefix else name
|
|
if attr["returned"] == "always":
|
|
result.append(path)
|
|
if attr["type"] == "complex":
|
|
sub = attr.get("subAttributes") or attr.get("attributes") or []
|
|
_walk(sub, path)
|
|
|
|
_walk(schema["attributes"])
|
|
return result
|
|
|
|
|
|
def filter_attributes(
|
|
obj: dict[str, Any],
|
|
attributes_query_str: str | None,
|
|
excluded_attributes_query_str: str | None,
|
|
schema: dict[str, Any],
|
|
) -> dict[str, Any]:
|
|
all_attributes = get_all_attribute_names(schema)
|
|
always_returned_attributes = get_all_attribute_names_where_returned_is_always(
|
|
schema
|
|
)
|
|
included_attributes = convert_query_str_to_list(attributes_query_str)
|
|
included_attributes = included_attributes or all_attributes
|
|
included_attributes_set = set(included_attributes).union(
|
|
set(always_returned_attributes)
|
|
)
|
|
excluded_attributes = convert_query_str_to_list(excluded_attributes_query_str)
|
|
excluded_attributes = excluded_attributes or []
|
|
excluded_attributes_set = set(excluded_attributes).difference(
|
|
set(always_returned_attributes)
|
|
)
|
|
include_paths = included_attributes_set.difference(excluded_attributes_set)
|
|
|
|
include_tree = {}
|
|
for path in include_paths:
|
|
parts = path.split(".")
|
|
node = include_tree
|
|
for part in parts:
|
|
node = node.setdefault(part, {})
|
|
|
|
def _recurse(o, tree, parent_key=None):
|
|
if isinstance(o, dict):
|
|
out = {}
|
|
for key, subtree in tree.items():
|
|
if key in o:
|
|
out[key] = _recurse(o[key], subtree, key)
|
|
return out
|
|
if isinstance(o, list):
|
|
out = [_recurse(item, tree, parent_key) for item in o]
|
|
return out
|
|
return o
|
|
|
|
result = _recurse(obj, include_tree)
|
|
return result
|
|
|
|
|
|
def filter_mutable_attributes(
|
|
schema: dict[str, Any],
|
|
requested_changes: dict[str, Any],
|
|
current_values: dict[str, Any],
|
|
) -> dict[str, Any]:
|
|
attributes = {attr.get("name"): attr for attr in schema.get("attributes", [])}
|
|
|
|
valid_changes = {}
|
|
|
|
for attr_name, new_value in requested_changes.items():
|
|
attr_def = attributes.get(attr_name)
|
|
if not attr_def:
|
|
# Unknown attribute: ignore per RFC 7644
|
|
continue
|
|
|
|
mutability = attr_def.get("mutability", "readWrite")
|
|
|
|
if mutability == "readWrite" or mutability == "writeOnly":
|
|
valid_changes[attr_name] = new_value
|
|
|
|
elif mutability == "readOnly":
|
|
# Cannot modify read-only attributes: ignore
|
|
continue
|
|
|
|
elif mutability == "immutable":
|
|
# Only valid if the new value matches the current value exactly
|
|
current_value = current_values.get(attr_name)
|
|
if new_value != current_value:
|
|
raise ValueError(
|
|
f"Attribute '{attr_name}' is immutable (cannot change). "
|
|
f"Current value: {current_value!r}, attempted change: {new_value!r}"
|
|
)
|
|
# If it matches, no change is needed (already set)
|
|
|
|
return valid_changes
|
|
|
|
|
|
def apply_scim_patch(
|
|
operations: list[dict[str, Any]], resource: dict[str, Any], schema: dict[str, Any]
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Apply SCIM patch operations to a resource based on schema.
|
|
Returns (updated_resource, changes) where `updated_resource` is the new SCIM
|
|
resource dict and `changes` maps attribute or path to (old_value, new_value).
|
|
Additions have old_value=None if attribute didn't exist; removals have new_value=None.
|
|
For add/remove on list-valued attributes, changes record the full list before/after.
|
|
"""
|
|
# Deep copy to avoid mutating original
|
|
updated = deepcopy(resource)
|
|
changes = {}
|
|
|
|
# Allowed attributes from schema
|
|
allowed_attrs = {attr["name"]: attr for attr in schema.get("attributes", [])}
|
|
|
|
for op in operations:
|
|
op_type = op.get("op", "").strip().lower()
|
|
path = op.get("path")
|
|
value = op.get("value")
|
|
|
|
if not path:
|
|
# Top-level merge
|
|
if op_type in ("add", "replace"):
|
|
if not isinstance(value, dict):
|
|
raise ValueError(
|
|
"When path is not provided, value must be a dict of attributes to merge."
|
|
)
|
|
for attr, val in value.items():
|
|
if attr not in allowed_attrs:
|
|
raise ValueError(
|
|
f"Attribute '{attr}' not defined in SCIM schema"
|
|
)
|
|
old = updated.get(attr)
|
|
updated[attr] = val if val is not None else updated.pop(attr, None)
|
|
changes[attr] = (old, val)
|
|
else:
|
|
raise ValueError(f"Unsupported operation without path: {op_type}")
|
|
continue
|
|
|
|
tokens = parse_scim_path(path)
|
|
|
|
# Detect simple top-level list add/remove
|
|
if (
|
|
op_type in ("add", "remove")
|
|
and len(tokens) == 1
|
|
and isinstance(tokens[0], str)
|
|
):
|
|
attr = tokens[0]
|
|
if attr not in allowed_attrs:
|
|
raise ValueError(f"Attribute '{attr}' not defined in SCIM schema")
|
|
current_list = updated.get(attr, [])
|
|
if isinstance(current_list, list):
|
|
before = deepcopy(current_list)
|
|
if op_type == "add":
|
|
# Ensure list exists
|
|
updated.setdefault(attr, [])
|
|
# Append new items
|
|
items = value if isinstance(value, list) else [value]
|
|
updated[attr].extend(items)
|
|
else: # remove
|
|
# Remove items matching filter if value not provided
|
|
# For remove on list without filter, remove all values equal to value
|
|
if value is None:
|
|
updated.pop(attr, None)
|
|
else:
|
|
# filter value items out
|
|
items = value if isinstance(value, list) else [value]
|
|
updated[attr] = [
|
|
e for e in updated.get(attr, []) if e not in items
|
|
]
|
|
after = deepcopy(updated.get(attr, []))
|
|
changes[attr] = (before, after)
|
|
continue
|
|
|
|
# For other operations, get old value and apply normally
|
|
old_val = get_by_path(updated, tokens)
|
|
|
|
if op_type == "add":
|
|
set_by_path(updated, tokens, value)
|
|
elif op_type == "replace":
|
|
if value is None:
|
|
remove_by_path(updated, tokens)
|
|
else:
|
|
set_by_path(updated, tokens, value)
|
|
elif op_type == "remove":
|
|
remove_by_path(updated, tokens)
|
|
else:
|
|
raise ValueError(f"Unsupported operation type: {op_type}")
|
|
|
|
# Record change for non-list or nested paths
|
|
new_val = None if op_type == "remove" else get_by_path(updated, tokens)
|
|
changes[path] = (old_val, new_val)
|
|
|
|
return updated, changes
|
|
|
|
|
|
def parse_scim_path(path):
|
|
"""
|
|
Parse a SCIM-style path (e.g., 'emails[type eq "work"].value') into a list
|
|
of tokens. Each token is either a string attribute name or a tuple
|
|
(attr, filter_attr, filter_value) for list-filtering.
|
|
"""
|
|
tokens = []
|
|
# Regex matches segments like attr or attr[filter] where filter is e.g. type eq "work"
|
|
segment_re = re.compile(r"([^\.\[]+)(?:\[(.*?)\])?")
|
|
for match in segment_re.finditer(path):
|
|
attr = match.group(1)
|
|
filt = match.group(2)
|
|
if filt:
|
|
# Support simple equality filter of form: subAttr eq "value"
|
|
m = re.match(r"\s*(\w+)\s+eq\s+\"([^\"]+)\"", filt)
|
|
if not m:
|
|
raise ValueError(f"Unsupported filter expression: {filt}")
|
|
filter_attr, filter_val = m.group(1), m.group(2)
|
|
tokens.append((attr, filter_attr, filter_val))
|
|
else:
|
|
tokens.append(attr)
|
|
return tokens
|
|
|
|
|
|
def get_by_path(doc, tokens):
|
|
"""
|
|
Retrieve a value from nested dicts/lists using parsed tokens.
|
|
Returns None if any step is missing.
|
|
"""
|
|
cur = doc
|
|
for token in tokens:
|
|
if cur is None:
|
|
return None
|
|
if isinstance(token, tuple):
|
|
attr, fattr, fval = token
|
|
lst = cur.get(attr)
|
|
if not isinstance(lst, list):
|
|
return None
|
|
# Find first dict element matching filter
|
|
for elem in lst:
|
|
if isinstance(elem, dict) and elem.get(fattr) == fval:
|
|
cur = elem
|
|
break
|
|
else:
|
|
return None
|
|
else:
|
|
if isinstance(cur, dict):
|
|
cur = cur.get(token)
|
|
elif isinstance(cur, list) and isinstance(token, int):
|
|
if 0 <= token < len(cur):
|
|
cur = cur[token]
|
|
else:
|
|
return None
|
|
else:
|
|
return None
|
|
return cur
|
|
|
|
|
|
def set_by_path(doc, tokens, value):
|
|
"""
|
|
Set a value in nested dicts/lists using parsed tokens.
|
|
Creates intermediate dicts/lists as needed.
|
|
"""
|
|
cur = doc
|
|
for i, token in enumerate(tokens):
|
|
last = i == len(tokens) - 1
|
|
if isinstance(token, tuple):
|
|
attr, fattr, fval = token
|
|
lst = cur.setdefault(attr, [])
|
|
if not isinstance(lst, list):
|
|
raise ValueError(f"Expected list at attribute '{attr}'")
|
|
# Find existing entry
|
|
idx = next(
|
|
(
|
|
j
|
|
for j, e in enumerate(lst)
|
|
if isinstance(e, dict) and e.get(fattr) == fval
|
|
),
|
|
None,
|
|
)
|
|
if idx is None:
|
|
if last:
|
|
lst.append(value)
|
|
return
|
|
else:
|
|
new = {}
|
|
lst.append(new)
|
|
cur = new
|
|
else:
|
|
if last:
|
|
lst[idx] = value
|
|
return
|
|
cur = lst[idx]
|
|
|
|
else:
|
|
if last:
|
|
if value is None:
|
|
if isinstance(cur, dict):
|
|
cur.pop(token, None)
|
|
else:
|
|
cur[token] = value
|
|
else:
|
|
cur = cur.setdefault(token, {})
|
|
|
|
|
|
def remove_by_path(doc, tokens):
|
|
"""
|
|
Remove a value in nested dicts/lists using parsed tokens.
|
|
Does nothing if path not present.
|
|
"""
|
|
cur = doc
|
|
for i, token in enumerate(tokens):
|
|
last = i == len(tokens) - 1
|
|
if isinstance(token, tuple):
|
|
attr, fattr, fval = token
|
|
lst = cur.get(attr)
|
|
if not isinstance(lst, list):
|
|
return
|
|
for j, elem in enumerate(lst):
|
|
if isinstance(elem, dict) and elem.get(fattr) == fval:
|
|
if last:
|
|
lst.pop(j)
|
|
return
|
|
cur = elem
|
|
break
|
|
else:
|
|
return
|
|
else:
|
|
if last:
|
|
if isinstance(cur, dict):
|
|
cur.pop(token, None)
|
|
elif isinstance(cur, list) and isinstance(token, int):
|
|
if 0 <= token < len(cur):
|
|
cur.pop(token)
|
|
return
|
|
else:
|
|
if isinstance(cur, dict):
|
|
cur = cur.get(token)
|
|
elif isinstance(cur, list) and isinstance(token, int):
|
|
cur = cur[token] if 0 <= token < len(cur) else None
|
|
else:
|
|
return
|
|
|
|
|
|
class SCIMFilterParser:
|
|
_TOK_RE = re.compile(
|
|
r"""
|
|
(?:"[^"]*"|'[^']*')| # double- or single-quoted string
|
|
\band\b|\bor\b|\bnot\b|
|
|
\beq\b|\bne\b|\bco\b|\bsw\b|\bew\b|\bgt\b|\blt\b|\bge\b|\ble\b|\bpr\b|
|
|
[()]| # parentheses
|
|
[^\s()]+ # bare token
|
|
""",
|
|
re.IGNORECASE | re.VERBOSE,
|
|
)
|
|
_NUMERIC_RE = re.compile(r"^-?\d+(\.\d+)?$")
|
|
|
|
def __init__(self, text: str, attr_map: dict[str, str]):
|
|
self.tokens = [tok for tok in self._TOK_RE.findall(text)]
|
|
self.pos = 0
|
|
self.attr_map = attr_map
|
|
|
|
def peek(self) -> str | None:
|
|
return self.tokens[self.pos].lower() if self.pos < len(self.tokens) else None
|
|
|
|
def next(self) -> str:
|
|
tok = self.tokens[self.pos]
|
|
self.pos += 1
|
|
return tok
|
|
|
|
def parse(self) -> str:
|
|
expr = self._parse_or()
|
|
if self.pos != len(self.tokens):
|
|
raise ValueError(f"Unexpected token at end: {self.peek()}")
|
|
return expr
|
|
|
|
def _parse_or(self) -> str:
|
|
left = self._parse_and()
|
|
while self.peek() == "or":
|
|
self.next()
|
|
right = self._parse_and()
|
|
left = f"({left} OR {right})"
|
|
return left
|
|
|
|
def _parse_and(self) -> str:
|
|
left = self._parse_not()
|
|
while self.peek() == "and":
|
|
self.next()
|
|
right = self._parse_not()
|
|
left = f"({left} AND {right})"
|
|
return left
|
|
|
|
def _parse_not(self) -> str:
|
|
if self.peek() == "not":
|
|
self.next()
|
|
inner = self._parse_simple()
|
|
return f"(NOT {inner})"
|
|
return self._parse_simple()
|
|
|
|
def _parse_simple(self) -> str:
|
|
if self.peek() == "(":
|
|
self.next()
|
|
expr = self._parse_or()
|
|
if self.next() != ")":
|
|
raise ValueError("Missing closing parenthesis")
|
|
return f"({expr})"
|
|
return self._parse_comparison()
|
|
|
|
def _parse_comparison(self) -> str:
|
|
raw_attr = self.next()
|
|
col = self.attr_map.get(raw_attr, raw_attr)
|
|
op = self.next().lower()
|
|
|
|
if op == "pr":
|
|
return f"{col} IS NOT NULL"
|
|
|
|
val = self.next()
|
|
|
|
# strip quotes if present (single or double)
|
|
if (val.startswith('"') and val.endswith('"')) or (
|
|
val.startswith("'") and val.endswith("'")
|
|
):
|
|
inner = val[1:-1].replace("'", "''")
|
|
sql_val = f"'{inner}'"
|
|
elif self._NUMERIC_RE.match(val):
|
|
sql_val = val
|
|
else:
|
|
inner = val.replace("'", "''")
|
|
sql_val = f"'{inner}'"
|
|
|
|
if op == "eq":
|
|
return f"{col} = {sql_val}"
|
|
if op == "ne":
|
|
return f"{col} <> {sql_val}"
|
|
if op == "co":
|
|
return f"{col} LIKE '%' || {sql_val} || '%'"
|
|
if op == "sw":
|
|
return f"{col} LIKE {sql_val} || '%'"
|
|
if op == "ew":
|
|
return f"{col} LIKE '%' || {sql_val}"
|
|
if op in ("gt", "lt", "ge", "le"):
|
|
sql_ops = {"gt": ">", "lt": "<", "ge": ">=", "le": "<="}
|
|
return f"{col} {sql_ops[op]} {sql_val}"
|
|
|
|
raise ValueError(f"Unknown operator: {op}")
|
|
|
|
|
|
def scim_to_sql_where(filter_str: str | None, attr_map: dict[str, str]) -> str | None:
|
|
"""
|
|
Convert a SCIM filter into an SQL WHERE fragment,
|
|
mapping SCIM attributes per attr_map and correctly quoting
|
|
both single- and double-quoted strings.
|
|
"""
|
|
if filter_str is None:
|
|
return None
|
|
parser = SCIMFilterParser(filter_str, attr_map)
|
|
return parser.parse()
|