Fix pagination and implement all patch group methods

This commit is contained in:
Pavel Kim 2025-02-18 16:40:52 +01:00 committed by Jonathan Griffin
parent e13008c006
commit 937e4d244c
3 changed files with 209 additions and 211 deletions

View file

@ -199,8 +199,31 @@ def get_roles_with_uuid(tenant_id):
r["created_at"] = TimeUTC.datetime_to_timestamp(r["created_at"])
return helper.list_to_camel_case(rows)
def get_roles_with_uuid_paginated(tenant_id, start_index, count=None, name=None):
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT roles.*, COALESCE(projects, '{}') AS projects
FROM public.roles
LEFT JOIN LATERAL (SELECT array_agg(project_id) AS projects
FROM roles_projects
INNER JOIN projects USING (project_id)
WHERE roles_projects.role_id = roles.role_id
AND projects.deleted_at ISNULL ) AS role_projects ON (TRUE)
WHERE tenant_id =%(tenant_id)s
AND data ? 'group_id'
AND deleted_at IS NULL
AND not service_role
AND name = COALESCE(%(name)s, name)
ORDER BY role_id
LIMIT %(count)s
OFFSET %(startIndex)s;""",
{"tenant_id": tenant_id, "name": name, "startIndex": start_index - 1, "count": count})
cur.execute(query=query)
rows = cur.fetchall()
return helper.list_to_camel_case(rows)
def get_role_by_name(tenant_id, name):
### "name" isn't unique in database
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT *
FROM public.roles
@ -303,4 +326,57 @@ def get_role_by_group_id(tenant_id, group_id):
row = cur.fetchone()
if row is not None:
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
return helper.dict_to_camel_case(row)
return helper.dict_to_camel_case(row)
def get_users_by_group_uuid(tenant_id, group_id):
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT
u.user_id,
u.name,
u.data
FROM public.roles r
LEFT JOIN public.users u USING (role_id, tenant_id)
WHERE u.tenant_id = %(tenant_id)s
AND u.deleted_at IS NULL
AND r.data->>'group_id' = %(group_id)s
""",
{"tenant_id": tenant_id, "group_id": group_id})
cur.execute(query=query)
rows = cur.fetchall()
return helper.list_to_camel_case(rows)
def get_member_permissions(tenant_id):
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT
r.permissions
FROM public.roles r
WHERE r.tenant_id = %(tenant_id)s
AND r.name = 'Member'
AND r.deleted_at IS NULL
""",
{"tenant_id": tenant_id})
cur.execute(query=query)
row = cur.fetchone()
return helper.dict_to_camel_case(row)
def remove_group_membership(tenant_id, group_id, user_id):
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""WITH r AS (
SELECT role_id
FROM public.roles
WHERE data->>'group_id' = %(group_id)s
LIMIT 1
)
UPDATE public.users u
SET role_id= NULL
FROM r
WHERE u.data->>'user_id' = %(user_id)s
AND u.role_id = r.role_id
AND u.tenant_id = %(tenant_id)s
AND u.deleted_at IS NULL
RETURNING *;""",
{"tenant_id": tenant_id, "group_id": group_id, "user_id": user_id})
cur.execute(query=query)
row = cur.fetchone()
return helper.dict_to_camel_case(row)

View file

