diff --git a/ee/api/chalicelib/core/users.py b/ee/api/chalicelib/core/users.py index 173819092..b1b014735 100644 --- a/ee/api/chalicelib/core/users.py +++ b/ee/api/chalicelib/core/users.py @@ -2,9 +2,11 @@ import json import logging import secrets from typing import Optional +from datetime import datetime from decouple import config from fastapi import BackgroundTasks, HTTPException +from psycopg2.extensions import AsIs from psycopg2.extras import Json from pydantic import BaseModel, model_validator from starlette import status @@ -482,6 +484,7 @@ def restore_scim_user( role_id = %(role_id)s, deleted_at = NULL, created_at = default, + updated_at = default, api_key = default, jwt_iat = NULL, weekly_report = default @@ -523,7 +526,8 @@ def update_scim_user( email = %(email)s, name = %(name)s, internal_id = %(internal_id)s, - role_id = %(role_id)s + role_id = %(role_id)s, + updated_at = default WHERE users.user_id = %(user_id)s AND users.tenant_id = %(tenant_id)s @@ -548,6 +552,39 @@ def update_scim_user( return helper.dict_to_camel_case(cur.fetchone()) +def patch_scim_user( + user_id: int, + tenant_id: int, + **kwargs, +): + with pg_client.PostgresClient() as cur: + set_fragments = [] + kwargs["updated_at"] = datetime.now() + for k, v in kwargs.items(): + fragment = cur.mogrify( + "%s = %s", + (AsIs(k), v), + ).decode("utf-8") + set_fragments.append(fragment) + set_clause = ", ".join(set_fragments) + query = f""" + WITH u AS ( + UPDATE public.users + SET {set_clause} + WHERE + users.user_id = {user_id} + AND users.tenant_id = {tenant_id} + AND users.deleted_at IS NULL + RETURNING * + ) + SELECT + u.*, + roles.name as role_name + FROM u LEFT JOIN public.roles USING (role_id);""" + cur.execute(query) + return helper.dict_to_camel_case(cur.fetchone()) + + def generate_new_api_key(user_id): with pg_client.PostgresClient() as cur: cur.execute( @@ -1224,7 +1261,9 @@ def soft_delete_scim_user_by_id(user_id, tenant_id): cur.mogrify( """ UPDATE public.users - SET deleted_at = NULL + SET + deleted_at = NULL, + updated_at = default WHERE users.user_id = %(user_id)s AND users.tenant_id = %(tenant_id)s diff --git a/ee/api/routers/fixtures/service_provider_config.json b/ee/api/routers/fixtures/service_provider_config.json index ebb37bec4..dbcbff942 100644 --- a/ee/api/routers/fixtures/service_provider_config.json +++ b/ee/api/routers/fixtures/service_provider_config.json @@ -3,7 +3,7 @@ "urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig" ], "patch": { - "supported": false + "supported": true }, "bulk": { "supported": false, diff --git a/ee/api/routers/scim.py b/ee/api/routers/scim.py index 230e09ddf..a00822969 100644 --- a/ee/api/routers/scim.py +++ b/ee/api/routers/scim.py @@ -2,6 +2,7 @@ from copy import deepcopy import logging from typing import Any, Callable from enum import Enum +from datetime import datetime from decouple import config from fastapi import Depends, HTTPException, Header, Query, Response, Request @@ -263,6 +264,15 @@ def _parse_scim_user_input(data: dict[str, Any], tenant_id: str) -> dict[str, An return result +def _parse_user_patch_operations(data: dict[str, Any]) -> dict[str, Any]: + result = {} + operations = data["Operations"] + for operation in operations: + if operation["op"] == "replace" and "active" in operation["value"]: + result["deleted_at"] = None if operation["value"]["active"] is True else datetime.now() + return result + + def _serialize_db_user_to_scim_user(db_user: dict[str, Any]) -> dict[str, Any]: return { "id": str(db_user["userId"]), @@ -331,6 +341,8 @@ RESOURCE_TYPE_TO_RESOURCE_CONFIG = { "delete_resource": users.soft_delete_scim_user_by_id, "parse_put_payload": _parse_scim_user_input, "update_resource": users.update_scim_user, + "parse_patch_operations": _parse_user_patch_operations, + "patch_resource": users.patch_scim_user, }, "Groups": { "max_items_per_page": 10, @@ -430,6 +442,7 @@ class PostResourceType(str, Enum): GROUPS = "Groups" + @public_app.post("/{resource_type}") async def create_resource( resource_type: PostResourceType, @@ -492,7 +505,7 @@ class PutResourceType(str, Enum): @public_app.put("/{resource_type}/{resource_id}") -async def update_resource( +async def put_resource( resource_type: PutResourceType, resource_id: str, r: Request, @@ -539,3 +552,44 @@ async def update_resource( return _uniqueness_error_response() except Exception as e: return _internal_server_error_response(str(e)) + + +class PatchResourceType(str, Enum): + USERS = "Users" + + +@public_app.patch("/{resource_type}/{resource_id}") +async def patch_resource( + resource_type: PatchResourceType, + resource_id: str, + r: Request, + tenant_id=Depends(auth_required), +): + resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type] + db_resource = resource_config["get_unique_resource"](resource_id, tenant_id) + if not db_resource: + return _not_found_error_response(resource_id) + current_scim_resource = ( + _serialize_db_resource_to_scim_resource_with_attribute_awareness( + db_resource, + resource_config["schema_id"], + resource_config["db_to_scim_serializer"], + ) + ) + payload = await r.json() + parsed_payload = resource_config["parse_patch_operations"](payload) + # note(jon): we don't need to handle uniqueness contraints and etc. like in PUT + # because we are only covering the User resource and the field `active` + updated_db_resource = resource_config["patch_resource"]( + resource_id, + tenant_id, + **parsed_payload, + ) + updated_scim_resource = ( + _serialize_db_resource_to_scim_resource_with_attribute_awareness( + updated_db_resource, + resource_config["schema_id"], + resource_config["db_to_scim_serializer"], + ) + ) + return JSONResponse(status_code=200, content=updated_scim_resource)