diff --git a/ee/api/routers/scim/fixtures/group_schema.json b/ee/api/routers/scim/fixtures/group_schema.json index b31dd6dd1..ddb030b92 100644 --- a/ee/api/routers/scim/fixtures/group_schema.json +++ b/ee/api/routers/scim/fixtures/group_schema.json @@ -110,29 +110,6 @@ "returned": "default", "uniqueness": "none" }, - { - "name": "projectKeys", - "type": "complex", - "multiValued": true, - "description": "A list of project keys associated with the group.", - "required": false, - "caseExact": false, - "mutability": "readWrite", - "returned": "default", - "subAttributes": [ - { - "name": "value", - "type": "string", - "multiValued": false, - "description": "The unique project key.", - "required": true, - "mutability": "immutable", - "returned": "default", - "caseExact": true, - "uniqueness": "none" - } - ] - }, { "name": "members", "type": "complex", diff --git a/ee/api/routers/scim/fixtures/user_schema.json b/ee/api/routers/scim/fixtures/user_schema.json index c80a084c5..528c4a69a 100644 --- a/ee/api/routers/scim/fixtures/user_schema.json +++ b/ee/api/routers/scim/fixtures/user_schema.json @@ -334,18 +334,14 @@ }, { "name": "entitlements", - "type": "complex", + "type": "string", "multiValued": true, "description": "Entitlements granted to the user.", "required": false, + "caseExact": true, + "canonicalValues": ["SESSION_REPLAY", "METRICS", "ASSIST_LIVE", "ASSIST_CALL", "SPOT_PUBLIC"], "mutability": "readWrite", - "returned": "default", - "subAttributes": [ - { "name": "value", "type": "string", "multiValued": false, "description": "Entitlement value.", "required": false, "caseExact": false, "mutability": "readWrite", "returned": "default", "uniqueness": "none" }, - { "name": "display", "type": "string", "multiValued": false, "description": "Display name.", "required": false, "caseExact": false, "mutability": "readWrite", "returned": "default", "uniqueness": "none" }, - { "name": "type", "type": "string", "multiValued": false, "description": "Type label.", "required": false, "caseExact": false, "mutability": "readWrite", "returned": "default", "uniqueness": "none" }, - { "name": "primary", "type": "boolean", "multiValued": false, "description": "Primary flag; one per list.", "required": false, "mutability": "readWrite", "returned": "default" } - ] + "returned": "default" }, { "name": "roles", @@ -376,7 +372,17 @@ { "name": "type", "type": "string", "multiValued": false, "description": "Type label.", "required": false, "caseExact": false, "canonicalValues": [], "mutability": "readWrite", "returned": "default", "uniqueness": "none" }, { "name": "primary", "type": "boolean", "multiValued": false, "description": "Primary flag; one per list.", "required": false, "mutability": "readWrite", "returned": "default" } ] - } + }, + { + "name": "projectKeys", + "type": "string", + "multiValued": true, + "description": "A list of project keys associated with the group.", + "required": false, + "caseExact": false, + "mutability": "readWrite", + "returned": "default" + } ], "meta": { "resourceType": "Schema", diff --git a/ee/api/routers/scim/groups.py b/ee/api/routers/scim/groups.py index 09d36b231..8eb9447db 100644 --- a/ee/api/routers/scim/groups.py +++ b/ee/api/routers/scim/groups.py @@ -3,6 +3,7 @@ from datetime import datetime from psycopg2.extensions import AsIs from chalicelib.utils import pg_client +from routers.scim import helpers from routers.scim.resource_config import ( ProviderResource, ClientResource, @@ -21,8 +22,6 @@ def convert_client_resource_update_input_to_provider_resource_update_input( if "members" in client_input: members = client_input["members"] or [] result["user_ids"] = [int(member["value"]) for member in members] - if "projectKeys" in client_input: - result["project_keys"] = [item["value"] for item in client_input["projectKeys"]] return result @@ -50,9 +49,6 @@ def convert_provider_resource_to_client_resource( } for member in members ], - "projectKeys": [ - {"value": project_key} for project_key in provider_resource["project_keys"] - ], } @@ -141,7 +137,6 @@ def convert_client_resource_creation_input_to_provider_resource_creation_input( "user_ids": [ int(member["value"]) for member in client_input.get("members", []) ], - "project_keys": [item["value"] for item in client_input.get("projectKeys", [])], } @@ -153,7 +148,6 @@ def convert_client_resource_rewrite_input_to_provider_resource_rewrite_input( "user_ids": [ int(member["value"]) for member in client_input.get("members", []) ], - "project_keys": [item["value"] for item in client_input.get("projectKeys", [])], } @@ -161,7 +155,6 @@ def create_provider_resource( name: str, tenant_id: int, user_ids: list[str] | None = None, - project_keys: list[str] | None = None, **kwargs: dict[str, Any], ) -> ProviderResource: with pg_client.PostgresClient() as cur: @@ -175,17 +168,7 @@ def create_provider_resource( cur.mogrify("%s", (v,)).decode("utf-8") for v in kwargs.values() ] value_clause = ", ".join(value_fragments) - user_ids = user_ids or [] - user_id_fragments = [ - cur.mogrify("%s", (user_id,)).decode("utf-8") for user_id in user_ids - ] - user_id_clause = f"ARRAY[{', '.join(user_id_fragments)}]::int[]" - project_keys = project_keys or [] - project_key_fragments = [ - cur.mogrify("%s", (project_key,)).decode("utf-8") - for project_key in project_keys - ] - project_key_clause = f"ARRAY[{', '.join(project_key_fragments)}]::varchar[]" + user_id_clause = helpers.safe_mogrify_array(user_ids, "int", cur) cur.execute( f""" INSERT INTO public.roles ({column_clause}) @@ -203,18 +186,6 @@ def create_provider_resource( WHERE users.user_id = ANY({user_id_clause}) """ ) - cur.execute( - f""" - WITH ps AS ( - SELECT * - FROM public.projects - WHERE projects.project_key = ANY({project_key_clause}) - ) - INSERT INTO public.roles_projects (role_id, project_id) - SELECT {role_id}, ps.project_id - FROM ps - """ - ) cur.execute(f"{_main_select_query(tenant_id, role_id)} LIMIT 1") return cur.fetchone() @@ -223,7 +194,6 @@ def _update_resource_sql( resource_id: int, tenant_id: int, user_ids: list[int] | None = None, - project_keys: list[str] | None = None, **kwargs: dict[str, Any], ) -> dict[str, Any]: with pg_client.PostgresClient() as cur: @@ -233,17 +203,7 @@ def _update_resource_sql( for k, v in kwargs.items() ] set_clause = ", ".join(set_fragments) - user_ids = user_ids or [] - user_id_fragments = [ - cur.mogrify("%s", (user_id,)).decode("utf-8") for user_id in user_ids - ] - user_id_clause = f"ARRAY[{', '.join(user_id_fragments)}]::int[]" - project_keys = project_keys or [] - project_key_fragments = [ - cur.mogrify("%s", (project_key,)).decode("utf-8") - for project_key in project_keys - ] - project_key_clause = f"ARRAY[{', '.join(project_key_fragments)}]::varchar[]" + user_id_clause = helpers.safe_mogrify_array(user_ids, "int", cur) cur.execute( f""" UPDATE public.users @@ -268,27 +228,6 @@ def _update_resource_sql( RETURNING * """ ) - cur.execute( - f""" - DELETE FROM public.roles_projects - USING public.projects - WHERE - projects.project_id = roles_projects.project_id - AND roles_projects.role_id = {resource_id} - AND projects.project_key != ALL({project_key_clause}) - """ - ) - cur.execute( - f""" - INSERT INTO public.roles_projects (role_id, project_id) - SELECT {resource_id}, projects.project_id - FROM public.projects - LEFT JOIN public.roles_projects USING (project_id) - WHERE - projects.project_key = ANY({project_key_clause}) - AND roles_projects.role_id IS NULL - """ - ) cur.execute( f""" UPDATE public.roles diff --git a/ee/api/routers/scim/helpers.py b/ee/api/routers/scim/helpers.py index ebf5f1b67..bb6c56fec 100644 --- a/ee/api/routers/scim/helpers.py +++ b/ee/api/routers/scim/helpers.py @@ -1,6 +1,18 @@ -from typing import Any +from typing import Any, Literal from copy import deepcopy import re +from chalicelib.utils import pg_client + + +def safe_mogrify_array( + items: list[Any] | None, + array_type: Literal["varchar", "int"], + cursor: pg_client.PostgresClient, +) -> str: + items = items or [] + fragments = [cursor.mogrify("%s", (item,)).decode("utf-8") for item in items] + result = f"ARRAY[{', '.join(fragments)}]::{array_type}[]" + return result def convert_query_str_to_list(query_str: str | None) -> list[str]: diff --git a/ee/api/routers/scim/users.py b/ee/api/routers/scim/users.py index e1d67b58e..14639a083 100644 --- a/ee/api/routers/scim/users.py +++ b/ee/api/routers/scim/users.py @@ -1,6 +1,7 @@ from typing import Any from datetime import datetime from psycopg2.extensions import AsIs +from routers.scim import helpers from chalicelib.utils import pg_client from routers.scim.resource_config import ( @@ -10,6 +11,11 @@ from routers.scim.resource_config import ( ClientInput, ProviderInput, ) +from schemas.schemas_ee import ValidIdentityProviderPermissions + + +def _is_valid_permission_for_identity_provider(permission: str) -> bool: + return ValidIdentityProviderPermissions.has_value(permission) def convert_client_resource_update_input_to_provider_resource_update_input( @@ -28,6 +34,14 @@ def convert_client_resource_update_input_to_provider_resource_update_input( result["internal_id"] = client_input["externalId"] if "active" in client_input: result["deleted_at"] = None if client_input["active"] else datetime.now() + if "projectKeys" in client_input: + result["project_keys"] = [item["value"] for item in client_input["projectKeys"]] + if "entitlements" in client_input: + result["permissions"] = [ + item + for item in client_input["entitlements"] + if _is_valid_permission_for_identity_provider(item) + ] return result @@ -53,6 +67,12 @@ def convert_client_resource_rewrite_input_to_provider_resource_rewrite_input( "email": client_input["userName"], "internal_id": client_input.get("externalId"), "name": name, + "project_keys": [item for item in client_input.get("projectKeys", [])], + "permissions": [ + item + for item in client_input.get("entitlements", []) + if _is_valid_permission_for_identity_provider(item) + ], } result = {k: v for k, v in result.items() if v is not None} return result @@ -80,6 +100,12 @@ def convert_client_resource_creation_input_to_provider_resource_creation_input( "email": client_input["userName"], "internal_id": client_input.get("externalId"), "name": name, + "project_keys": [item["value"] for item in client_input.get("projectKeys", [])], + "permissions": [ + item + for item in client_input.get("entitlements", []) + if _is_valid_permission_for_identity_provider(item) + ], } result = {k: v for k, v in result.items() if v is not None} return result @@ -223,11 +249,57 @@ def get_provider_resource( return cur.fetchone() +def _update_role_projects_and_permissions( + role_id: int | None, + project_keys: list[str] | None, + permissions: list[str] | None, + cur: pg_client.PostgresClient, +) -> None: + all_projects = "true" if not project_keys else "false" + project_key_clause = helpers.safe_mogrify_array(project_keys, "varchar", cur) + permission_clause = helpers.safe_mogrify_array(permissions, "varchar", cur) + cur.execute( + f""" + UPDATE public.roles + SET + updated_at = now(), + all_projects = {all_projects}, + permissions = {permission_clause} + WHERE role_id = {role_id} + RETURNING * + """ + ) + cur.execute( + f""" + DELETE FROM public.roles_projects + USING public.projects + WHERE + projects.project_id = roles_projects.project_id + AND roles_projects.role_id = {role_id} + AND projects.project_key != ALL({project_key_clause}) + """ + ) + cur.execute( + f""" + INSERT INTO public.roles_projects (role_id, project_id) + SELECT {role_id}, projects.project_id + FROM public.projects + LEFT JOIN public.roles_projects USING (project_id) + WHERE + projects.project_key = ANY({project_key_clause}) + AND roles_projects.role_id IS NULL + RETURNING * + """ + ) + + def create_provider_resource( email: str, tenant_id: int, name: str = "", internal_id: str | None = None, + project_keys: list[str] | None = None, + permissions: list[str] | None = None, ) -> ProviderResource: with pg_client.PostgresClient() as cur: cur.execute( @@ -259,7 +331,11 @@ def create_provider_resource( }, ) ) - return cur.fetchone() + user = cur.fetchone() + _update_role_projects_and_permissions( + user["role_id"], project_keys, permissions, cur + ) + return user def restore_provider_resource( @@ -267,6 +343,8 @@ def restore_provider_resource( email: str, name: str = "", internal_id: str | None = None, + project_keys: list[str] | None = None, + permissions: list[str] | None = None, **kwargs: dict[str, Any], ) -> ProviderResource: with pg_client.PostgresClient() as cur: @@ -300,7 +378,42 @@ def restore_provider_resource( }, ) ) - return cur.fetchone() + user = cur.fetchone() + _update_role_projects_and_permissions( + user["role_id"], project_keys, permissions, cur + ) + return user + + +def _update_resource_sql( + resource_id: int, + tenant_id: int, + project_keys: list[str] | None = None, + permissions: list[str] | None = None, + **kwargs: dict[str, Any], +) -> dict[str, Any]: + with pg_client.PostgresClient() as cur: + kwargs["updated_at"] = datetime.now() + set_fragments = [ + cur.mogrify("%s = %s", (AsIs(k), v)).decode("utf-8") + for k, v in kwargs.items() + ] + set_clause = ", ".join(set_fragments) + cur.execute( + f""" + UPDATE public.users + SET {set_clause} + WHERE + users.user_id = {resource_id} + AND users.tenant_id = {tenant_id} + AND users.deleted_at IS NULL + RETURNING * + """ + ) + user = cur.fetchone() + role_id = user["role_id"] + _update_role_projects_and_permissions(role_id, project_keys, permissions, cur) + return user def rewrite_provider_resource( @@ -309,37 +422,18 @@ def rewrite_provider_resource( email: str, name: str = "", internal_id: str | None = None, -): - with pg_client.PostgresClient() as cur: - cur.execute( - cur.mogrify( - """ - WITH u AS ( - UPDATE public.users - SET - email = %(email)s, - name = %(name)s, - internal_id = %(internal_id)s, - updated_at = now() - WHERE - users.user_id = %(user_id)s - AND users.tenant_id = %(tenant_id)s - AND users.deleted_at IS NULL - RETURNING * - ) - SELECT * - FROM u - """, - { - "tenant_id": tenant_id, - "user_id": resource_id, - "email": email, - "name": name, - "internal_id": internal_id, - }, - ) - ) - return cur.fetchone() + project_keys: list[str] | None = None, + permissions: list[str] | None = None, +) -> dict[str, Any]: + return _update_resource_sql( + resource_id, + tenant_id, + email=email, + name=name, + internal_id=internal_id, + project_keys=project_keys, + permissions=permissions, + ) def update_provider_resource( @@ -347,29 +441,4 @@ def update_provider_resource( tenant_id: int, **kwargs, ): - with pg_client.PostgresClient() as cur: - set_fragments = [] - kwargs["updated_at"] = datetime.now() - for k, v in kwargs.items(): - fragment = cur.mogrify( - "%s = %s", - (AsIs(k), v), - ).decode("utf-8") - set_fragments.append(fragment) - set_clause = ", ".join(set_fragments) - cur.execute( - f""" - WITH u AS ( - UPDATE public.users - SET {set_clause} - WHERE - users.user_id = {resource_id} - AND users.tenant_id = {tenant_id} - AND users.deleted_at IS NULL - RETURNING * - ) - SELECT * - FROM u - """ - ) - return cur.fetchone() + return _update_resource_sql(resource_id, tenant_id, **kwargs) diff --git a/ee/api/schemas/schemas_ee.py b/ee/api/schemas/schemas_ee.py index 394f88859..bde5c14d0 100644 --- a/ee/api/schemas/schemas_ee.py +++ b/ee/api/schemas/schemas_ee.py @@ -28,6 +28,14 @@ class ServicePermissions(str, Enum): READ_NOTES = "SERVICE_READ_NOTES" +class ValidIdentityProviderPermissions(str, Enum): + SESSION_REPLAY = "SESSION_REPLAY" + METRICS = "METRICS" + ASSIST_LIVE = "ASSIST_LIVE" + ASSIST_CALL = "ASSIST_CALL" + SPOT_PUBLIC = "SPOT_PUBLIC" + + class CurrentContext(schemas.CurrentContext): permissions: List[Union[Permissions, ServicePermissions]] = Field(...) service_account: bool = Field(default=False)