openreplay/ee/api/routers/scim/groups.py

327 lines
9.5 KiB
Python

from typing import Any
from datetime import datetime
from psycopg2.extensions import AsIs
from chalicelib.utils import pg_client
from routers.scim.resource_config import (
ProviderResource,
ClientResource,
ResourceId,
ClientInput,
ProviderInput,
)
def convert_client_resource_update_input_to_provider_resource_update_input(
tenant_id: int, client_input: ClientInput
) -> ProviderInput:
result = {}
if "displayName" in client_input:
result["name"] = client_input["displayName"]
if "members" in client_input:
members = client_input["members"] or []
result["user_ids"] = [int(member["value"]) for member in members]
return result
def convert_provider_resource_to_client_resource(
provider_resource: ProviderResource,
) -> ClientResource:
members = provider_resource["users"] or []
return {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"id": str(provider_resource["role_id"]),
"meta": {
"resourceType": "Group",
"created": provider_resource["created_at"].strftime("%Y-%m-%dT%H:%M:%SZ"),
"lastModified": provider_resource["updated_at"].strftime(
"%Y-%m-%dT%H:%M:%SZ"
),
"location": f"Groups/{provider_resource['role_id']}",
},
"displayName": provider_resource["name"],
"members": [
{
"value": str(member["user_id"]),
"$ref": f"Users/{member['user_id']}",
"type": "User",
}
for member in members
],
}
def get_active_resource_count(tenant_id: int) -> int:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT COUNT(*)
FROM public.roles
WHERE
roles.tenant_id = %(tenant_id)s
AND roles.deleted_at IS NULL
""",
{"tenant_id": tenant_id},
)
)
return cur.fetchone()["count"]
def get_provider_resource_chunk(
offset: int, tenant_id: int, limit: int
) -> list[ProviderResource]:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT
roles.*,
COALESCE(
(
SELECT json_agg(users)
FROM public.users
WHERE users.role_id = roles.role_id
),
'[]'
) AS users
FROM public.roles
WHERE
roles.tenant_id = %(tenant_id)s
AND roles.deleted_at IS NULL
LIMIT %(limit)s
OFFSET %(offset)s;
""",
{
"offset": offset,
"limit": limit,
"tenant_id": tenant_id,
},
)
)
return cur.fetchall()
def get_provider_resource(
resource_id: ResourceId, tenant_id: int
) -> ProviderResource | None:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT
roles.*,
COALESCE(
(
SELECT json_agg(users)
FROM public.users
WHERE users.role_id = roles.role_id
),
'[]'
) AS users
FROM public.roles
WHERE
roles.tenant_id = %(tenant_id)s
AND roles.role_id = %(resource_id)s
AND roles.deleted_at IS NULL
LIMIT 1;
""",
{"resource_id": resource_id, "tenant_id": tenant_id},
)
)
return cur.fetchone()
def get_provider_resource_from_unique_fields(
**kwargs: dict[str, Any],
) -> ProviderResource | None:
# note(jon): we do not really use this for scim.groups (openreplay.roles) as we don't have unique values outside
# of the primary key
return None
def convert_client_resource_creation_input_to_provider_resource_creation_input(
tenant_id: int, client_input: ClientInput
) -> ProviderInput:
return {
"name": client_input["displayName"],
"user_ids": [
int(member["value"]) for member in client_input.get("members", [])
],
}
def convert_client_resource_rewrite_input_to_provider_resource_rewrite_input(
tenant_id: int, client_input: ClientInput
) -> ProviderInput:
return {
"name": client_input["displayName"],
"user_ids": [
int(member["value"]) for member in client_input.get("members", [])
],
}
def create_provider_resource(
name: str,
tenant_id: int,
user_ids: list[str] | None = None,
**kwargs: dict[str, Any],
) -> ProviderResource:
with pg_client.PostgresClient() as cur:
kwargs["name"] = name
kwargs["tenant_id"] = tenant_id
column_fragments = [
cur.mogrify("%s", (AsIs(k),)).decode("utf-8") for k in kwargs.keys()
]
column_clause = ", ".join(column_fragments)
value_fragments = [
cur.mogrify("%s", (v,)).decode("utf-8") for v in kwargs.values()
]
value_clause = ", ".join(value_fragments)
user_ids = user_ids or []
user_id_fragments = [
cur.mogrify("%s", (user_id,)).decode("utf-8") for user_id in user_ids
]
user_id_clause = f"ARRAY[{', '.join(user_id_fragments)}]::int[]"
cur.execute(
f"""
WITH
r AS (
INSERT INTO public.roles ({column_clause})
VALUES ({value_clause})
RETURNING *
),
linked_users AS (
UPDATE public.users
SET
updated_at = now(),
role_id = (SELECT r.role_id FROM r)
WHERE users.user_id = ANY({user_id_clause})
RETURNING *
)
SELECT
r.*,
COALESCE(
(
SELECT json_agg(linked_users.*)
FROM linked_users
),
'[]'
) AS users
FROM r
LIMIT 1;
"""
)
return cur.fetchone()
def _update_resource_sql(
resource_id: int,
tenant_id: int,
user_ids: list[int] | None = None,
**kwargs: dict[str, Any],
) -> dict[str, Any]:
with pg_client.PostgresClient() as cur:
kwargs["updated_at"] = datetime.now()
set_fragments = [
cur.mogrify("%s = %s", (AsIs(k), v)).decode("utf-8")
for k, v in kwargs.items()
]
set_clause = ", ".join(set_fragments)
user_ids = user_ids or []
user_id_fragments = [
cur.mogrify("%s", (user_id,)).decode("utf-8") for user_id in user_ids
]
user_id_clause = f"ARRAY[{', '.join(user_id_fragments)}]::int[]"
cur.execute(
f"""
UPDATE public.users
SET role_id = NULL
WHERE users.role_id = {resource_id}
"""
)
cur.execute(
f"""
WITH
r AS (
UPDATE public.roles
SET {set_clause}
WHERE
roles.role_id = {resource_id}
AND roles.tenant_id = {tenant_id}
AND roles.deleted_at IS NULL
RETURNING *
),
linked_users AS (
UPDATE public.users
SET
updated_at = now(),
role_id = {resource_id}
WHERE users.user_id = ANY({user_id_clause})
RETURNING *
)
SELECT
r.*,
COALESCE(
(
SELECT json_agg(linked_users.*)
FROM linked_users
),
'[]'
) AS users
FROM r
LIMIT 1;
"""
)
return cur.fetchone()
def delete_provider_resource(resource_id: ResourceId, tenant_id: int) -> None:
_update_resource_sql(
resource_id=resource_id,
tenant_id=tenant_id,
deleted_at=datetime.now(),
)
def restore_provider_resource(
resource_id: int,
tenant_id: int,
name: str,
**kwargs: dict[str, Any],
) -> dict[str, Any]:
return _update_resource_sql(
resource_id=resource_id,
tenant_id=tenant_id,
name=name,
created_at=datetime.now(),
deleted_at=None,
**kwargs,
)
def rewrite_provider_resource(
resource_id: int,
tenant_id: int,
name: str,
**kwargs: dict[str, Any],
) -> dict[str, Any]:
return _update_resource_sql(
resource_id=resource_id,
tenant_id=tenant_id,
name=name,
**kwargs,
)
def update_provider_resource(
resource_id: int,
tenant_id: int,
**kwargs: dict[str, Any],
):
return _update_resource_sql(
resource_id=resource_id,
tenant_id=tenant_id,
**kwargs,
)