add filtering
This commit is contained in:
parent
891a0e31c1
commit
920fdd3455
6 changed files with 182 additions and 42 deletions
|
|
@ -234,6 +234,7 @@ user_config = ResourceConfig(
|
|||
rewrite_provider_resource=users.rewrite_provider_resource,
|
||||
convert_client_resource_update_input_to_provider_resource_update_input=users.convert_client_resource_update_input_to_provider_resource_update_input,
|
||||
update_provider_resource=users.update_provider_resource,
|
||||
filter_attribute_mapping=users.filter_attribute_mapping,
|
||||
)
|
||||
group_config = ResourceConfig(
|
||||
schema_id="urn:ietf:params:scim:schemas:core:2.0:Group",
|
||||
|
|
@ -251,6 +252,7 @@ group_config = ResourceConfig(
|
|||
rewrite_provider_resource=groups.rewrite_provider_resource,
|
||||
convert_client_resource_update_input_to_provider_resource_update_input=groups.convert_client_resource_update_input_to_provider_resource_update_input,
|
||||
update_provider_resource=groups.update_provider_resource,
|
||||
filter_attribute_mapping=groups.filter_attribute_mapping,
|
||||
)
|
||||
|
||||
RESOURCE_TYPE_TO_RESOURCE_CONFIG: dict[str, ResourceConfig] = {
|
||||
|
|
@ -272,15 +274,19 @@ async def get_resources(
|
|||
requested_items_per_page: int | None = Query(None, alias="count"),
|
||||
attributes: str | None = Query(None),
|
||||
excluded_attributes: str | None = Query(None, alias="excludedAttributes"),
|
||||
filter: str | None = Query(None),
|
||||
):
|
||||
config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
|
||||
total_resources = config.get_active_resource_count(tenant_id)
|
||||
filter_clause = helpers.scim_to_sql_where(filter, config.filter_attribute_mapping())
|
||||
total_resources = config.get_active_resource_count(tenant_id, filter_clause)
|
||||
start_index_one_indexed = max(1, requested_start_index_one_indexed)
|
||||
offset = start_index_one_indexed - 1
|
||||
limit = min(
|
||||
max(0, requested_items_per_page or config.max_chunk_size), config.max_chunk_size
|
||||
)
|
||||
provider_resources = config.get_provider_resource_chunk(offset, tenant_id, limit)
|
||||
provider_resources = config.get_provider_resource_chunk(
|
||||
offset, tenant_id, limit, filter_clause
|
||||
)
|
||||
client_resources = [
|
||||
api_helper.convert_provider_resource_to_client_resource(
|
||||
config, provider_resource, attributes, excluded_attributes
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@
|
|||
"maxPayloadSize": 0
|
||||
},
|
||||
"filter": {
|
||||
"supported": false,
|
||||
"maxResults": 0
|
||||
"supported": true,
|
||||
"maxResults": 10
|
||||
},
|
||||
"changePassword": {
|
||||
"supported": false
|
||||
|
|
|
|||
|
|
@ -56,30 +56,36 @@ def convert_provider_resource_to_client_resource(
|
|||
}
|
||||
|
||||
|
||||
def get_active_resource_count(tenant_id: int) -> int:
|
||||
def get_active_resource_count(tenant_id: int, filter_clause: str | None = None) -> int:
|
||||
where_and_clauses = [
|
||||
f"roles.tenant_id = {tenant_id}",
|
||||
"roles.deleted_at IS NULL",
|
||||
]
|
||||
if filter_clause is not None:
|
||||
where_and_clauses.append(filter_clause)
|
||||
where_clause = " AND ".join(where_and_clauses)
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM public.roles
|
||||
WHERE
|
||||
roles.tenant_id = %(tenant_id)s
|
||||
AND roles.deleted_at IS NULL
|
||||
""",
|
||||
{"tenant_id": tenant_id},
|
||||
)
|
||||
f"""
|
||||
SELECT COUNT(*)
|
||||
FROM public.roles
|
||||
WHERE {where_clause}
|
||||
"""
|
||||
)
|
||||
return cur.fetchone()["count"]
|
||||
|
||||
|
||||
def _main_select_query(tenant_id: int, resource_id: int | None = None) -> str:
|
||||
def _main_select_query(
|
||||
tenant_id: int, resource_id: int | None = None, filter_clause: str | None = None
|
||||
) -> str:
|
||||
where_and_clauses = [
|
||||
f"roles.tenant_id = {tenant_id}",
|
||||
"roles.deleted_at IS NULL",
|
||||
]
|
||||
if resource_id is not None:
|
||||
where_and_clauses.append(f"roles.role_id = {resource_id}")
|
||||
if filter_clause is not None:
|
||||
where_and_clauses.append(filter_clause)
|
||||
where_clause = " AND ".join(where_and_clauses)
|
||||
return f"""
|
||||
SELECT
|
||||
|
|
@ -107,14 +113,18 @@ def _main_select_query(tenant_id: int, resource_id: int | None = None) -> str:
|
|||
|
||||
|
||||
def get_provider_resource_chunk(
|
||||
offset: int, tenant_id: int, limit: int
|
||||
offset: int, tenant_id: int, limit: int, filter_clause: str | None = None
|
||||
) -> list[ProviderResource]:
|
||||
query = _main_select_query(tenant_id)
|
||||
query = _main_select_query(tenant_id, filter_clause=filter_clause)
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(f"{query} LIMIT {limit} OFFSET {offset}")
|
||||
return cur.fetchall()
|
||||
|
||||
|
||||
def filter_attribute_mapping() -> dict[str, str]:
|
||||
return {"displayName": "roles.name"}
|
||||
|
||||
|
||||
def get_provider_resource(
|
||||
resource_id: ResourceId, tenant_id: int
|
||||
) -> ProviderResource | None:
|
||||
|
|
|
|||
|
|
@ -366,3 +366,118 @@ def remove_by_path(doc, tokens):
|
|||
cur = cur[token] if 0 <= token < len(cur) else None
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
class SCIMFilterParser:
|
||||
_TOK_RE = re.compile(
|
||||
r"""
|
||||
(?:"[^"]*"|'[^']*')| # double- or single-quoted string
|
||||
\band\b|\bor\b|\bnot\b|
|
||||
\beq\b|\bne\b|\bco\b|\bsw\b|\bew\b|\bgt\b|\blt\b|\bge\b|\ble\b|\bpr\b|
|
||||
[()]| # parentheses
|
||||
[^\s()]+ # bare token
|
||||
""",
|
||||
re.IGNORECASE | re.VERBOSE,
|
||||
)
|
||||
_NUMERIC_RE = re.compile(r"^-?\d+(\.\d+)?$")
|
||||
|
||||
def __init__(self, text: str, attr_map: dict[str, str]):
|
||||
self.tokens = [tok for tok in self._TOK_RE.findall(text)]
|
||||
self.pos = 0
|
||||
self.attr_map = attr_map
|
||||
|
||||
def peek(self) -> str | None:
|
||||
return self.tokens[self.pos].lower() if self.pos < len(self.tokens) else None
|
||||
|
||||
def next(self) -> str:
|
||||
tok = self.tokens[self.pos]
|
||||
self.pos += 1
|
||||
return tok
|
||||
|
||||
def parse(self) -> str:
|
||||
expr = self._parse_or()
|
||||
if self.pos != len(self.tokens):
|
||||
raise ValueError(f"Unexpected token at end: {self.peek()}")
|
||||
return expr
|
||||
|
||||
def _parse_or(self) -> str:
|
||||
left = self._parse_and()
|
||||
while self.peek() == "or":
|
||||
self.next()
|
||||
right = self._parse_and()
|
||||
left = f"({left} OR {right})"
|
||||
return left
|
||||
|
||||
def _parse_and(self) -> str:
|
||||
left = self._parse_not()
|
||||
while self.peek() == "and":
|
||||
self.next()
|
||||
right = self._parse_not()
|
||||
left = f"({left} AND {right})"
|
||||
return left
|
||||
|
||||
def _parse_not(self) -> str:
|
||||
if self.peek() == "not":
|
||||
self.next()
|
||||
inner = self._parse_simple()
|
||||
return f"(NOT {inner})"
|
||||
return self._parse_simple()
|
||||
|
||||
def _parse_simple(self) -> str:
|
||||
if self.peek() == "(":
|
||||
self.next()
|
||||
expr = self._parse_or()
|
||||
if self.next() != ")":
|
||||
raise ValueError("Missing closing parenthesis")
|
||||
return f"({expr})"
|
||||
return self._parse_comparison()
|
||||
|
||||
def _parse_comparison(self) -> str:
|
||||
raw_attr = self.next()
|
||||
col = self.attr_map.get(raw_attr, raw_attr)
|
||||
op = self.next().lower()
|
||||
|
||||
if op == "pr":
|
||||
return f"{col} IS NOT NULL"
|
||||
|
||||
val = self.next()
|
||||
|
||||
# strip quotes if present (single or double)
|
||||
if (val.startswith('"') and val.endswith('"')) or (
|
||||
val.startswith("'") and val.endswith("'")
|
||||
):
|
||||
inner = val[1:-1].replace("'", "''")
|
||||
sql_val = f"'{inner}'"
|
||||
elif self._NUMERIC_RE.match(val):
|
||||
sql_val = val
|
||||
else:
|
||||
inner = val.replace("'", "''")
|
||||
sql_val = f"'{inner}'"
|
||||
|
||||
if op == "eq":
|
||||
return f"{col} = {sql_val}"
|
||||
if op == "ne":
|
||||
return f"{col} <> {sql_val}"
|
||||
if op == "co":
|
||||
return f"{col} LIKE '%' || {sql_val} || '%'"
|
||||
if op == "sw":
|
||||
return f"{col} LIKE {sql_val} || '%'"
|
||||
if op == "ew":
|
||||
return f"{col} LIKE '%' || {sql_val}"
|
||||
if op in ("gt", "lt", "ge", "le"):
|
||||
sql_ops = {"gt": ">", "lt": "<", "ge": ">=", "le": "<="}
|
||||
return f"{col} {sql_ops[op]} {sql_val}"
|
||||
|
||||
raise ValueError(f"Unknown operator: {op}")
|
||||
|
||||
|
||||
def scim_to_sql_where(filter_str: str | None, attr_map: dict[str, str]) -> str | None:
|
||||
"""
|
||||
Convert a SCIM filter into an SQL WHERE fragment,
|
||||
mapping SCIM attributes per attr_map and correctly quoting
|
||||
both single- and double-quoted strings.
|
||||
"""
|
||||
if filter_str is None:
|
||||
return None
|
||||
parser = SCIMFilterParser(filter_str, attr_map)
|
||||
return parser.parse()
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ class ResourceConfig:
|
|||
[int, ClientInput], ProviderInput
|
||||
]
|
||||
update_provider_resource: Callable[..., ProviderResource]
|
||||
filter_attribute_mapping: Callable[None, dict[str, str]]
|
||||
|
||||
|
||||
def get_schema(config: ResourceConfig) -> Schema:
|
||||
|
|
|
|||
|
|
@ -85,6 +85,10 @@ def convert_client_resource_creation_input_to_provider_resource_creation_input(
|
|||
return result
|
||||
|
||||
|
||||
def filter_attribute_mapping() -> dict[str, str]:
|
||||
return {"userName": "users.email"}
|
||||
|
||||
|
||||
def get_provider_resource_from_unique_fields(
|
||||
email: str, **kwargs: dict[str, Any]
|
||||
) -> ProviderResource | None:
|
||||
|
|
@ -153,40 +157,44 @@ def convert_provider_resource_to_client_resource(
|
|||
}
|
||||
|
||||
|
||||
def get_active_resource_count(tenant_id: int) -> int:
|
||||
def get_active_resource_count(tenant_id: int, filter_clause: str | None = None) -> int:
|
||||
where_and_statements = [
|
||||
f"users.tenant_id = {tenant_id}",
|
||||
"users.deleted_at IS NULL",
|
||||
]
|
||||
if filter_clause is not None:
|
||||
where_and_statements.append(filter_clause)
|
||||
where_clause = " AND ".join(where_and_statements)
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM public.users
|
||||
WHERE
|
||||
users.tenant_id = %(tenant_id)s
|
||||
AND users.deleted_at IS NULL
|
||||
""",
|
||||
{"tenant_id": tenant_id},
|
||||
)
|
||||
f"""
|
||||
SELECT COUNT(*)
|
||||
FROM public.users
|
||||
WHERE {where_clause}
|
||||
"""
|
||||
)
|
||||
return cur.fetchone()["count"]
|
||||
|
||||
|
||||
def get_provider_resource_chunk(
|
||||
offset: int, tenant_id: int, limit: int
|
||||
offset: int, tenant_id: int, limit: int, filter_clause: str | None = None
|
||||
) -> list[ProviderResource]:
|
||||
where_and_statements = [
|
||||
f"users.tenant_id = {tenant_id}",
|
||||
"users.deleted_at IS NULL",
|
||||
]
|
||||
if filter_clause is not None:
|
||||
where_and_statements.append(filter_clause)
|
||||
where_clause = " AND ".join(where_and_statements)
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
SELECT *
|
||||
FROM public.users
|
||||
WHERE
|
||||
users.tenant_id = %(tenant_id)s
|
||||
AND users.deleted_at IS NULL
|
||||
LIMIT %(limit)s
|
||||
OFFSET %(offset)s;
|
||||
""",
|
||||
{"offset": offset, "limit": limit, "tenant_id": tenant_id},
|
||||
)
|
||||
f"""
|
||||
SELECT *
|
||||
FROM public.users
|
||||
WHERE {where_clause}
|
||||
LIMIT {limit}
|
||||
OFFSET {offset};
|
||||
"""
|
||||
)
|
||||
return cur.fetchall()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue