fix users.name error and handle role/project interactions

This commit is contained in:
Jonathan Griffin 2025-05-06 14:15:31 +02:00
parent c02f52b413
commit 891a0e31c1
4 changed files with 184 additions and 162 deletions

View file

@ -243,7 +243,7 @@ group_config = ResourceConfig(
get_provider_resource_chunk=groups.get_provider_resource_chunk,
get_provider_resource=groups.get_provider_resource,
convert_client_resource_creation_input_to_provider_resource_creation_input=groups.convert_client_resource_creation_input_to_provider_resource_creation_input,
get_provider_resource_from_unique_fields=groups.get_provider_resource_from_unique_fields,
get_provider_resource_from_unique_fields=lambda **kwargs: None,
restore_provider_resource=None,
create_provider_resource=groups.create_provider_resource,
delete_provider_resource=groups.delete_provider_resource,
@ -382,7 +382,7 @@ async def delete_resource(
@public_app.put("/{resource_type}/{resource_id}")
async def put_resource(
resource_type: SCIMResource,
resource_id: str,
resource_id: int | str,
r: Request,
tenant_id=Depends(auth_required),
attributes: list[str] | None = Query(None),
@ -424,7 +424,7 @@ async def put_resource(
@public_app.patch("/{resource_type}/{resource_id}")
async def patch_resource(
resource_type: SCIMResource,
resource_id: str,
resource_id: int | str,
r: Request,
tenant_id=Depends(auth_required),
attributes: list[str] | None = Query(None),

View file

@ -110,6 +110,29 @@
"returned": "default",
"uniqueness": "none"
},
{
"name": "projectKeys",
"type": "complex",
"multiValued": true,
"description": "A list of project keys associated with the group.",
"required": false,
"caseExact": false,
"mutability": "readWrite",
"returned": "default",
"subAttributes": [
{
"name": "value",
"type": "string",
"multiValued": false,
"description": "The unique project key.",
"required": true,
"mutability": "immutable",
"returned": "default",
"caseExact": true,
"uniqueness": "none"
}
]
},
{
"name": "members",
"type": "complex",

View file

@ -21,6 +21,8 @@ def convert_client_resource_update_input_to_provider_resource_update_input(
if "members" in client_input:
members = client_input["members"] or []
result["user_ids"] = [int(member["value"]) for member in members]
if "projectKeys" in client_input:
result["project_keys"] = [item["value"] for item in client_input["projectKeys"]]
return result
@ -48,6 +50,9 @@ def convert_provider_resource_to_client_resource(
}
for member in members
],
"projectKeys": [
{"value": project_key} for project_key in provider_resource["project_keys"]
],
}
@ -68,13 +73,15 @@ def get_active_resource_count(tenant_id: int) -> int:
return cur.fetchone()["count"]
def get_provider_resource_chunk(
offset: int, tenant_id: int, limit: int
) -> list[ProviderResource]:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
def _main_select_query(tenant_id: int, resource_id: int | None = None) -> str:
where_and_clauses = [
f"roles.tenant_id = {tenant_id}",
"roles.deleted_at IS NULL",
]
if resource_id is not None:
where_and_clauses.append(f"roles.role_id = {resource_id}")
where_clause = " AND ".join(where_and_clauses)
return f"""
SELECT
roles.*,
COALESCE(
@ -84,21 +91,27 @@ def get_provider_resource_chunk(
WHERE users.role_id = roles.role_id
),
'[]'
) AS users
) AS users,
COALESCE(
(
SELECT json_agg(projects.project_key)
FROM public.projects
LEFT JOIN public.roles_projects USING (project_id)
WHERE roles_projects.role_id = roles.role_id
),
'[]'
) AS project_keys
FROM public.roles
WHERE
roles.tenant_id = %(tenant_id)s
AND roles.deleted_at IS NULL
LIMIT %(limit)s
OFFSET %(offset)s;
""",
{
"offset": offset,
"limit": limit,
"tenant_id": tenant_id,
},
)
)
WHERE {where_clause}
"""
def get_provider_resource_chunk(
offset: int, tenant_id: int, limit: int
) -> list[ProviderResource]:
query = _main_select_query(tenant_id)
with pg_client.PostgresClient() as cur:
cur.execute(f"{query} LIMIT {limit} OFFSET {offset}")
return cur.fetchall()
@ -106,40 +119,10 @@ def get_provider_resource(
resource_id: ResourceId, tenant_id: int
) -> ProviderResource | None:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT
roles.*,
COALESCE(
(
SELECT json_agg(users)
FROM public.users
WHERE users.role_id = roles.role_id
),
'[]'
) AS users
FROM public.roles
WHERE
roles.tenant_id = %(tenant_id)s
AND roles.role_id = %(resource_id)s
AND roles.deleted_at IS NULL
LIMIT 1;
""",
{"resource_id": resource_id, "tenant_id": tenant_id},
)
)
cur.execute(f"{_main_select_query(tenant_id, resource_id)} LIMIT 1")
return cur.fetchone()
def get_provider_resource_from_unique_fields(
**kwargs: dict[str, Any],
) -> ProviderResource | None:
# note(jon): we do not really use this for scim.groups (openreplay.roles) as we don't have unique values outside
# of the primary key
return None
def convert_client_resource_creation_input_to_provider_resource_creation_input(
tenant_id: int, client_input: ClientInput
) -> ProviderInput:
@ -148,6 +131,7 @@ def convert_client_resource_creation_input_to_provider_resource_creation_input(
"user_ids": [
int(member["value"]) for member in client_input.get("members", [])
],
"project_keys": [item["value"] for item in client_input.get("projectKeys", [])],
}
@ -159,6 +143,7 @@ def convert_client_resource_rewrite_input_to_provider_resource_rewrite_input(
"user_ids": [
int(member["value"]) for member in client_input.get("members", [])
],
"project_keys": [item["value"] for item in client_input.get("projectKeys", [])],
}
@ -166,6 +151,7 @@ def create_provider_resource(
name: str,
tenant_id: int,
user_ids: list[str] | None = None,
project_keys: list[str] | None = None,
**kwargs: dict[str, Any],
) -> ProviderResource:
with pg_client.PostgresClient() as cur:
@ -184,35 +170,42 @@ def create_provider_resource(
cur.mogrify("%s", (user_id,)).decode("utf-8") for user_id in user_ids
]
user_id_clause = f"ARRAY[{', '.join(user_id_fragments)}]::int[]"
project_keys = project_keys or []
project_key_fragments = [
cur.mogrify("%s", (project_key,)).decode("utf-8")
for project_key in project_keys
]
project_key_clause = f"ARRAY[{', '.join(project_key_fragments)}]::varchar[]"
cur.execute(
f"""
WITH
r AS (
INSERT INTO public.roles ({column_clause})
VALUES ({value_clause})
RETURNING *
),
linked_users AS (
RETURNING role_id
"""
)
role_id = cur.fetchone()["role_id"]
cur.execute(
f"""
UPDATE public.users
SET
updated_at = now(),
role_id = (SELECT r.role_id FROM r)
role_id = {role_id}
WHERE users.user_id = ANY({user_id_clause})
RETURNING *
)
SELECT
r.*,
COALESCE(
(
SELECT json_agg(linked_users.*)
FROM linked_users
),
'[]'
) AS users
FROM r
LIMIT 1;
"""
)
cur.execute(
f"""
WITH ps AS (
SELECT *
FROM public.projects
WHERE projects.project_key = ANY({project_key_clause})
)
INSERT INTO public.roles_projects (role_id, project_id)
SELECT {role_id}, ps.project_id
FROM ps
"""
)
cur.execute(f"{_main_select_query(tenant_id, role_id)} LIMIT 1")
return cur.fetchone()
@ -220,6 +213,7 @@ def _update_resource_sql(
resource_id: int,
tenant_id: int,
user_ids: list[int] | None = None,
project_keys: list[str] | None = None,
**kwargs: dict[str, Any],
) -> dict[str, Any]:
with pg_client.PostgresClient() as cur:
@ -234,46 +228,68 @@ def _update_resource_sql(
cur.mogrify("%s", (user_id,)).decode("utf-8") for user_id in user_ids
]
user_id_clause = f"ARRAY[{', '.join(user_id_fragments)}]::int[]"
project_keys = project_keys or []
project_key_fragments = [
cur.mogrify("%s", (project_key,)).decode("utf-8")
for project_key in project_keys
]
project_key_clause = f"ARRAY[{', '.join(project_key_fragments)}]::varchar[]"
cur.execute(
f"""
UPDATE public.users
SET role_id = NULL
WHERE users.role_id = {resource_id}
SET
updated_at = now(),
role_id = NULL
WHERE
users.role_id = {resource_id}
AND users.user_id != ALL({user_id_clause})
RETURNING *
"""
)
cur.execute(
f"""
UPDATE public.users
SET
updated_at = now(),
role_id = {resource_id}
WHERE
(users.role_id != {resource_id} OR users.role_id IS NULL)
AND users.user_id = ANY({user_id_clause})
RETURNING *
"""
)
cur.execute(
f"""
DELETE FROM public.roles_projects
USING public.projects
WHERE
projects.project_id = roles_projects.project_id
AND roles_projects.role_id = {resource_id}
AND projects.project_key != ALL({project_key_clause})
"""
)
cur.execute(
f"""
INSERT INTO public.roles_projects (role_id, project_id)
SELECT {resource_id}, projects.project_id
FROM public.projects
LEFT JOIN public.roles_projects USING (project_id)
WHERE
projects.project_key = ANY({project_key_clause})
AND roles_projects.role_id IS NULL
"""
)
cur.execute(
f"""
WITH
r AS (
UPDATE public.roles
SET {set_clause}
WHERE
roles.role_id = {resource_id}
AND roles.tenant_id = {tenant_id}
AND roles.deleted_at IS NULL
RETURNING *
),
linked_users AS (
UPDATE public.users
SET
updated_at = now(),
role_id = {resource_id}
WHERE users.user_id = ANY({user_id_clause})
RETURNING *
)
SELECT
r.*,
COALESCE(
(
SELECT json_agg(linked_users.*)
FROM linked_users
),
'[]'
) AS users
FROM r
LIMIT 1;
"""
)
cur.execute(f"{_main_select_query(tenant_id, resource_id)} LIMIT 1")
return cur.fetchone()
@ -285,22 +301,6 @@ def delete_provider_resource(resource_id: ResourceId, tenant_id: int) -> None:
)
def restore_provider_resource(
resource_id: int,
tenant_id: int,
name: str,
**kwargs: dict[str, Any],
) -> dict[str, Any]:
return _update_resource_sql(
resource_id=resource_id,
tenant_id=tenant_id,
name=name,
created_at=datetime.now(),
deleted_at=None,
**kwargs,
)
def rewrite_provider_resource(
resource_id: int,
tenant_id: int,

View file

@ -3,7 +3,6 @@ from datetime import datetime
from psycopg2.extensions import AsIs
from chalicelib.utils import pg_client
from chalicelib.core import roles
from routers.scim.resource_config import (
ProviderResource,
ClientResource,
@ -35,8 +34,6 @@ def convert_client_resource_update_input_to_provider_resource_update_input(
def convert_client_resource_rewrite_input_to_provider_resource_rewrite_input(
tenant_id: int, client_input: ClientInput
) -> ProviderInput:
name = client_input.get("name", {}).get("formatted")
if not name:
name = " ".join(
[
x
@ -50,6 +47,8 @@ def convert_client_resource_rewrite_input_to_provider_resource_rewrite_input(
if x
]
)
if not name:
name = client_input.get("displayName")
result = {
"email": client_input["userName"],
"internal_id": client_input.get("externalId"),
@ -62,8 +61,6 @@ def convert_client_resource_rewrite_input_to_provider_resource_rewrite_input(
def convert_client_resource_creation_input_to_provider_resource_creation_input(
tenant_id: int, client_input: ClientInput
) -> ProviderInput:
name = client_input.get("name", {}).get("formatted")
if not name:
name = " ".join(
[
x
@ -77,6 +74,8 @@ def convert_client_resource_creation_input_to_provider_resource_creation_input(
if x
]
)
if not name:
name = client_input.get("displayName")
result = {
"email": client_input["userName"],
"internal_id": client_input.get("externalId"),