diff --git a/ee/api/routers/scim/api.py b/ee/api/routers/scim/api.py index 2607566cb..affe33946 100644 --- a/ee/api/routers/scim/api.py +++ b/ee/api/routers/scim/api.py @@ -243,7 +243,7 @@ group_config = ResourceConfig( get_provider_resource_chunk=groups.get_provider_resource_chunk, get_provider_resource=groups.get_provider_resource, convert_client_resource_creation_input_to_provider_resource_creation_input=groups.convert_client_resource_creation_input_to_provider_resource_creation_input, - get_provider_resource_from_unique_fields=groups.get_provider_resource_from_unique_fields, + get_provider_resource_from_unique_fields=lambda **kwargs: None, restore_provider_resource=None, create_provider_resource=groups.create_provider_resource, delete_provider_resource=groups.delete_provider_resource, @@ -382,7 +382,7 @@ async def delete_resource( @public_app.put("/{resource_type}/{resource_id}") async def put_resource( resource_type: SCIMResource, - resource_id: str, + resource_id: int | str, r: Request, tenant_id=Depends(auth_required), attributes: list[str] | None = Query(None), @@ -424,7 +424,7 @@ async def put_resource( @public_app.patch("/{resource_type}/{resource_id}") async def patch_resource( resource_type: SCIMResource, - resource_id: str, + resource_id: int | str, r: Request, tenant_id=Depends(auth_required), attributes: list[str] | None = Query(None), diff --git a/ee/api/routers/scim/fixtures/group_schema.json b/ee/api/routers/scim/fixtures/group_schema.json index ddb030b92..b31dd6dd1 100644 --- a/ee/api/routers/scim/fixtures/group_schema.json +++ b/ee/api/routers/scim/fixtures/group_schema.json @@ -110,6 +110,29 @@ "returned": "default", "uniqueness": "none" }, + { + "name": "projectKeys", + "type": "complex", + "multiValued": true, + "description": "A list of project keys associated with the group.", + "required": false, + "caseExact": false, + "mutability": "readWrite", + "returned": "default", + "subAttributes": [ + { + "name": "value", + "type": "string", + "multiValued": false, + "description": "The unique project key.", + "required": true, + "mutability": "immutable", + "returned": "default", + "caseExact": true, + "uniqueness": "none" + } + ] + }, { "name": "members", "type": "complex", diff --git a/ee/api/routers/scim/groups.py b/ee/api/routers/scim/groups.py index 53c6389ff..cc113eb16 100644 --- a/ee/api/routers/scim/groups.py +++ b/ee/api/routers/scim/groups.py @@ -21,6 +21,8 @@ def convert_client_resource_update_input_to_provider_resource_update_input( if "members" in client_input: members = client_input["members"] or [] result["user_ids"] = [int(member["value"]) for member in members] + if "projectKeys" in client_input: + result["project_keys"] = [item["value"] for item in client_input["projectKeys"]] return result @@ -48,6 +50,9 @@ def convert_provider_resource_to_client_resource( } for member in members ], + "projectKeys": [ + {"value": project_key} for project_key in provider_resource["project_keys"] + ], } @@ -68,37 +73,45 @@ def get_active_resource_count(tenant_id: int) -> int: return cur.fetchone()["count"] +def _main_select_query(tenant_id: int, resource_id: int | None = None) -> str: + where_and_clauses = [ + f"roles.tenant_id = {tenant_id}", + "roles.deleted_at IS NULL", + ] + if resource_id is not None: + where_and_clauses.append(f"roles.role_id = {resource_id}") + where_clause = " AND ".join(where_and_clauses) + return f""" + SELECT + roles.*, + COALESCE( + ( + SELECT json_agg(users) + FROM public.users + WHERE users.role_id = roles.role_id + ), + '[]' + ) AS users, + COALESCE( + ( + SELECT json_agg(projects.project_key) + FROM public.projects + LEFT JOIN public.roles_projects USING (project_id) + WHERE roles_projects.role_id = roles.role_id + ), + '[]' + ) AS project_keys + FROM public.roles + WHERE {where_clause} + """ + + def get_provider_resource_chunk( offset: int, tenant_id: int, limit: int ) -> list[ProviderResource]: + query = _main_select_query(tenant_id) with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify( - """ - SELECT - roles.*, - COALESCE( - ( - SELECT json_agg(users) - FROM public.users - WHERE users.role_id = roles.role_id - ), - '[]' - ) AS users - FROM public.roles - WHERE - roles.tenant_id = %(tenant_id)s - AND roles.deleted_at IS NULL - LIMIT %(limit)s - OFFSET %(offset)s; - """, - { - "offset": offset, - "limit": limit, - "tenant_id": tenant_id, - }, - ) - ) + cur.execute(f"{query} LIMIT {limit} OFFSET {offset}") return cur.fetchall() @@ -106,40 +119,10 @@ def get_provider_resource( resource_id: ResourceId, tenant_id: int ) -> ProviderResource | None: with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify( - """ - SELECT - roles.*, - COALESCE( - ( - SELECT json_agg(users) - FROM public.users - WHERE users.role_id = roles.role_id - ), - '[]' - ) AS users - FROM public.roles - WHERE - roles.tenant_id = %(tenant_id)s - AND roles.role_id = %(resource_id)s - AND roles.deleted_at IS NULL - LIMIT 1; - """, - {"resource_id": resource_id, "tenant_id": tenant_id}, - ) - ) + cur.execute(f"{_main_select_query(tenant_id, resource_id)} LIMIT 1") return cur.fetchone() -def get_provider_resource_from_unique_fields( - **kwargs: dict[str, Any], -) -> ProviderResource | None: - # note(jon): we do not really use this for scim.groups (openreplay.roles) as we don't have unique values outside - # of the primary key - return None - - def convert_client_resource_creation_input_to_provider_resource_creation_input( tenant_id: int, client_input: ClientInput ) -> ProviderInput: @@ -148,6 +131,7 @@ def convert_client_resource_creation_input_to_provider_resource_creation_input( "user_ids": [ int(member["value"]) for member in client_input.get("members", []) ], + "project_keys": [item["value"] for item in client_input.get("projectKeys", [])], } @@ -159,6 +143,7 @@ def convert_client_resource_rewrite_input_to_provider_resource_rewrite_input( "user_ids": [ int(member["value"]) for member in client_input.get("members", []) ], + "project_keys": [item["value"] for item in client_input.get("projectKeys", [])], } @@ -166,6 +151,7 @@ def create_provider_resource( name: str, tenant_id: int, user_ids: list[str] | None = None, + project_keys: list[str] | None = None, **kwargs: dict[str, Any], ) -> ProviderResource: with pg_client.PostgresClient() as cur: @@ -184,35 +170,42 @@ def create_provider_resource( cur.mogrify("%s", (user_id,)).decode("utf-8") for user_id in user_ids ] user_id_clause = f"ARRAY[{', '.join(user_id_fragments)}]::int[]" + project_keys = project_keys or [] + project_key_fragments = [ + cur.mogrify("%s", (project_key,)).decode("utf-8") + for project_key in project_keys + ] + project_key_clause = f"ARRAY[{', '.join(project_key_fragments)}]::varchar[]" cur.execute( f""" - WITH - r AS ( - INSERT INTO public.roles ({column_clause}) - VALUES ({value_clause}) - RETURNING * - ), - linked_users AS ( - UPDATE public.users - SET - updated_at = now(), - role_id = (SELECT r.role_id FROM r) - WHERE users.user_id = ANY({user_id_clause}) - RETURNING * - ) - SELECT - r.*, - COALESCE( - ( - SELECT json_agg(linked_users.*) - FROM linked_users - ), - '[]' - ) AS users - FROM r - LIMIT 1; + INSERT INTO public.roles ({column_clause}) + VALUES ({value_clause}) + RETURNING role_id """ ) + role_id = cur.fetchone()["role_id"] + cur.execute( + f""" + UPDATE public.users + SET + updated_at = now(), + role_id = {role_id} + WHERE users.user_id = ANY({user_id_clause}) + """ + ) + cur.execute( + f""" + WITH ps AS ( + SELECT * + FROM public.projects + WHERE projects.project_key = ANY({project_key_clause}) + ) + INSERT INTO public.roles_projects (role_id, project_id) + SELECT {role_id}, ps.project_id + FROM ps + """ + ) + cur.execute(f"{_main_select_query(tenant_id, role_id)} LIMIT 1") return cur.fetchone() @@ -220,6 +213,7 @@ def _update_resource_sql( resource_id: int, tenant_id: int, user_ids: list[int] | None = None, + project_keys: list[str] | None = None, **kwargs: dict[str, Any], ) -> dict[str, Any]: with pg_client.PostgresClient() as cur: @@ -234,46 +228,68 @@ def _update_resource_sql( cur.mogrify("%s", (user_id,)).decode("utf-8") for user_id in user_ids ] user_id_clause = f"ARRAY[{', '.join(user_id_fragments)}]::int[]" + project_keys = project_keys or [] + project_key_fragments = [ + cur.mogrify("%s", (project_key,)).decode("utf-8") + for project_key in project_keys + ] + project_key_clause = f"ARRAY[{', '.join(project_key_fragments)}]::varchar[]" cur.execute( f""" UPDATE public.users - SET role_id = NULL - WHERE users.role_id = {resource_id} + SET + updated_at = now(), + role_id = NULL + WHERE + users.role_id = {resource_id} + AND users.user_id != ALL({user_id_clause}) + RETURNING * """ ) cur.execute( f""" - WITH - r AS ( - UPDATE public.roles - SET {set_clause} - WHERE - roles.role_id = {resource_id} - AND roles.tenant_id = {tenant_id} - AND roles.deleted_at IS NULL - RETURNING * - ), - linked_users AS ( - UPDATE public.users - SET - updated_at = now(), - role_id = {resource_id} - WHERE users.user_id = ANY({user_id_clause}) - RETURNING * - ) - SELECT - r.*, - COALESCE( - ( - SELECT json_agg(linked_users.*) - FROM linked_users - ), - '[]' - ) AS users - FROM r - LIMIT 1; + UPDATE public.users + SET + updated_at = now(), + role_id = {resource_id} + WHERE + (users.role_id != {resource_id} OR users.role_id IS NULL) + AND users.user_id = ANY({user_id_clause}) + RETURNING * """ ) + cur.execute( + f""" + DELETE FROM public.roles_projects + USING public.projects + WHERE + projects.project_id = roles_projects.project_id + AND roles_projects.role_id = {resource_id} + AND projects.project_key != ALL({project_key_clause}) + """ + ) + cur.execute( + f""" + INSERT INTO public.roles_projects (role_id, project_id) + SELECT {resource_id}, projects.project_id + FROM public.projects + LEFT JOIN public.roles_projects USING (project_id) + WHERE + projects.project_key = ANY({project_key_clause}) + AND roles_projects.role_id IS NULL + """ + ) + cur.execute( + f""" + UPDATE public.roles + SET {set_clause} + WHERE + roles.role_id = {resource_id} + AND roles.tenant_id = {tenant_id} + AND roles.deleted_at IS NULL + """ + ) + cur.execute(f"{_main_select_query(tenant_id, resource_id)} LIMIT 1") return cur.fetchone() @@ -285,22 +301,6 @@ def delete_provider_resource(resource_id: ResourceId, tenant_id: int) -> None: ) -def restore_provider_resource( - resource_id: int, - tenant_id: int, - name: str, - **kwargs: dict[str, Any], -) -> dict[str, Any]: - return _update_resource_sql( - resource_id=resource_id, - tenant_id=tenant_id, - name=name, - created_at=datetime.now(), - deleted_at=None, - **kwargs, - ) - - def rewrite_provider_resource( resource_id: int, tenant_id: int, diff --git a/ee/api/routers/scim/users.py b/ee/api/routers/scim/users.py index ccd17e4df..1a7fd1d17 100644 --- a/ee/api/routers/scim/users.py +++ b/ee/api/routers/scim/users.py @@ -3,7 +3,6 @@ from datetime import datetime from psycopg2.extensions import AsIs from chalicelib.utils import pg_client -from chalicelib.core import roles from routers.scim.resource_config import ( ProviderResource, ClientResource, @@ -35,21 +34,21 @@ def convert_client_resource_update_input_to_provider_resource_update_input( def convert_client_resource_rewrite_input_to_provider_resource_rewrite_input( tenant_id: int, client_input: ClientInput ) -> ProviderInput: - name = client_input.get("name", {}).get("formatted") - if not name: - name = " ".join( - [ - x - for x in [ - client_input.get("name", {}).get("honorificPrefix"), - client_input.get("name", {}).get("givenName"), - client_input.get("name", {}).get("middleName"), - client_input.get("name", {}).get("familyName"), - client_input.get("name", {}).get("honorificSuffix"), - ] - if x + name = " ".join( + [ + x + for x in [ + client_input.get("name", {}).get("honorificPrefix"), + client_input.get("name", {}).get("givenName"), + client_input.get("name", {}).get("middleName"), + client_input.get("name", {}).get("familyName"), + client_input.get("name", {}).get("honorificSuffix"), ] - ) + if x + ] + ) + if not name: + name = client_input.get("displayName") result = { "email": client_input["userName"], "internal_id": client_input.get("externalId"), @@ -62,21 +61,21 @@ def convert_client_resource_rewrite_input_to_provider_resource_rewrite_input( def convert_client_resource_creation_input_to_provider_resource_creation_input( tenant_id: int, client_input: ClientInput ) -> ProviderInput: - name = client_input.get("name", {}).get("formatted") - if not name: - name = " ".join( - [ - x - for x in [ - client_input.get("name", {}).get("honorificPrefix"), - client_input.get("name", {}).get("givenName"), - client_input.get("name", {}).get("middleName"), - client_input.get("name", {}).get("familyName"), - client_input.get("name", {}).get("honorificSuffix"), - ] - if x + name = " ".join( + [ + x + for x in [ + client_input.get("name", {}).get("honorificPrefix"), + client_input.get("name", {}).get("givenName"), + client_input.get("name", {}).get("middleName"), + client_input.get("name", {}).get("familyName"), + client_input.get("name", {}).get("honorificSuffix"), ] - ) + if x + ] + ) + if not name: + name = client_input.get("displayName") result = { "email": client_input["userName"], "internal_id": client_input.get("externalId"),