added new fields to user endpoints

This commit is contained in:
Jonathan Griffin 2025-04-22 09:00:47 +02:00
parent d0a1e894d6
commit ebeff746cb
2 changed files with 221 additions and 144 deletions

View file

@ -162,37 +162,6 @@ def reset_member(tenant_id, editor_id, user_id_to_update):
return {"data": {"invitationLink": generate_new_invitation(user_id_to_update)}} return {"data": {"invitationLink": generate_new_invitation(user_id_to_update)}}
def update_scim_user(
user_id: int,
tenant_id: int,
email: str,
):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
WITH u AS (
UPDATE public.users
SET email = %(email)s
WHERE
users.user_id = %(user_id)s
AND users.tenant_id = %(tenant_id)s
AND users.deleted_at IS NULL
RETURNING *
)
SELECT *
FROM u;
""",
{
"tenant_id": tenant_id,
"user_id": user_id,
"email": email,
},
)
)
return helper.dict_to_camel_case(cur.fetchone())
def update(tenant_id, user_id, changes, output=True): def update(tenant_id, user_id, changes, output=True):
AUTH_KEYS = [ AUTH_KEYS = [
"password", "password",
@ -381,13 +350,39 @@ def get(user_id, tenant_id):
return helper.dict_to_camel_case(r) return helper.dict_to_camel_case(r)
def get_scim_users_paginated(start_index, tenant_id, count=None):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT
users.*,
roles.name AS role_name
FROM public.users
LEFT JOIN public.roles USING (role_id)
WHERE
users.tenant_id = %(tenant_id)s
AND users.deleted_at IS NULL
LIMIT %(limit)s
OFFSET %(offset)s;
""",
{"offset": start_index - 1, "limit": count, "tenant_id": tenant_id},
)
)
r = cur.fetchall()
return helper.list_to_camel_case(r)
def get_scim_user_by_id(user_id, tenant_id): def get_scim_user_by_id(user_id, tenant_id):
with pg_client.PostgresClient() as cur: with pg_client.PostgresClient() as cur:
cur.execute( cur.execute(
cur.mogrify( cur.mogrify(
""" """
SELECT * SELECT
users.*,
roles.name AS role_name
FROM public.users FROM public.users
LEFT JOIN public.roles USING (role_id)
WHERE WHERE
users.user_id = %(user_id)s users.user_id = %(user_id)s
AND users.tenant_id = %(tenant_id)s AND users.tenant_id = %(tenant_id)s
@ -403,6 +398,140 @@ def get_scim_user_by_id(user_id, tenant_id):
return helper.dict_to_camel_case(cur.fetchone()) return helper.dict_to_camel_case(cur.fetchone())
def create_scim_user(
email: str,
tenant_id: int,
name: str = "",
internal_id: str | None = None,
role_id: int | None = None,
):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
WITH u AS (
INSERT INTO public.users (
tenant_id,
email,
name,
internal_id,
role_id
)
VALUES (
%(tenant_id)s,
%(email)s,
%(name)s,
%(internal_id)s,
%(role_id)s
)
RETURNING *
)
SELECT
u.*,
roles.name as role_name
FROM u LEFT JOIN public.roles USING (role_id);
""",
{
"tenant_id": tenant_id,
"email": email,
"name": name,
"internal_id": internal_id,
"role_id": role_id,
},
)
)
return helper.dict_to_camel_case(cur.fetchone())
def restore_scim_user(
user_id: int,
tenant_id: int,
email: str,
name: str = "",
internal_id: str | None = None,
role_id: int | None = None,
):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
WITH u AS (
UPDATE public.users
SET
tenant_id = %(tenant_id)s,
email = %(email)s,
name = %(name)s,
internal_id = %(internal_id)s,
role_id = %(role_id)s,
deleted_at = NULL,
created_at = default,
api_key = default,
jwt_iat = NULL,
weekly_report = default
WHERE users.user_id = %(user_id)s
RETURNING *
)
SELECT
u.*,
roles.name as role_name
FROM u LEFT JOIN public.roles USING (role_id);
""",
{
"tenant_id": tenant_id,
"user_id": user_id,
"email": email,
"name": name,
"internal_id": internal_id,
"role_id": role_id,
},
)
)
return helper.dict_to_camel_case(cur.fetchone())
def update_scim_user(
user_id: int,
tenant_id: int,
email: str,
name: str = "",
internal_id: str | None = None,
role_id: int | None = None,
):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
WITH u AS (
UPDATE public.users
SET
email = %(email)s,
name = %(name)s,
internal_id = %(internal_id)s,
role_id = %(role_id)s
WHERE
users.user_id = %(user_id)s
AND users.tenant_id = %(tenant_id)s
AND users.deleted_at IS NULL
RETURNING *
)
SELECT
u.*,
roles.name as role_name
FROM u LEFT JOIN public.roles USING (role_id);
""",
{
"tenant_id": tenant_id,
"user_id": user_id,
"email": email,
"name": name,
"internal_id": internal_id,
"role_id": role_id,
},
)
)
return helper.dict_to_camel_case(cur.fetchone())
def generate_new_api_key(user_id): def generate_new_api_key(user_id):
with pg_client.PostgresClient() as cur: with pg_client.PostgresClient() as cur:
cur.execute( cur.execute(
@ -513,7 +642,7 @@ def edit_member(
return {"data": user} return {"data": user}
def get_existing_scim_user_by_unique_values(email): def get_existing_scim_user_by_unique_values_from_all_users(email):
with pg_client.PostgresClient() as cur: with pg_client.PostgresClient() as cur:
cur.execute( cur.execute(
cur.mogrify( cur.mogrify(
@ -558,26 +687,6 @@ def get_by_email_only(email):
return helper.dict_to_camel_case(r) return helper.dict_to_camel_case(r)
def get_users_paginated(start_index, tenant_id, count=None):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT *
FROM public.users
WHERE
users.tenant_id = %(tenant_id)s
AND users.deleted_at IS NULL
LIMIT %(limit)s
OFFSET %(offset)s;
""",
{"offset": start_index - 1, "limit": count, "tenant_id": tenant_id},
)
)
r = cur.fetchall()
return helper.list_to_camel_case(r)
def get_member(tenant_id, user_id): def get_member(tenant_id, user_id):
with pg_client.PostgresClient() as cur: with pg_client.PostgresClient() as cur:
cur.execute( cur.execute(
@ -1093,41 +1202,6 @@ def create_sso_user(tenant_id, email, admin, name, origin, role_id, internal_id=
return helper.dict_to_camel_case(cur.fetchone()) return helper.dict_to_camel_case(cur.fetchone())
def create_scim_user(
email,
name,
tenant_id,
):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
WITH u AS (
INSERT INTO public.users (
tenant_id,
email,
name
)
VALUES (
%(tenant_id)s,
%(email)s,
%(name)s
)
RETURNING *
)
SELECT *
FROM u;
""",
{
"tenant_id": tenant_id,
"email": email,
"name": name,
},
)
)
return helper.dict_to_camel_case(cur.fetchone())
def soft_delete_scim_user_by_id(user_id, tenant_id): def soft_delete_scim_user_by_id(user_id, tenant_id):
with pg_client.PostgresClient() as cur: with pg_client.PostgresClient() as cur:
cur.execute( cur.execute(
@ -1314,35 +1388,6 @@ def restore_sso_user(
return helper.dict_to_camel_case(cur.fetchone()) return helper.dict_to_camel_case(cur.fetchone())
def restore_scim_user(
user_id,
tenant_id,
):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
WITH u AS (
UPDATE public.users
SET
tenant_id = %(tenant_id)s,
deleted_at = NULL,
created_at = default,
api_key = default,
jwt_iat = NULL,
weekly_report = default
WHERE users.user_id = %(user_id)s
RETURNING *
)
SELECT *
FROM u;
""",
{"tenant_id": tenant_id, "user_id": user_id},
)
)
return helper.dict_to_camel_case(cur.fetchone())
def get_user_settings(user_id): def get_user_settings(user_id):
# read user settings from users.settings:jsonb column # read user settings from users.settings:jsonb column
with pg_client.PostgresClient() as cur: with pg_client.PostgresClient() as cur:

View file

@ -1,3 +1,4 @@
from copy import deepcopy
import logging import logging
from typing import Any from typing import Any
@ -7,7 +8,7 @@ from fastapi.responses import JSONResponse
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel from pydantic import BaseModel
from chalicelib.core import users, tenants from chalicelib.core import users, roles, tenants
from chalicelib.utils.scim_auth import ( from chalicelib.utils.scim_auth import (
auth_optional, auth_optional,
auth_required, auth_required,
@ -171,13 +172,21 @@ async def get_schemas(filter_param: str | None = Query(None, alias="filter")):
) )
@public_app.get("/Schemas/{schema_id}", dependencies=[Depends(auth_required)]) @public_app.get("/Schemas/{schema_id}")
async def get_schema(schema_id: str): async def get_schema(schema_id: str, tenant_id=Depends(auth_required)):
if schema_id not in SCHEMA_IDS_TO_SCHEMA_DETAILS: if schema_id not in SCHEMA_IDS_TO_SCHEMA_DETAILS:
return _not_found_error_response(schema_id) return _not_found_error_response(schema_id)
schema = deepcopy(SCHEMA_IDS_TO_SCHEMA_DETAILS[schema_id])
if schema_id == "urn:ietf:params:scim:schemas:core:2.0:User":
db_roles = roles.get_roles(tenant_id)
role_names = [role["name"] for role in db_roles]
user_type_attribute = next(
filter(lambda x: x["name"] == "userType", schema["attributes"])
)
user_type_attribute["canonicalValues"] = role_names
return JSONResponse( return JSONResponse(
status_code=200, status_code=200,
content=SCHEMA_IDS_TO_SCHEMA_DETAILS[schema_id], content=schema,
) )
@ -205,7 +214,22 @@ async def get_service_provider_config(
MAX_USERS_PER_PAGE = 10 MAX_USERS_PER_PAGE = 10
def _convert_db_user_to_scim_user( def _parse_scim_user_input(data: dict[str, Any], tenant_id: str) -> dict[str, Any]:
role_id = None
if "userType" in data:
role = roles.get_role_by_name(tenant_id, data["userType"])
role_id = role["roleId"] if role else None
result = {
"email": data["userName"],
"internal_id": data.get("externalId"),
"name": data.get("name", {}).get("formatted") or data.get("displayName"),
"role_id": role_id,
}
result = {k: v for k, v in result.items() if v is not None}
return result
def _serialize_db_user_to_scim_user(
db_user: dict[str, Any], db_user: dict[str, Any],
attributes: list[str] | None = None, attributes: list[str] | None = None,
excluded_attributes: list[str] | None = None, excluded_attributes: list[str] | None = None,
@ -239,7 +263,8 @@ def _convert_db_user_to_scim_user(
"formatted": db_user["name"], "formatted": db_user["name"],
}, },
"displayName": db_user["name"] or db_user["email"], "displayName": db_user["name"] or db_user["email"],
"userType": db_user.get("roleName"),
"active": db_user["deletedAt"] is None,
} }
scim_user = scim_helpers.filter_attributes(scim_user, included_attributes) scim_user = scim_helpers.filter_attributes(scim_user, included_attributes)
scim_user = scim_helpers.exclude_attributes(scim_user, excluded_attributes) scim_user = scim_helpers.exclude_attributes(scim_user, excluded_attributes)
@ -260,12 +285,12 @@ async def get_users(
) )
# todo(jon): this might not be the most efficient thing to do. could be better to just do a count. # todo(jon): this might not be the most efficient thing to do. could be better to just do a count.
# but this is the fastest thing at the moment just to test that it's working # but this is the fastest thing at the moment just to test that it's working
total_resources = users.get_users_paginated(1, tenant_id) total_resources = users.get_scim_users_paginated(1, tenant_id)
db_resources = users.get_users_paginated( db_resources = users.get_scim_users_paginated(
start_index, tenant_id, count=items_per_page start_index, tenant_id, count=items_per_page
) )
scim_resources = [ scim_resources = [
_convert_db_user_to_scim_user(resource, attributes, excluded_attributes) _serialize_db_user_to_scim_user(resource, attributes, excluded_attributes)
for resource in db_resources for resource in db_resources
] ]
return JSONResponse( return JSONResponse(
@ -289,7 +314,7 @@ async def get_user(
db_resource = users.get_scim_user_by_id(user_id, tenant_id) db_resource = users.get_scim_user_by_id(user_id, tenant_id)
if not db_resource: if not db_resource:
return _not_found_error_response(user_id) return _not_found_error_response(user_id)
scim_resource = _convert_db_user_to_scim_user( scim_resource = _serialize_db_user_to_scim_user(
db_resource, attributes, excluded_attributes db_resource, attributes, excluded_attributes
) )
return JSONResponse(status_code=200, content=scim_resource) return JSONResponse(status_code=200, content=scim_resource)
@ -297,26 +322,28 @@ async def get_user(
@public_app.post("/Users") @public_app.post("/Users")
async def create_user(r: Request, tenant_id=Depends(auth_required)): async def create_user(r: Request, tenant_id=Depends(auth_required)):
payload = await r.json() scim_payload = await r.json()
if "userName" not in payload: try:
db_payload = _parse_scim_user_input(scim_payload, tenant_id)
except KeyError:
return _invalid_value_error_response() return _invalid_value_error_response()
# note(jon): this method will return soft deleted users as well existing_db_resource = users.get_existing_scim_user_by_unique_values_from_all_users(
existing_db_resource = users.get_existing_scim_user_by_unique_values( db_payload["email"]
payload["userName"]
) )
if existing_db_resource and existing_db_resource["deletedAt"] is None: if existing_db_resource and existing_db_resource["deletedAt"] is None:
return _uniqueness_error_response() return _uniqueness_error_response()
if existing_db_resource and existing_db_resource["deletedAt"] is not None: if existing_db_resource and existing_db_resource["deletedAt"] is not None:
db_resource = users.restore_scim_user(existing_db_resource["userId"], tenant_id) db_resource = users.restore_scim_user(
user_id=existing_db_resource["userId"],
tenant_id=tenant_id,
**db_payload,
)
else: else:
db_resource = users.create_scim_user( db_resource = users.create_scim_user(
email=payload["userName"],
# note(jon): scim schema does not require the `name.formatted` attribute, but we require `name`.
# so, we have to define the value ourselves here
name="",
tenant_id=tenant_id, tenant_id=tenant_id,
**db_payload,
) )
scim_resource = _convert_db_user_to_scim_user(db_resource) scim_resource = _serialize_db_user_to_scim_user(db_resource)
response = JSONResponse(status_code=201, content=scim_resource) response = JSONResponse(status_code=201, content=scim_resource)
response.headers["Location"] = scim_resource["meta"]["location"] response.headers["Location"] = scim_resource["meta"]["location"]
return response return response
@ -327,22 +354,26 @@ async def update_user(user_id: str, r: Request, tenant_id=Depends(auth_required)
db_resource = users.get_scim_user_by_id(user_id, tenant_id) db_resource = users.get_scim_user_by_id(user_id, tenant_id)
if not db_resource: if not db_resource:
return _not_found_error_response(user_id) return _not_found_error_response(user_id)
current_scim_resource = _convert_db_user_to_scim_user(db_resource) current_scim_resource = _serialize_db_user_to_scim_user(db_resource)
changes = await r.json() requested_scim_changes = await r.json()
schema = SCHEMA_IDS_TO_SCHEMA_DETAILS["urn:ietf:params:scim:schemas:core:2.0:User"] schema = SCHEMA_IDS_TO_SCHEMA_DETAILS["urn:ietf:params:scim:schemas:core:2.0:User"]
try: try:
valid_mutable_changes = scim_helpers.filter_mutable_attributes( valid_mutable_scim_changes = scim_helpers.filter_mutable_attributes(
schema, changes, current_scim_resource schema, requested_scim_changes, current_scim_resource
) )
except ValueError: except ValueError:
return _mutability_error_response() return _mutability_error_response()
valid_mutable_db_changes = _parse_scim_user_input(
valid_mutable_scim_changes,
tenant_id,
)
try: try:
updated_db_resource = users.update_scim_user( updated_db_resource = users.update_scim_user(
user_id, user_id,
tenant_id, tenant_id,
email=valid_mutable_changes["userName"], **valid_mutable_db_changes,
) )
updated_scim_resource = _convert_db_user_to_scim_user(updated_db_resource) updated_scim_resource = _serialize_db_user_to_scim_user(updated_db_resource)
return JSONResponse(status_code=200, content=updated_scim_resource) return JSONResponse(status_code=200, content=updated_scim_resource)
except Exception: except Exception:
# note(jon): for now, this is the only error that would happen when updating the scim user # note(jon): for now, this is the only error that would happen when updating the scim user
@ -351,6 +382,7 @@ async def update_user(user_id: str, r: Request, tenant_id=Depends(auth_required)
@public_app.delete("/Users/{user_id}") @public_app.delete("/Users/{user_id}")
async def delete_user(user_id: str, tenant_id=Depends(auth_required)): async def delete_user(user_id: str, tenant_id=Depends(auth_required)):
# note(jon): this is a soft delete
db_resource = users.get_scim_user_by_id(user_id, tenant_id) db_resource = users.get_scim_user_by_id(user_id, tenant_id)
if not db_resource: if not db_resource:
return _not_found_error_response(user_id) return _not_found_error_response(user_id)