diff --git a/ee/api/routers/scim/api.py b/ee/api/routers/scim/api.py index affe33946..2508fc7bf 100644 --- a/ee/api/routers/scim/api.py +++ b/ee/api/routers/scim/api.py @@ -234,6 +234,7 @@ user_config = ResourceConfig( rewrite_provider_resource=users.rewrite_provider_resource, convert_client_resource_update_input_to_provider_resource_update_input=users.convert_client_resource_update_input_to_provider_resource_update_input, update_provider_resource=users.update_provider_resource, + filter_attribute_mapping=users.filter_attribute_mapping, ) group_config = ResourceConfig( schema_id="urn:ietf:params:scim:schemas:core:2.0:Group", @@ -251,6 +252,7 @@ group_config = ResourceConfig( rewrite_provider_resource=groups.rewrite_provider_resource, convert_client_resource_update_input_to_provider_resource_update_input=groups.convert_client_resource_update_input_to_provider_resource_update_input, update_provider_resource=groups.update_provider_resource, + filter_attribute_mapping=groups.filter_attribute_mapping, ) RESOURCE_TYPE_TO_RESOURCE_CONFIG: dict[str, ResourceConfig] = { @@ -272,15 +274,19 @@ async def get_resources( requested_items_per_page: int | None = Query(None, alias="count"), attributes: str | None = Query(None), excluded_attributes: str | None = Query(None, alias="excludedAttributes"), + filter: str | None = Query(None), ): config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] - total_resources = config.get_active_resource_count(tenant_id) + filter_clause = helpers.scim_to_sql_where(filter, config.filter_attribute_mapping()) + total_resources = config.get_active_resource_count(tenant_id, filter_clause) start_index_one_indexed = max(1, requested_start_index_one_indexed) offset = start_index_one_indexed - 1 limit = min( max(0, requested_items_per_page or config.max_chunk_size), config.max_chunk_size ) - provider_resources = config.get_provider_resource_chunk(offset, tenant_id, limit) + provider_resources = config.get_provider_resource_chunk( + offset, tenant_id, limit, filter_clause + ) client_resources = [ api_helper.convert_provider_resource_to_client_resource( config, provider_resource, attributes, excluded_attributes diff --git a/ee/api/routers/scim/fixtures/service_provider_config.json b/ee/api/routers/scim/fixtures/service_provider_config.json index dbcbff942..38a5079ae 100644 --- a/ee/api/routers/scim/fixtures/service_provider_config.json +++ b/ee/api/routers/scim/fixtures/service_provider_config.json @@ -11,8 +11,8 @@ "maxPayloadSize": 0 }, "filter": { - "supported": false, - "maxResults": 0 + "supported": true, + "maxResults": 10 }, "changePassword": { "supported": false diff --git a/ee/api/routers/scim/groups.py b/ee/api/routers/scim/groups.py index cc113eb16..09d36b231 100644 --- a/ee/api/routers/scim/groups.py +++ b/ee/api/routers/scim/groups.py @@ -56,30 +56,36 @@ def convert_provider_resource_to_client_resource( } -def get_active_resource_count(tenant_id: int) -> int: +def get_active_resource_count(tenant_id: int, filter_clause: str | None = None) -> int: + where_and_clauses = [ + f"roles.tenant_id = {tenant_id}", + "roles.deleted_at IS NULL", + ] + if filter_clause is not None: + where_and_clauses.append(filter_clause) + where_clause = " AND ".join(where_and_clauses) with pg_client.PostgresClient() as cur: cur.execute( - cur.mogrify( - """ - SELECT COUNT(*) - FROM public.roles - WHERE - roles.tenant_id = %(tenant_id)s - AND roles.deleted_at IS NULL - """, - {"tenant_id": tenant_id}, - ) + f""" + SELECT COUNT(*) + FROM public.roles + WHERE {where_clause} + """ ) return cur.fetchone()["count"] -def _main_select_query(tenant_id: int, resource_id: int | None = None) -> str: +def _main_select_query( + tenant_id: int, resource_id: int | None = None, filter_clause: str | None = None +) -> str: where_and_clauses = [ f"roles.tenant_id = {tenant_id}", "roles.deleted_at IS NULL", ] if resource_id is not None: where_and_clauses.append(f"roles.role_id = {resource_id}") + if filter_clause is not None: + where_and_clauses.append(filter_clause) where_clause = " AND ".join(where_and_clauses) return f""" SELECT @@ -107,14 +113,18 @@ def _main_select_query(tenant_id: int, resource_id: int | None = None) -> str: def get_provider_resource_chunk( - offset: int, tenant_id: int, limit: int + offset: int, tenant_id: int, limit: int, filter_clause: str | None = None ) -> list[ProviderResource]: - query = _main_select_query(tenant_id) + query = _main_select_query(tenant_id, filter_clause=filter_clause) with pg_client.PostgresClient() as cur: cur.execute(f"{query} LIMIT {limit} OFFSET {offset}") return cur.fetchall() +def filter_attribute_mapping() -> dict[str, str]: + return {"displayName": "roles.name"} + + def get_provider_resource( resource_id: ResourceId, tenant_id: int ) -> ProviderResource | None: diff --git a/ee/api/routers/scim/helpers.py b/ee/api/routers/scim/helpers.py index a94806d14..ebf5f1b67 100644 --- a/ee/api/routers/scim/helpers.py +++ b/ee/api/routers/scim/helpers.py @@ -366,3 +366,118 @@ def remove_by_path(doc, tokens): cur = cur[token] if 0 <= token < len(cur) else None else: return + + +class SCIMFilterParser: + _TOK_RE = re.compile( + r""" + (?:"[^"]*"|'[^']*')| # double- or single-quoted string + \band\b|\bor\b|\bnot\b| + \beq\b|\bne\b|\bco\b|\bsw\b|\bew\b|\bgt\b|\blt\b|\bge\b|\ble\b|\bpr\b| + [()]| # parentheses + [^\s()]+ # bare token + """, + re.IGNORECASE | re.VERBOSE, + ) + _NUMERIC_RE = re.compile(r"^-?\d+(\.\d+)?$") + + def __init__(self, text: str, attr_map: dict[str, str]): + self.tokens = [tok for tok in self._TOK_RE.findall(text)] + self.pos = 0 + self.attr_map = attr_map + + def peek(self) -> str | None: + return self.tokens[self.pos].lower() if self.pos < len(self.tokens) else None + + def next(self) -> str: + tok = self.tokens[self.pos] + self.pos += 1 + return tok + + def parse(self) -> str: + expr = self._parse_or() + if self.pos != len(self.tokens): + raise ValueError(f"Unexpected token at end: {self.peek()}") + return expr + + def _parse_or(self) -> str: + left = self._parse_and() + while self.peek() == "or": + self.next() + right = self._parse_and() + left = f"({left} OR {right})" + return left + + def _parse_and(self) -> str: + left = self._parse_not() + while self.peek() == "and": + self.next() + right = self._parse_not() + left = f"({left} AND {right})" + return left + + def _parse_not(self) -> str: + if self.peek() == "not": + self.next() + inner = self._parse_simple() + return f"(NOT {inner})" + return self._parse_simple() + + def _parse_simple(self) -> str: + if self.peek() == "(": + self.next() + expr = self._parse_or() + if self.next() != ")": + raise ValueError("Missing closing parenthesis") + return f"({expr})" + return self._parse_comparison() + + def _parse_comparison(self) -> str: + raw_attr = self.next() + col = self.attr_map.get(raw_attr, raw_attr) + op = self.next().lower() + + if op == "pr": + return f"{col} IS NOT NULL" + + val = self.next() + + # strip quotes if present (single or double) + if (val.startswith('"') and val.endswith('"')) or ( + val.startswith("'") and val.endswith("'") + ): + inner = val[1:-1].replace("'", "''") + sql_val = f"'{inner}'" + elif self._NUMERIC_RE.match(val): + sql_val = val + else: + inner = val.replace("'", "''") + sql_val = f"'{inner}'" + + if op == "eq": + return f"{col} = {sql_val}" + if op == "ne": + return f"{col} <> {sql_val}" + if op == "co": + return f"{col} LIKE '%' || {sql_val} || '%'" + if op == "sw": + return f"{col} LIKE {sql_val} || '%'" + if op == "ew": + return f"{col} LIKE '%' || {sql_val}" + if op in ("gt", "lt", "ge", "le"): + sql_ops = {"gt": ">", "lt": "<", "ge": ">=", "le": "<="} + return f"{col} {sql_ops[op]} {sql_val}" + + raise ValueError(f"Unknown operator: {op}") + + +def scim_to_sql_where(filter_str: str | None, attr_map: dict[str, str]) -> str | None: + """ + Convert a SCIM filter into an SQL WHERE fragment, + mapping SCIM attributes per attr_map and correctly quoting + both single- and double-quoted strings. + """ + if filter_str is None: + return None + parser = SCIMFilterParser(filter_str, attr_map) + return parser.parse() diff --git a/ee/api/routers/scim/resource_config.py b/ee/api/routers/scim/resource_config.py index afae5eed6..f89877f65 100644 --- a/ee/api/routers/scim/resource_config.py +++ b/ee/api/routers/scim/resource_config.py @@ -40,6 +40,7 @@ class ResourceConfig: [int, ClientInput], ProviderInput ] update_provider_resource: Callable[..., ProviderResource] + filter_attribute_mapping: Callable[None, dict[str, str]] def get_schema(config: ResourceConfig) -> Schema: diff --git a/ee/api/routers/scim/users.py b/ee/api/routers/scim/users.py index 1a7fd1d17..e1d67b58e 100644 --- a/ee/api/routers/scim/users.py +++ b/ee/api/routers/scim/users.py @@ -85,6 +85,10 @@ def convert_client_resource_creation_input_to_provider_resource_creation_input( return result +def filter_attribute_mapping() -> dict[str, str]: + return {"userName": "users.email"} + + def get_provider_resource_from_unique_fields( email: str, **kwargs: dict[str, Any] ) -> ProviderResource | None: @@ -153,40 +157,44 @@ def convert_provider_resource_to_client_resource( } -def get_active_resource_count(tenant_id: int) -> int: +def get_active_resource_count(tenant_id: int, filter_clause: str | None = None) -> int: + where_and_statements = [ + f"users.tenant_id = {tenant_id}", + "users.deleted_at IS NULL", + ] + if filter_clause is not None: + where_and_statements.append(filter_clause) + where_clause = " AND ".join(where_and_statements) with pg_client.PostgresClient() as cur: cur.execute( - cur.mogrify( - """ - SELECT COUNT(*) - FROM public.users - WHERE - users.tenant_id = %(tenant_id)s - AND users.deleted_at IS NULL - """, - {"tenant_id": tenant_id}, - ) + f""" + SELECT COUNT(*) + FROM public.users + WHERE {where_clause} + """ ) return cur.fetchone()["count"] def get_provider_resource_chunk( - offset: int, tenant_id: int, limit: int + offset: int, tenant_id: int, limit: int, filter_clause: str | None = None ) -> list[ProviderResource]: + where_and_statements = [ + f"users.tenant_id = {tenant_id}", + "users.deleted_at IS NULL", + ] + if filter_clause is not None: + where_and_statements.append(filter_clause) + where_clause = " AND ".join(where_and_statements) with pg_client.PostgresClient() as cur: cur.execute( - cur.mogrify( - """ - SELECT * - FROM public.users - WHERE - users.tenant_id = %(tenant_id)s - AND users.deleted_at IS NULL - LIMIT %(limit)s - OFFSET %(offset)s; - """, - {"offset": offset, "limit": limit, "tenant_id": tenant_id}, - ) + f""" + SELECT * + FROM public.users + WHERE {where_clause} + LIMIT {limit} + OFFSET {offset}; + """ ) return cur.fetchall()