at users.updated_at, handle queries with multiple attributes, remove dead code

This commit is contained in:
Jonathan Griffin 2025-04-23 08:46:18 +02:00
parent 9057637b84
commit b0531ef223
5 changed files with 59 additions and 34 deletions

View file

@ -350,6 +350,23 @@ def get(user_id, tenant_id):
return helper.dict_to_camel_case(r)
def count_total_scim_users(tenant_id: int) -> int:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT COUNT(*)
FROM public.users
WHERE
users.tenant_id = %(tenant_id)s
AND users.deleted_at IS NULL
""",
{"tenant_id": tenant_id},
)
)
return cur.fetchone()["count"]
def get_scim_users_paginated(start_index, tenant_id, count=None):
with pg_client.PostgresClient() as cur:
cur.execute(
@ -444,8 +461,7 @@ def create_scim_user(
def restore_scim_user(
userId: int,
tenantId: int,
tenant_id: int,
email: str,
name: str = "",
internal_id: str | None = None,
@ -469,7 +485,7 @@ def restore_scim_user(
api_key = default,
jwt_iat = NULL,
weekly_report = default
WHERE users.user_id = %(user_id)s
WHERE users.email = %(email)s
RETURNING *
)
SELECT
@ -478,8 +494,7 @@ def restore_scim_user(
FROM u LEFT JOIN public.roles USING (role_id);
""",
{
"tenant_id": tenantId,
"user_id": userId,
"tenant_id": tenant_id,
"email": email,
"name": name,
"internal_id": internal_id,

View file

@ -270,8 +270,7 @@ def _serialize_db_user_to_scim_user(db_user: dict[str, Any]) -> dict[str, Any]:
"meta": {
"resourceType": "User",
"created": db_user["createdAt"].strftime("%Y-%m-%dT%H:%M:%SZ"),
# todo(jon): we currently don't keep track of this in the db
"lastModified": db_user["createdAt"].strftime("%Y-%m-%dT%H:%M:%SZ"),
"lastModified": db_user["updatedAt"].strftime("%Y-%m-%dT%H:%M:%SZ"),
"location": f"Users/{db_user['userId']}",
},
"userName": db_user["email"],
@ -322,6 +321,7 @@ RESOURCE_TYPE_TO_RESOURCE_CONFIG = {
"max_items_per_page": 10,
"schema_id": "urn:ietf:params:scim:schemas:core:2.0:User",
"db_to_scim_serializer": _serialize_db_user_to_scim_user,
"count_total_resources": users.count_total_scim_users,
"get_paginated_resources": users.get_scim_users_paginated,
"get_unique_resource": users.get_scim_user_by_id,
"parse_post_payload": _parse_scim_user_input,
@ -336,11 +336,13 @@ RESOURCE_TYPE_TO_RESOURCE_CONFIG = {
"max_items_per_page": 10,
"schema_id": "urn:ietf:params:scim:schemas:core:2.0:Group",
"db_to_scim_serializer": _serialize_db_group_to_scim_group,
"count_total_resources": scim_groups.count_total_resources,
"get_paginated_resources": scim_groups.get_resources_paginated,
"get_unique_resource": scim_groups.get_resource_by_id,
"parse_post_payload": _parse_scim_group_input,
"get_resource_by_unique_values": scim_groups.get_existing_resource_by_unique_values_from_all_resources,
"restore_resource": scim_groups.restore_resource,
# note(jon): we're not soft deleting groups, so we don't need this
"restore_resource": None,
"create_resource": scim_groups.create_resource,
"delete_resource": scim_groups.delete_resource,
"parse_put_payload": _parse_scim_group_input,
@ -360,8 +362,8 @@ async def get_resources(
tenant_id=Depends(auth_required),
requested_start_index: int = Query(1, alias="startIndex"),
requested_items_per_page: int | None = Query(None, alias="count"),
attributes: list[str] | None = Query(None),
excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"),
attributes: str | None = Query(None),
excluded_attributes: str | None = Query(None, alias="excludedAttributes"),
):
resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
start_index = max(1, requested_start_index)
@ -369,12 +371,12 @@ async def get_resources(
items_per_page = min(
max(0, requested_items_per_page or max_items_per_page), max_items_per_page
)
# todo(jon): this might not be the most efficient thing to do. could be better to just do a count.
# but this is the fastest thing at the moment just to test that it's working
total_resources = resource_config["get_paginated_resources"](1, tenant_id)
total_resources = resource_config["count_total_resources"](tenant_id)
db_resources = resource_config["get_paginated_resources"](
start_index, tenant_id, items_per_page
)
attributes = scim_helpers.convert_query_str_to_list(attributes)
excluded_attributes = scim_helpers.convert_query_str_to_list(excluded_attributes)
scim_resources = [
_serialize_db_resource_to_scim_resource_with_attribute_awareness(
db_resource,
@ -388,7 +390,7 @@ async def get_resources(
return JSONResponse(
status_code=200,
content={
"totalResults": len(total_resources),
"totalResults": total_resources,
"startIndex": start_index,
"itemsPerPage": len(scim_resources),
"Resources": scim_resources,
@ -446,10 +448,9 @@ async def create_resource(
if existing_db_resource and existing_db_resource.get("deletedAt") is None:
return _uniqueness_error_response()
if existing_db_resource and existing_db_resource.get("deletedAt") is not None:
# todo(jon): not a super elegant solution overwriting the existing db resource.
# maybe we should try something else.
existing_db_resource.update(db_payload)
db_resource = resource_config["restore_resource"](**existing_db_resource)
db_resource = resource_config["restore_resource"](
tenant_id=tenant_id, **db_payload
)
else:
db_resource = resource_config["create_resource"](
tenant_id=tenant_id,

View file

@ -3,6 +3,21 @@ from typing import Any
from chalicelib.utils import helper, pg_client
def count_total_resources(tenant_id: int) -> int:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT COUNT(*)
FROM public.groups
WHERE groups.tenant_id = %(tenant_id)s
""",
{"tenant_id": tenant_id},
)
)
return cur.fetchone()["count"]
def get_resources_paginated(
offset_one_indexed: int, tenant_id: int, limit: int | None = None
) -> list[dict[str, Any]]:
@ -66,11 +81,6 @@ def get_existing_resource_by_unique_values_from_all_resources(
return None
def restore_resource(**kwargs: dict[str, Any]) -> dict[str, Any] | None:
# note(jon): we're not soft deleting groups, so we don't need this
return None
def create_resource(
name: str, tenant_id: int, **kwargs: dict[str, Any]
) -> dict[str, Any]:

View file

@ -2,6 +2,12 @@ from typing import Any
from copy import deepcopy
def convert_query_str_to_list(query_str: str | None) -> list[str]:
if query_str is None:
return None
return query_str.split(",")
def get_all_attribute_names(schema: dict[str, Any]) -> list[str]:
result = []
@ -102,14 +108,10 @@ def exclude_attributes(
elif isinstance(value, list):
new_list = []
for item in value:
if isinstance(item, dict):
# note(jon): `item` should always be a dict here
new_item = exclude_attributes(item, subs)
new_list.append(new_item)
else:
new_list.append(item)
new_resource[key] = new_list
else:
new_resource[key] = value
else:
# No exclusion for this key: copy safely
if isinstance(value, (dict, list)):
@ -153,8 +155,4 @@ def filter_mutable_attributes(
)
# If it matches, no change is needed (already set)
else:
# Unknown mutability: default to safe behavior (ignore)
continue
return valid_changes

View file

@ -142,6 +142,7 @@ CREATE TABLE public.users
role user_role NOT NULL DEFAULT 'member',
name text NOT NULL,
created_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'),
updated_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'),
deleted_at timestamp without time zone NULL DEFAULT NULL,
api_key text UNIQUE DEFAULT generate_api_key(20) NOT NULL,
jwt_iat timestamp without time zone NULL DEFAULT NULL,