@ -277,7 +277,7 @@ def get(user_id, tenant_id):
users.user_id = %(userId)s
AND users.tenant_id = %(tenant_id)s
AND users.deleted_at IS NULL
AND (roles.role_id IS NULL OR roles.deleted_at IS NULL AND roles.tenant_id = %(tenant_id)s)
AND (roles.role_id IS NULL OR roles.deleted_at IS NULL AND roles.tenant_id = %(tenant_id)s)
LIMIT 1;""",
{"userId": user_id, "tenant_id": tenant_id})
)
@ -318,7 +318,7 @@ def get_by_uuid(user_uuid, tenant_id):
)
r = cur.fetchone()
return helper.dict_to_camel_case(r)
def get_deleted_by_uuid(user_uuid, tenant_id):
with pg_client.PostgresClient() as cur:
cur.execute(
@ -481,37 +481,7 @@ def get_by_email_only(email):
r = cur.fetchone()
return helper.dict_to_camel_case(r)
def get_by_email_with_uuid(email):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
f"""SELECT
users.user_id,
users.tenant_id,
users.email,
users.role,
users.name,
users.data,
(CASE WHEN users.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin,
(CASE WHEN users.role = 'admin' THEN TRUE ELSE FALSE END) AS admin,
(CASE WHEN users.role = 'member' THEN TRUE ELSE FALSE END) AS member,
origin,
basic_authentication.password IS NOT NULL AS has_password,
role_id,
internal_id,
roles.name AS role_name
FROM public.users LEFT JOIN public.basic_authentication USING(user_id)
INNER JOIN public.roles USING(role_id)
WHERE users.email = %(email)s
AND users.deleted_at IS NULL
LIMIT 1;""",
{"email": email})
)
r = cur.fetchone()
return helper.dict_to_camel_case(r)
def get_users_paginated(start_index, count):
def get_users_paginated(start_index, count=None, email=None):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
@ -532,10 +502,12 @@ def get_users_paginated(start_index, count):
roles.name AS role_name
FROM public.users LEFT JOIN public.basic_authentication USING(user_id)
INNER JOIN public.roles USING(role_id)
WHERE users.deleted_at IS NULL AND users.data ? 'user_id'
WHERE users.deleted_at IS NULL
AND users.data ? 'user_id'
AND email = COALESCE(%(email)s, email)
LIMIT %(count)s
OFFSET %(startIndex)s;""",
{"startIndex": start_index - 1, "count": count})
OFFSET %(startIndex)s;;""",
{"startIndex": start_index - 1, "count": count, "email": email})
)
r = cur.fetchall()
if len(r):
@ -1313,52 +1285,6 @@ def restore_scim_user(
)
return helper.dict_to_camel_case(cur.fetchone())
def create_scim_user2(
tenant_id,
user_uuid,
username,
admin,
display_name,
full_name: dict,
emails,
origin,
locale,
role_id,
internal_id=None,
):
with pg_client.PostgresClient() as cur:
query = cur.mogrify(f"""\
WITH u AS (
INSERT INTO public.users (tenant_id, email, role, name, data, origin, internal_id, role_id)
VALUES (%(tenant_id)s, %(email)s, %(role)s, %(name)s, %(data)s, %(origin)s, %(internal_id)s,
(SELECT COALESCE((SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND role_id = %(role_id)s),
(SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name = 'Member' LIMIT 1),
(SELECT role_id FROM roles WHERE tenant_id = %(tenant_id)s AND name != 'Owner' LIMIT 1))))
RETURNING *
),
au AS (
INSERT INTO public.basic_authentication(user_id)
VALUES ((SELECT user_id FROM u))
)
SELECT u.user_id AS id,
u.email,
u.role,
u.name,
u.data,
(CASE WHEN u.role = 'owner' THEN TRUE ELSE FALSE END) AS super_admin,
(CASE WHEN u.role = 'admin' THEN TRUE ELSE FALSE END) AS admin,
(CASE WHEN u.role = 'member' THEN TRUE ELSE FALSE END) AS member,
origin
FROM u;""",
{"tenant_id": tenant_id, "email": username, "internal_id": internal_id,
"role": "admin" if admin else "member", "name": display_name, "origin": origin,
"role_id": role_id, "data": json.dumps({"lastAnnouncementView": TimeUTC.now(), "user_id": user_uuid, "locale": locale, "name": full_name, "emails": emails})})
cur.execute(
query
)
return helper.dict_to_camel_case(cur.fetchone())
def get_user_settings(user_id):
# read user settings from users.settings:jsonb column
with pg_client.PostgresClient() as cur:

View file

@ -1,4 +1,5 @@
import logging
import re
import uuid
from typing import Optional
@ -13,39 +14,20 @@ from routers.base import get_routers
logger = logging.getLogger(__name__)
public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2")
# Authentication Dependency
def auth_required(authorization: str = Header(..., alias="Authorization")):
"""Dependency to check Authorization header."""
token = authorization.replace("Bearer ", "")
if token != config("OCTA_TOKEN"):
raise HTTPException(status_code=403, detail="Unauthorized")
return token
"""
Models:
USER
schemas -> hardcoded
id -> from db
userName -> email, comes from Okta
name:
givenName -> from Okta
middleName -> from Okta
familyName -> from Okta
emails:
primary -> from Okta
value -> from Okta
type -> from Okta
displayName -> from Okta (potentially, givenName+" "+familyName)
locale -> from Okta (e.g. en-US)
externalId -> from Okta
active -> ! doesn't exist, but represent deleted users
groups -> users: {"display": group.displayName, "value": group.id}
meta -> hardcoded
GROUP
schemas -> hardcoded
id -> from db
meta -> hardcoded
displayName -> from db
members -> users: {"display": user.userName, "value": user.id}
User endpoints
"""
class Name(BaseModel):
@ -61,7 +43,7 @@ class UserRequest(BaseModel):
schemas: list[str]
userName: str
name: Name
emails: list[Email] # ignore for now
emails: list[Email]
displayName: str
locale: str
externalId: str
@ -87,36 +69,19 @@ class PatchUserRequest(BaseModel):
Operations: list[dict]
# Authentication Dependency
def auth_required(authorization: str = Header(..., alias="Authorization")):
"""Dependency to check Authorization header."""
token = authorization.replace("Bearer ", "")
if token != config("OCTA_TOKEN"):
raise HTTPException(status_code=403, detail="Unauthorized")
return token
public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2")
@public_app.get("/Users", dependencies=[Depends(auth_required)])
async def get_users(
start_index: int = Query(1, alias="startIndex"),
count: Optional[int] = Query(1, alias="count"),
filter: Optional[str] = Query(None, alias="filter"),
count: Optional[int] = Query(None, alias="count"),
email: Optional[str] = Query(None, alias="filter"),
):
"""Get SCIM Users"""
if filter:
single_filter = filter.split(" ")
filter_value = single_filter[2].strip('"')
filtered_users = users.get_by_email_with_uuid(filter_value)
filtered_users = [filtered_users] if filtered_users else []
else:
filtered_users = users.get_users_paginated(start_index, count)
if email:
email = email.split(" ")[2].strip('"')
result_users = users.get_users_paginated(start_index, count, email)
serialized_users = []
for user in filtered_users:
logger.info(user)
for user in result_users:
serialized_users.append(
UserResponse(
schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"],
@ -145,7 +110,8 @@ async def get_users(
@public_app.get("/Users/{user_id}", dependencies=[Depends(auth_required)])
def get_user(user_id: str):
"""Get SCIM User"""
user = users.get_by_uuid(user_id, 1)
tenant_id = 1
user = users.get_by_uuid(user_id, tenant_id)
if not user:
return JSONResponse(
status_code=404,
@ -173,8 +139,8 @@ def get_user(user_id: str):
@public_app.post("/Users", dependencies=[Depends(auth_required)])
async def create_user(r: UserRequest):
## This needs to manage addition of previously deactivated users
"""Create SCIM User"""
tenant_id = 1
existing_user = users.get_by_email_only(r.userName)
deleted_user = users.get_deleted_user_by_email(r.userName)
@ -188,22 +154,21 @@ async def create_user(r: UserRequest):
}
)
elif deleted_user:
user_id = users.get_deleted_by_uuid(deleted_user["data"]["userId"], 1)
user = users.restore_scim_user(user_id=user_id["userId"], tenant_id=1, user_uuid=uuid.uuid4().hex, email=r.emails[0].value, admin=False,
user_id = users.get_deleted_by_uuid(deleted_user["data"]["userId"], tenant_id)
user = users.restore_scim_user(user_id=user_id["userId"], tenant_id=tenant_id, user_uuid=uuid.uuid4().hex, email=r.emails[0].value, admin=False,
display_name=r.displayName, full_name=r.name.model_dump(mode='json'), emails=r.emails[0].model_dump(mode='json'),
origin="okta", locale=r.locale, role_id=None, internal_id=r.externalId)
else:
try:
# Need to handle groups later, for now ignore them
user = users.create_scim_user(tenant_id=1, user_uuid=uuid.uuid4().hex, email=r.emails[0].value, admin=False,
user = users.create_scim_user(tenant_id=tenant_id, user_uuid=uuid.uuid4().hex, email=r.emails[0].value, admin=False,
display_name=r.displayName, full_name=r.name.model_dump(mode='json'), emails=r.emails[0].model_dump(mode='json'),
origin="okta", locale=r.locale, role_id=None, internal_id=r.externalId) # role_id is set to 2 by default...
origin="okta", locale=r.locale, role_id=None, internal_id=r.externalId)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
res = UserResponse(
schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"],
id = user["data"]["userId"], # Transformed to camel case
id = user["data"]["userId"],
userName = r.userName,
name = r.name,
emails = r.emails,
@ -217,11 +182,11 @@ async def create_user(r: UserRequest):
@public_app.put("/Users/{user_id}", dependencies=[Depends(auth_required)]) # insert your header later
@public_app.put("/Users/{user_id}", dependencies=[Depends(auth_required)])
def update_user(user_id: str, r: UserRequest):
"""Update SCIM User"""
logger.info(r)
user = users.get_by_uuid(user_id, 1)
tenant_id = 1
user = users.get_by_uuid(user_id, tenant_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
@ -236,8 +201,7 @@ def update_user(user_id: str, r: UserRequest):
value_to_insert = v[0] if k == "emails" else v
changes["data"][k] = value_to_insert
try:
# Need to handle groups later, for now ignore them
users.update(1, user["userId"], changes)
users.update(tenant_id, user["userId"], changes)
res = UserResponse(
schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"],
id = user["data"]["userId"],
@ -258,22 +222,23 @@ def update_user(user_id: str, r: UserRequest):
@public_app.patch("/Users/{user_id}", dependencies=[Depends(auth_required)])
def deactivate_user(user_id: str, r: PatchUserRequest):
logger.info(r)
"""Deactivate user, soft-delete"""
tenant_id = 1
active = r.model_dump(mode='json')["Operations"][0]["value"]["active"]
logger.info(active)
if active:
raise HTTPException(status_code=404, detail="Activating user is not supported")
user = users.get_by_uuid(user_id, 1)
user = users.get_by_uuid(user_id, tenant_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
logger.info(user)
users.delete_member_as_admin(1, user["userId"])
users.delete_member_as_admin(tenant_id, user["userId"])
return Response(status_code=204, content="")
@public_app.delete("/Users/{user_uuid}", dependencies=[Depends(auth_required)])
def delete_user(user_uuid: str):
user = users.get_by_uuid(user_uuid, 1)
"""Delete user from database, hard-delete"""
tenant_id = 1
user = users.get_by_uuid(user_uuid, tenant_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
@ -281,15 +246,9 @@ def delete_user(user_uuid: str):
return Response(status_code=204, content="")
"""
Group endpoints
Potential issues:
1. Every user can be assigned only to single role
2. Deleting the group might be constrained by existing users linked to the role,
since those can't be left orphans
3.
"""
class Operation(BaseModel):
@ -297,6 +256,13 @@ class Operation(BaseModel):
path: str = Field(default=None)
value: list[dict] | dict = Field(default=None)
class GroupGetResponse(BaseModel):
schemas: list[str] = Field(default=["urn:ietf:params:scim:api:messages:2.0:ListResponse"])
totalResults: int
startIndex: int
itemsPerPage: int
resources: list = Field(alias="Resources")
class GroupRequest(BaseModel):
schemas: list[str] = Field(default=["urn:ietf:params:scim:schemas:core:2.0:Group"])
displayName: str = Field(default=None)
@ -308,80 +274,111 @@ class GroupPatchRequest(BaseModel):
operations: list[Operation] = Field(alias="Operations")
class GroupResponse(BaseModel):
schemas: list[str]
schemas: list[str] = Field(default=["urn:ietf:params:scim:schemas:core:2.0:Group"])
id: str
meta: dict = Field(default={"resourceType": "Group"})
displayName: str
members: list
meta: dict = Field(default={"resourceType": "Group"})
@public_app.get("/Groups", dependencies=[Depends(auth_required)])
def get_groups(): # Might need to add query params later
groups = roles.get_roles_with_uuid(1)
def get_groups(
start_index: int = Query(1, alias="startIndex"),
count: Optional[int] = Query(None, alias="count"),
group_name: Optional[str] = Query(None, alias="filter"),
):
"""Get groups"""
tenant_id = 1
res = []
for group in groups:
res.append(GroupResponse(
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
id=group["data"]["groupId"],
displayName=group["name"],
members=[], # add later
).model_dump(mode='json'))
if group_name:
group_name = group_name.split(" ")[2].strip('"')
groups = roles.get_roles_with_uuid_paginated(tenant_id, start_index, count, group_name)
res = [{
"id": group["data"]["groupId"],
"meta": {
"created": group["createdAt"],
"lastModified": "", # not currently a field
"version": "v1.0"
},
"displayName": group["name"]
} for group in groups
]
return JSONResponse(
status_code=200,
content=res
)
content=GroupGetResponse(
totalResults=len(groups),
startIndex=start_index,
itemsPerPage=len(groups),
Resources=res
).model_dump(mode='json'))
@public_app.get("/Groups/{group_id}", dependencies=[Depends(auth_required)])
def get_group(group_id: str):
group = roles.get_role_by_group_id(1, group_id)
"""Get a group by id"""
tenant_id = 1
group = roles.get_role_by_group_id(tenant_id, group_id)
if not group:
raise HTTPException(status_code=404, detail="Group not found")
members = roles.get_users_by_group_uuid(tenant_id, group["data"]["groupId"])
members = [{"value": member["data"]["userId"], "display": member["name"]} for member in members]
return JSONResponse(
status_code=200,
content=GroupResponse(
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
id=group["data"]["groupId"],
displayName=group["name"],
members=[], # add later
members=members,
).model_dump(mode='json'))
@public_app.post("/Groups", dependencies=[Depends(auth_required)])
def create_group(r: GroupRequest):
logger.info(r)
"""Create a group"""
tenant_id = 1
member_role = roles.get_member_permissions(tenant_id)
try:
data = schemas.RolePayloadSchema(name=r.displayName, permissions=[schemas.Permissions.SESSION_REPLAY]) # one permission for now
group = roles.create_as_admin(1, uuid.uuid4().hex, data)
data = schemas.RolePayloadSchema(name=r.displayName, permissions=member_role["permissions"]) # permissions by default are same as for member role
group = roles.create_as_admin(tenant_id, uuid.uuid4().hex, data)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
added_members = []
for member in r.members:
user = users.get_by_uuid(member["value"], tenant_id)
if user:
users.update(tenant_id, user["userId"], {"role_id": group["roleId"]})
added_members.append({
"value": user["data"]["userId"],
"display": user["name"]
})
return JSONResponse(
status_code=200,
content=GroupResponse(
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
id=group["data"]["groupId"],
displayName=group["name"],
members=[], # add later
members=added_members,
).model_dump(mode='json'))
@public_app.put("/Groups/{group_id}", dependencies=[Depends(auth_required)])
def update_put_group(group_id: str, r: GroupRequest):
# Possibly need to change GroupRequest object to accept a different structure
logger.info(r)
group = roles.get_role_by_group_id(1, group_id)
"""Update a group or members of the group (not used by anything yet)"""
tenant_id = 1
group = roles.get_role_by_group_id(tenant_id, group_id)
if not group:
raise HTTPException(status_code=404, detail="Group not found")
if r.operations and r.operations[0].op == "replace" and r.operations[0].path is None:
roles.update_group_name(1, group["data"]["groupId"], r.operations[0].value["displayName"])
roles.update_group_name(tenant_id, group["data"]["groupId"], r.operations[0].value["displayName"])
return Response(status_code=200, content="")
members = r.members
modified_members = []
for member in members:
user = users.get_by_uuid(member["value"], 1)
user = users.get_by_uuid(member["value"], tenant_id)
if user:
users.update(1, user["userId"], {"role_id": group["roleId"]})
users.update(tenant_id, user["userId"], {"role_id": group["roleId"]})
modified_members.append({
"value": user["data"]["userId"],
"display": user["name"]
@ -390,44 +387,41 @@ def update_put_group(group_id: str, r: GroupRequest):
return JSONResponse(
status_code=200,
content=GroupResponse(
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
id=group_id,
displayName=group["name"],
members=modified_members,
).model_dump(mode='json'))
@public_app.patch("/Groups/{group_id}", dependencies=[Depends(auth_required)])
def update_patch_group(group_id: str, r: GroupPatchRequest):
logger.info(r)
group = roles.get_role_by_group_id(1, group_id)
"""Update a group or members of the group, used by AIW"""
tenant_id = 1
group = roles.get_role_by_group_id(tenant_id, group_id)
if not group:
raise HTTPException(status_code=404, detail="Group not found")
if r.operations[0].op == "replace" and r.operations[0].path is None:
roles.update_group_name(1, group["data"]["groupId"], r.operations[0].value["displayName"])
roles.update_group_name(tenant_id, group["data"]["groupId"], r.operations[0].value["displayName"])
return Response(status_code=200, content="")
if r.operations[0].op == "replace":
# find all members of that role, and for those that don't intersect with the list, set them to default role and return
pass
modified_members = []
for op in r.operations:
if op.op == "add":
if op.op == "add" or op.op == "replace":
# Both methods work as "replace"
for u in op.value:
user = users.get_by_uuid(u["value"], 1)
user = users.get_by_uuid(u["value"], tenant_id)
if user:
users.update(1, user["userId"], {"role_id": group["roleId"]})
users.update(tenant_id, user["userId"], {"role_id": group["roleId"]})
modified_members.append({
"value": user["data"]["userId"],
"display": user["name"]
})
else:
# possibly remove by parsing the path?
pass
elif op.op == "remove":
user_id = re.search(r'\[value eq \"([a-f0-9]+)\"\]', op.path).group(1)
roles.remove_group_membership(tenant_id, group_id, user_id)
return JSONResponse(
status_code=200,
content=GroupResponse(
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
id=group_id,
displayName=group["name"],
members=modified_members,
@ -436,9 +430,11 @@ def update_patch_group(group_id: str, r: GroupPatchRequest):
@public_app.delete("/Groups/{group_id}", dependencies=[Depends(auth_required)])
def delete_group(group_id: str):
group = roles.get_role_by_group_id(1, group_id)
"""Delete a group, hard-delete"""
tenant_id = 1
group = roles.get_role_by_group_id(tenant_id, group_id)
if not group:
raise HTTPException(status_code=404, detail="Group not found")
roles.delete_scim_group(1, group["data"]["groupId"])
roles.delete_scim_group(tenant_id, group["data"]["groupId"])
return Response(status_code=200, content="")