at users.updated_at, handle queries with multiple attributes, remove dead code
This commit is contained in:
parent
9057637b84
commit
b0531ef223
5 changed files with 59 additions and 34 deletions
|
|
@ -350,6 +350,23 @@ def get(user_id, tenant_id):
|
|||
return helper.dict_to_camel_case(r)
|
||||
|
||||
|
||||
def count_total_scim_users(tenant_id: int) -> int:
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM public.users
|
||||
WHERE
|
||||
users.tenant_id = %(tenant_id)s
|
||||
AND users.deleted_at IS NULL
|
||||
""",
|
||||
{"tenant_id": tenant_id},
|
||||
)
|
||||
)
|
||||
return cur.fetchone()["count"]
|
||||
|
||||
|
||||
def get_scim_users_paginated(start_index, tenant_id, count=None):
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(
|
||||
|
|
@ -444,8 +461,7 @@ def create_scim_user(
|
|||
|
||||
|
||||
def restore_scim_user(
|
||||
userId: int,
|
||||
tenantId: int,
|
||||
tenant_id: int,
|
||||
email: str,
|
||||
name: str = "",
|
||||
internal_id: str | None = None,
|
||||
|
|
@ -469,7 +485,7 @@ def restore_scim_user(
|
|||
api_key = default,
|
||||
jwt_iat = NULL,
|
||||
weekly_report = default
|
||||
WHERE users.user_id = %(user_id)s
|
||||
WHERE users.email = %(email)s
|
||||
RETURNING *
|
||||
)
|
||||
SELECT
|
||||
|
|
@ -478,8 +494,7 @@ def restore_scim_user(
|
|||
FROM u LEFT JOIN public.roles USING (role_id);
|
||||
""",
|
||||
{
|
||||
"tenant_id": tenantId,
|
||||
"user_id": userId,
|
||||
"tenant_id": tenant_id,
|
||||
"email": email,
|
||||
"name": name,
|
||||
"internal_id": internal_id,
|
||||
|
|
|
|||
|
|
@ -270,8 +270,7 @@ def _serialize_db_user_to_scim_user(db_user: dict[str, Any]) -> dict[str, Any]:
|
|||
"meta": {
|
||||
"resourceType": "User",
|
||||
"created": db_user["createdAt"].strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
# todo(jon): we currently don't keep track of this in the db
|
||||
"lastModified": db_user["createdAt"].strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
"lastModified": db_user["updatedAt"].strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
"location": f"Users/{db_user['userId']}",
|
||||
},
|
||||
"userName": db_user["email"],
|
||||
|
|
@ -322,6 +321,7 @@ RESOURCE_TYPE_TO_RESOURCE_CONFIG = {
|
|||
"max_items_per_page": 10,
|
||||
"schema_id": "urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
"db_to_scim_serializer": _serialize_db_user_to_scim_user,
|
||||
"count_total_resources": users.count_total_scim_users,
|
||||
"get_paginated_resources": users.get_scim_users_paginated,
|
||||
"get_unique_resource": users.get_scim_user_by_id,
|
||||
"parse_post_payload": _parse_scim_user_input,
|
||||
|
|
@ -336,11 +336,13 @@ RESOURCE_TYPE_TO_RESOURCE_CONFIG = {
|
|||
"max_items_per_page": 10,
|
||||
"schema_id": "urn:ietf:params:scim:schemas:core:2.0:Group",
|
||||
"db_to_scim_serializer": _serialize_db_group_to_scim_group,
|
||||
"count_total_resources": scim_groups.count_total_resources,
|
||||
"get_paginated_resources": scim_groups.get_resources_paginated,
|
||||
"get_unique_resource": scim_groups.get_resource_by_id,
|
||||
"parse_post_payload": _parse_scim_group_input,
|
||||
"get_resource_by_unique_values": scim_groups.get_existing_resource_by_unique_values_from_all_resources,
|
||||
"restore_resource": scim_groups.restore_resource,
|
||||
# note(jon): we're not soft deleting groups, so we don't need this
|
||||
"restore_resource": None,
|
||||
"create_resource": scim_groups.create_resource,
|
||||
"delete_resource": scim_groups.delete_resource,
|
||||
"parse_put_payload": _parse_scim_group_input,
|
||||
|
|
@ -360,8 +362,8 @@ async def get_resources(
|
|||
tenant_id=Depends(auth_required),
|
||||
requested_start_index: int = Query(1, alias="startIndex"),
|
||||
requested_items_per_page: int | None = Query(None, alias="count"),
|
||||
attributes: list[str] | None = Query(None),
|
||||
excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"),
|
||||
attributes: str | None = Query(None),
|
||||
excluded_attributes: str | None = Query(None, alias="excludedAttributes"),
|
||||
):
|
||||
resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
|
||||
start_index = max(1, requested_start_index)
|
||||
|
|
@ -369,12 +371,12 @@ async def get_resources(
|
|||
items_per_page = min(
|
||||
max(0, requested_items_per_page or max_items_per_page), max_items_per_page
|
||||
)
|
||||
# todo(jon): this might not be the most efficient thing to do. could be better to just do a count.
|
||||
# but this is the fastest thing at the moment just to test that it's working
|
||||
total_resources = resource_config["get_paginated_resources"](1, tenant_id)
|
||||
total_resources = resource_config["count_total_resources"](tenant_id)
|
||||
db_resources = resource_config["get_paginated_resources"](
|
||||
start_index, tenant_id, items_per_page
|
||||
)
|
||||
attributes = scim_helpers.convert_query_str_to_list(attributes)
|
||||
excluded_attributes = scim_helpers.convert_query_str_to_list(excluded_attributes)
|
||||
scim_resources = [
|
||||
_serialize_db_resource_to_scim_resource_with_attribute_awareness(
|
||||
db_resource,
|
||||
|
|
@ -388,7 +390,7 @@ async def get_resources(
|
|||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"totalResults": len(total_resources),
|
||||
"totalResults": total_resources,
|
||||
"startIndex": start_index,
|
||||
"itemsPerPage": len(scim_resources),
|
||||
"Resources": scim_resources,
|
||||
|
|
@ -446,10 +448,9 @@ async def create_resource(
|
|||
if existing_db_resource and existing_db_resource.get("deletedAt") is None:
|
||||
return _uniqueness_error_response()
|
||||
if existing_db_resource and existing_db_resource.get("deletedAt") is not None:
|
||||
# todo(jon): not a super elegant solution overwriting the existing db resource.
|
||||
# maybe we should try something else.
|
||||
existing_db_resource.update(db_payload)
|
||||
db_resource = resource_config["restore_resource"](**existing_db_resource)
|
||||
db_resource = resource_config["restore_resource"](
|
||||
tenant_id=tenant_id, **db_payload
|
||||
)
|
||||
else:
|
||||
db_resource = resource_config["create_resource"](
|
||||
tenant_id=tenant_id,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,21 @@ from typing import Any
|
|||
from chalicelib.utils import helper, pg_client
|
||||
|
||||
|
||||
def count_total_resources(tenant_id: int) -> int:
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM public.groups
|
||||
WHERE groups.tenant_id = %(tenant_id)s
|
||||
""",
|
||||
{"tenant_id": tenant_id},
|
||||
)
|
||||
)
|
||||
return cur.fetchone()["count"]
|
||||
|
||||
|
||||
def get_resources_paginated(
|
||||
offset_one_indexed: int, tenant_id: int, limit: int | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
|
|
@ -66,11 +81,6 @@ def get_existing_resource_by_unique_values_from_all_resources(
|
|||
return None
|
||||
|
||||
|
||||
def restore_resource(**kwargs: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# note(jon): we're not soft deleting groups, so we don't need this
|
||||
return None
|
||||
|
||||
|
||||
def create_resource(
|
||||
name: str, tenant_id: int, **kwargs: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,12 @@ from typing import Any
|
|||
from copy import deepcopy
|
||||
|
||||
|
||||
def convert_query_str_to_list(query_str: str | None) -> list[str]:
|
||||
if query_str is None:
|
||||
return None
|
||||
return query_str.split(",")
|
||||
|
||||
|
||||
def get_all_attribute_names(schema: dict[str, Any]) -> list[str]:
|
||||
result = []
|
||||
|
||||
|
|
@ -102,14 +108,10 @@ def exclude_attributes(
|
|||
elif isinstance(value, list):
|
||||
new_list = []
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
new_item = exclude_attributes(item, subs)
|
||||
new_list.append(new_item)
|
||||
else:
|
||||
new_list.append(item)
|
||||
# note(jon): `item` should always be a dict here
|
||||
new_item = exclude_attributes(item, subs)
|
||||
new_list.append(new_item)
|
||||
new_resource[key] = new_list
|
||||
else:
|
||||
new_resource[key] = value
|
||||
else:
|
||||
# No exclusion for this key: copy safely
|
||||
if isinstance(value, (dict, list)):
|
||||
|
|
@ -153,8 +155,4 @@ def filter_mutable_attributes(
|
|||
)
|
||||
# If it matches, no change is needed (already set)
|
||||
|
||||
else:
|
||||
# Unknown mutability: default to safe behavior (ignore)
|
||||
continue
|
||||
|
||||
return valid_changes
|
||||
|
|
|
|||
|
|
@ -142,6 +142,7 @@ CREATE TABLE public.users
|
|||
role user_role NOT NULL DEFAULT 'member',
|
||||
name text NOT NULL,
|
||||
created_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'),
|
||||
updated_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'),
|
||||
deleted_at timestamp without time zone NULL DEFAULT NULL,
|
||||
api_key text UNIQUE DEFAULT generate_api_key(20) NOT NULL,
|
||||
jwt_iat timestamp without time zone NULL DEFAULT NULL,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue