From b0531ef223a37451deafb8ab035068617df5c20d Mon Sep 17 00:00:00 2001 From: Jonathan Griffin Date: Wed, 23 Apr 2025 08:46:18 +0200 Subject: [PATCH] at users.updated_at, handle queries with multiple attributes, remove dead code --- ee/api/chalicelib/core/users.py | 25 +++++++++++++---- ee/api/routers/scim.py | 27 ++++++++++--------- ee/api/routers/scim_groups.py | 20 ++++++++++---- ee/api/routers/scim_helpers.py | 20 +++++++------- .../db/init_dbs/postgresql/init_schema.sql | 1 + 5 files changed, 59 insertions(+), 34 deletions(-) diff --git a/ee/api/chalicelib/core/users.py b/ee/api/chalicelib/core/users.py index f7e084b8a..173819092 100644 --- a/ee/api/chalicelib/core/users.py +++ b/ee/api/chalicelib/core/users.py @@ -350,6 +350,23 @@ def get(user_id, tenant_id): return helper.dict_to_camel_case(r) +def count_total_scim_users(tenant_id: int) -> int: + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + SELECT COUNT(*) + FROM public.users + WHERE + users.tenant_id = %(tenant_id)s + AND users.deleted_at IS NULL + """, + {"tenant_id": tenant_id}, + ) + ) + return cur.fetchone()["count"] + + def get_scim_users_paginated(start_index, tenant_id, count=None): with pg_client.PostgresClient() as cur: cur.execute( @@ -444,8 +461,7 @@ def create_scim_user( def restore_scim_user( - userId: int, - tenantId: int, + tenant_id: int, email: str, name: str = "", internal_id: str | None = None, @@ -469,7 +485,7 @@ def restore_scim_user( api_key = default, jwt_iat = NULL, weekly_report = default - WHERE users.user_id = %(user_id)s + WHERE users.email = %(email)s RETURNING * ) SELECT @@ -478,8 +494,7 @@ def restore_scim_user( FROM u LEFT JOIN public.roles USING (role_id); """, { - "tenant_id": tenantId, - "user_id": userId, + "tenant_id": tenant_id, "email": email, "name": name, "internal_id": internal_id, diff --git a/ee/api/routers/scim.py b/ee/api/routers/scim.py index c50ea41a5..230e09ddf 100644 --- a/ee/api/routers/scim.py +++ b/ee/api/routers/scim.py @@ -270,8 +270,7 @@ def _serialize_db_user_to_scim_user(db_user: dict[str, Any]) -> dict[str, Any]: "meta": { "resourceType": "User", "created": db_user["createdAt"].strftime("%Y-%m-%dT%H:%M:%SZ"), - # todo(jon): we currently don't keep track of this in the db - "lastModified": db_user["createdAt"].strftime("%Y-%m-%dT%H:%M:%SZ"), + "lastModified": db_user["updatedAt"].strftime("%Y-%m-%dT%H:%M:%SZ"), "location": f"Users/{db_user['userId']}", }, "userName": db_user["email"], @@ -322,6 +321,7 @@ RESOURCE_TYPE_TO_RESOURCE_CONFIG = { "max_items_per_page": 10, "schema_id": "urn:ietf:params:scim:schemas:core:2.0:User", "db_to_scim_serializer": _serialize_db_user_to_scim_user, + "count_total_resources": users.count_total_scim_users, "get_paginated_resources": users.get_scim_users_paginated, "get_unique_resource": users.get_scim_user_by_id, "parse_post_payload": _parse_scim_user_input, @@ -336,11 +336,13 @@ RESOURCE_TYPE_TO_RESOURCE_CONFIG = { "max_items_per_page": 10, "schema_id": "urn:ietf:params:scim:schemas:core:2.0:Group", "db_to_scim_serializer": _serialize_db_group_to_scim_group, + "count_total_resources": scim_groups.count_total_resources, "get_paginated_resources": scim_groups.get_resources_paginated, "get_unique_resource": scim_groups.get_resource_by_id, "parse_post_payload": _parse_scim_group_input, "get_resource_by_unique_values": scim_groups.get_existing_resource_by_unique_values_from_all_resources, - "restore_resource": scim_groups.restore_resource, + # note(jon): we're not soft deleting groups, so we don't need this + "restore_resource": None, "create_resource": scim_groups.create_resource, "delete_resource": scim_groups.delete_resource, "parse_put_payload": _parse_scim_group_input, @@ -360,8 +362,8 @@ async def get_resources( tenant_id=Depends(auth_required), requested_start_index: int = Query(1, alias="startIndex"), requested_items_per_page: int | None = Query(None, alias="count"), - attributes: list[str] | None = Query(None), - excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"), + attributes: str | None = Query(None), + excluded_attributes: str | None = Query(None, alias="excludedAttributes"), ): resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] start_index = max(1, requested_start_index) @@ -369,12 +371,12 @@ async def get_resources( items_per_page = min( max(0, requested_items_per_page or max_items_per_page), max_items_per_page ) - # todo(jon): this might not be the most efficient thing to do. could be better to just do a count. - # but this is the fastest thing at the moment just to test that it's working - total_resources = resource_config["get_paginated_resources"](1, tenant_id) + total_resources = resource_config["count_total_resources"](tenant_id) db_resources = resource_config["get_paginated_resources"]( start_index, tenant_id, items_per_page ) + attributes = scim_helpers.convert_query_str_to_list(attributes) + excluded_attributes = scim_helpers.convert_query_str_to_list(excluded_attributes) scim_resources = [ _serialize_db_resource_to_scim_resource_with_attribute_awareness( db_resource, @@ -388,7 +390,7 @@ async def get_resources( return JSONResponse( status_code=200, content={ - "totalResults": len(total_resources), + "totalResults": total_resources, "startIndex": start_index, "itemsPerPage": len(scim_resources), "Resources": scim_resources, @@ -446,10 +448,9 @@ async def create_resource( if existing_db_resource and existing_db_resource.get("deletedAt") is None: return _uniqueness_error_response() if existing_db_resource and existing_db_resource.get("deletedAt") is not None: - # todo(jon): not a super elegant solution overwriting the existing db resource. - # maybe we should try something else. - existing_db_resource.update(db_payload) - db_resource = resource_config["restore_resource"](**existing_db_resource) + db_resource = resource_config["restore_resource"]( + tenant_id=tenant_id, **db_payload + ) else: db_resource = resource_config["create_resource"]( tenant_id=tenant_id, diff --git a/ee/api/routers/scim_groups.py b/ee/api/routers/scim_groups.py index d80bf818d..743d53512 100644 --- a/ee/api/routers/scim_groups.py +++ b/ee/api/routers/scim_groups.py @@ -3,6 +3,21 @@ from typing import Any from chalicelib.utils import helper, pg_client +def count_total_resources(tenant_id: int) -> int: + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + SELECT COUNT(*) + FROM public.groups + WHERE groups.tenant_id = %(tenant_id)s + """, + {"tenant_id": tenant_id}, + ) + ) + return cur.fetchone()["count"] + + def get_resources_paginated( offset_one_indexed: int, tenant_id: int, limit: int | None = None ) -> list[dict[str, Any]]: @@ -66,11 +81,6 @@ def get_existing_resource_by_unique_values_from_all_resources( return None -def restore_resource(**kwargs: dict[str, Any]) -> dict[str, Any] | None: - # note(jon): we're not soft deleting groups, so we don't need this - return None - - def create_resource( name: str, tenant_id: int, **kwargs: dict[str, Any] ) -> dict[str, Any]: diff --git a/ee/api/routers/scim_helpers.py b/ee/api/routers/scim_helpers.py index cda66b29c..477becc6c 100644 --- a/ee/api/routers/scim_helpers.py +++ b/ee/api/routers/scim_helpers.py @@ -2,6 +2,12 @@ from typing import Any from copy import deepcopy +def convert_query_str_to_list(query_str: str | None) -> list[str]: + if query_str is None: + return None + return query_str.split(",") + + def get_all_attribute_names(schema: dict[str, Any]) -> list[str]: result = [] @@ -102,14 +108,10 @@ def exclude_attributes( elif isinstance(value, list): new_list = [] for item in value: - if isinstance(item, dict): - new_item = exclude_attributes(item, subs) - new_list.append(new_item) - else: - new_list.append(item) + # note(jon): `item` should always be a dict here + new_item = exclude_attributes(item, subs) + new_list.append(new_item) new_resource[key] = new_list - else: - new_resource[key] = value else: # No exclusion for this key: copy safely if isinstance(value, (dict, list)): @@ -153,8 +155,4 @@ def filter_mutable_attributes( ) # If it matches, no change is needed (already set) - else: - # Unknown mutability: default to safe behavior (ignore) - continue - return valid_changes diff --git a/ee/scripts/schema/db/init_dbs/postgresql/init_schema.sql b/ee/scripts/schema/db/init_dbs/postgresql/init_schema.sql index 96fb5a23b..521079621 100644 --- a/ee/scripts/schema/db/init_dbs/postgresql/init_schema.sql +++ b/ee/scripts/schema/db/init_dbs/postgresql/init_schema.sql @@ -142,6 +142,7 @@ CREATE TABLE public.users role user_role NOT NULL DEFAULT 'member', name text NOT NULL, created_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'), + updated_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'), deleted_at timestamp without time zone NULL DEFAULT NULL, api_key text UNIQUE DEFAULT generate_api_key(20) NOT NULL, jwt_iat timestamp without time zone NULL DEFAULT NULL,