reformat files and remove unnecessary imports
This commit is contained in:
parent
a8d36d40b5
commit
464b9b1b47
6 changed files with 588 additions and 300 deletions
|
|
@ -9,13 +9,15 @@ from chalicelib.utils.TimeUTC import TimeUTC
|
|||
|
||||
def __exists_by_name(tenant_id: int, name: str, exclude_id: Optional[int]) -> bool:
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify(f"""SELECT EXISTS(SELECT 1
|
||||
query = cur.mogrify(
|
||||
f"""SELECT EXISTS(SELECT 1
|
||||
FROM public.roles
|
||||
WHERE tenant_id = %(tenant_id)s
|
||||
AND name ILIKE %(name)s
|
||||
AND deleted_at ISNULL
|
||||
{"AND role_id!=%(exclude_id)s" if exclude_id else ""}) AS exists;""",
|
||||
{"tenant_id": tenant_id, "name": name, "exclude_id": exclude_id})
|
||||
{"tenant_id": tenant_id, "name": name, "exclude_id": exclude_id},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
row = cur.fetchone()
|
||||
return row["exists"]
|
||||
|
|
@ -27,24 +29,31 @@ def update(tenant_id, user_id, role_id, data: schemas.RolePayloadSchema):
|
|||
if not admin["admin"] and not admin["superAdmin"]:
|
||||
return {"errors": ["unauthorized"]}
|
||||
if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=role_id):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"name already exists.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="name already exists."
|
||||
)
|
||||
|
||||
if not data.all_projects and (data.projects is None or len(data.projects) == 0):
|
||||
return {"errors": ["must specify a project or all projects"]}
|
||||
if data.projects is not None and len(data.projects) > 0 and not data.all_projects:
|
||||
data.projects = projects.is_authorized_batch(project_ids=data.projects, tenant_id=tenant_id)
|
||||
data.projects = projects.is_authorized_batch(
|
||||
project_ids=data.projects, tenant_id=tenant_id
|
||||
)
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""SELECT 1
|
||||
query = cur.mogrify(
|
||||
"""SELECT 1
|
||||
FROM public.roles
|
||||
WHERE role_id = %(role_id)s
|
||||
AND tenant_id = %(tenant_id)s
|
||||
AND protected = TRUE
|
||||
LIMIT 1;""",
|
||||
{"tenant_id": tenant_id, "role_id": role_id})
|
||||
{"tenant_id": tenant_id, "role_id": role_id},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
if cur.fetchone() is not None:
|
||||
return {"errors": ["this role is protected"]}
|
||||
query = cur.mogrify("""UPDATE public.roles
|
||||
query = cur.mogrify(
|
||||
"""UPDATE public.roles
|
||||
SET name= %(name)s,
|
||||
description= %(description)s,
|
||||
permissions= %(permissions)s,
|
||||
|
|
@ -56,23 +65,31 @@ def update(tenant_id, user_id, role_id, data: schemas.RolePayloadSchema):
|
|||
RETURNING *, COALESCE((SELECT ARRAY_AGG(project_id)
|
||||
FROM roles_projects
|
||||
WHERE roles_projects.role_id=%(role_id)s),'{}') AS projects;""",
|
||||
{"tenant_id": tenant_id, "role_id": role_id, **data.model_dump()})
|
||||
{"tenant_id": tenant_id, "role_id": role_id, **data.model_dump()},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
row = cur.fetchone()
|
||||
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
|
||||
if not data.all_projects:
|
||||
d_projects = [i for i in row["projects"] if i not in data.projects]
|
||||
if len(d_projects) > 0:
|
||||
query = cur.mogrify("""DELETE FROM roles_projects
|
||||
query = cur.mogrify(
|
||||
"""DELETE FROM roles_projects
|
||||
WHERE role_id=%(role_id)s
|
||||
AND project_id IN %(project_ids)s""",
|
||||
{"role_id": role_id, "project_ids": tuple(d_projects)})
|
||||
{"role_id": role_id, "project_ids": tuple(d_projects)},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
n_projects = [i for i in data.projects if i not in row["projects"]]
|
||||
if len(n_projects) > 0:
|
||||
query = cur.mogrify(f"""INSERT INTO roles_projects(role_id, project_id)
|
||||
query = cur.mogrify(
|
||||
f"""INSERT INTO roles_projects(role_id, project_id)
|
||||
VALUES {",".join([f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(n_projects))])}""",
|
||||
{"role_id": role_id, **{f"project_id_{i}": p for i, p in enumerate(n_projects)}})
|
||||
{
|
||||
"role_id": role_id,
|
||||
**{f"project_id_{i}": p for i, p in enumerate(n_projects)},
|
||||
},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
row["projects"] = data.projects
|
||||
|
||||
|
|
@ -86,28 +103,44 @@ def create(tenant_id, user_id, data: schemas.RolePayloadSchema):
|
|||
return {"errors": ["unauthorized"]}
|
||||
|
||||
if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=None):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"name already exists.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="name already exists."
|
||||
)
|
||||
|
||||
if not data.all_projects and (data.projects is None or len(data.projects) == 0):
|
||||
return {"errors": ["must specify a project or all projects"]}
|
||||
if data.projects is not None and len(data.projects) > 0 and not data.all_projects:
|
||||
data.projects = projects.is_authorized_batch(project_ids=data.projects, tenant_id=tenant_id)
|
||||
data.projects = projects.is_authorized_batch(
|
||||
project_ids=data.projects, tenant_id=tenant_id
|
||||
)
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""INSERT INTO roles(tenant_id, name, description, permissions, all_projects)
|
||||
query = cur.mogrify(
|
||||
"""INSERT INTO roles(tenant_id, name, description, permissions, all_projects)
|
||||
VALUES (%(tenant_id)s, %(name)s, %(description)s, %(permissions)s::text[], %(all_projects)s)
|
||||
RETURNING *;""",
|
||||
{"tenant_id": tenant_id, "name": data.name, "description": data.description,
|
||||
"permissions": data.permissions, "all_projects": data.all_projects})
|
||||
{
|
||||
"tenant_id": tenant_id,
|
||||
"name": data.name,
|
||||
"description": data.description,
|
||||
"permissions": data.permissions,
|
||||
"all_projects": data.all_projects,
|
||||
},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
row = cur.fetchone()
|
||||
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
|
||||
row["projects"] = []
|
||||
if not data.all_projects:
|
||||
role_id = row["role_id"]
|
||||
query = cur.mogrify(f"""INSERT INTO roles_projects(role_id, project_id)
|
||||
query = cur.mogrify(
|
||||
f"""INSERT INTO roles_projects(role_id, project_id)
|
||||
VALUES {",".join(f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(data.projects)))}
|
||||
RETURNING project_id;""",
|
||||
{"role_id": role_id, **{f"project_id_{i}": p for i, p in enumerate(data.projects)}})
|
||||
{
|
||||
"role_id": role_id,
|
||||
**{f"project_id_{i}": p for i, p in enumerate(data.projects)},
|
||||
},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
row["projects"] = [r["project_id"] for r in cur.fetchall()]
|
||||
return helper.dict_to_camel_case(row)
|
||||
|
|
@ -115,7 +148,8 @@ def create(tenant_id, user_id, data: schemas.RolePayloadSchema):
|
|||
|
||||
def get_roles(tenant_id):
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""SELECT roles.*, COALESCE(projects, '{}') AS projects
|
||||
query = cur.mogrify(
|
||||
"""SELECT roles.*, COALESCE(projects, '{}') AS projects
|
||||
FROM public.roles
|
||||
LEFT JOIN LATERAL (SELECT array_agg(project_id) AS projects
|
||||
FROM roles_projects
|
||||
|
|
@ -126,7 +160,8 @@ def get_roles(tenant_id):
|
|||
AND deleted_at IS NULL
|
||||
AND not service_role
|
||||
ORDER BY role_id;""",
|
||||
{"tenant_id": tenant_id})
|
||||
{"tenant_id": tenant_id},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
rows = cur.fetchall()
|
||||
for r in rows:
|
||||
|
|
@ -136,12 +171,14 @@ def get_roles(tenant_id):
|
|||
|
||||
def get_role_by_name(tenant_id, name):
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""SELECT *
|
||||
query = cur.mogrify(
|
||||
"""SELECT *
|
||||
FROM public.roles
|
||||
WHERE tenant_id =%(tenant_id)s
|
||||
AND deleted_at IS NULL
|
||||
AND name ILIKE %(name)s;""",
|
||||
{"tenant_id": tenant_id, "name": name})
|
||||
{"tenant_id": tenant_id, "name": name},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
row = cur.fetchone()
|
||||
if row is not None:
|
||||
|
|
@ -155,45 +192,53 @@ def delete(tenant_id, user_id, role_id):
|
|||
if not admin["admin"] and not admin["superAdmin"]:
|
||||
return {"errors": ["unauthorized"]}
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""SELECT 1
|
||||
query = cur.mogrify(
|
||||
"""SELECT 1
|
||||
FROM public.roles
|
||||
WHERE role_id = %(role_id)s
|
||||
AND tenant_id = %(tenant_id)s
|
||||
AND protected = TRUE
|
||||
LIMIT 1;""",
|
||||
{"tenant_id": tenant_id, "role_id": role_id})
|
||||
{"tenant_id": tenant_id, "role_id": role_id},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
if cur.fetchone() is not None:
|
||||
return {"errors": ["this role is protected"]}
|
||||
query = cur.mogrify("""SELECT 1
|
||||
query = cur.mogrify(
|
||||
"""SELECT 1
|
||||
FROM public.users
|
||||
WHERE role_id = %(role_id)s
|
||||
AND tenant_id = %(tenant_id)s
|
||||
LIMIT 1;""",
|
||||
{"tenant_id": tenant_id, "role_id": role_id})
|
||||
{"tenant_id": tenant_id, "role_id": role_id},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
if cur.fetchone() is not None:
|
||||
return {"errors": ["this role is already attached to other user(s)"]}
|
||||
query = cur.mogrify("""UPDATE public.roles
|
||||
query = cur.mogrify(
|
||||
"""UPDATE public.roles
|
||||
SET deleted_at = timezone('utc'::text, now())
|
||||
WHERE role_id = %(role_id)s
|
||||
AND tenant_id = %(tenant_id)s
|
||||
AND protected = FALSE;""",
|
||||
{"tenant_id": tenant_id, "role_id": role_id})
|
||||
{"tenant_id": tenant_id, "role_id": role_id},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
return get_roles(tenant_id=tenant_id)
|
||||
|
||||
|
||||
def get_role(tenant_id, role_id):
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""SELECT roles.*
|
||||
query = cur.mogrify(
|
||||
"""SELECT roles.*
|
||||
FROM public.roles
|
||||
WHERE tenant_id =%(tenant_id)s
|
||||
AND deleted_at IS NULL
|
||||
AND not service_role
|
||||
AND role_id = %(role_id)s
|
||||
LIMIT 1;""",
|
||||
{"tenant_id": tenant_id, "role_id": role_id})
|
||||
{"tenant_id": tenant_id, "role_id": role_id},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
row = cur.fetchone()
|
||||
if row is not None:
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -13,8 +13,8 @@ REFRESH_SECRET_KEY = config("SCIM_REFRESH_SECRET_KEY")
|
|||
ALGORITHM = config("SCIM_JWT_ALGORITHM")
|
||||
ACCESS_TOKEN_EXPIRE_SECONDS = int(config("SCIM_ACCESS_TOKEN_EXPIRE_SECONDS"))
|
||||
REFRESH_TOKEN_EXPIRE_SECONDS = int(config("SCIM_REFRESH_TOKEN_EXPIRE_SECONDS"))
|
||||
AUDIENCE="okta_client"
|
||||
ISSUER=config("JWT_ISSUER"),
|
||||
AUDIENCE = "okta_client"
|
||||
ISSUER = (config("JWT_ISSUER"),)
|
||||
|
||||
# Simulated Okta Client Credentials
|
||||
# OKTA_CLIENT_ID = "okta-client"
|
||||
|
|
@ -23,7 +23,7 @@ ISSUER=config("JWT_ISSUER"),
|
|||
# class TokenRequest(BaseModel):
|
||||
# client_id: str
|
||||
# client_secret: str
|
||||
|
||||
|
||||
# async def authenticate_client(token_request: TokenRequest):
|
||||
# """Validate Okta Client Credentials and issue JWT"""
|
||||
# if token_request.client_id != OKTA_CLIENT_ID or token_request.client_secret != OKTA_CLIENT_SECRET:
|
||||
|
|
@ -31,6 +31,7 @@ ISSUER=config("JWT_ISSUER"),
|
|||
|
||||
# return {"access_token": create_jwt(), "token_type": "bearer"}
|
||||
|
||||
|
||||
def create_tokens(tenant_id):
|
||||
curr_time = time.time()
|
||||
access_payload = {
|
||||
|
|
@ -38,7 +39,7 @@ def create_tokens(tenant_id):
|
|||
"sub": "scim_server",
|
||||
"aud": AUDIENCE,
|
||||
"iss": ISSUER,
|
||||
"exp": ""
|
||||
"exp": "",
|
||||
}
|
||||
access_payload.update({"exp": curr_time + ACCESS_TOKEN_EXPIRE_SECONDS})
|
||||
access_token = jwt.encode(access_payload, ACCESS_SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
|
@ -49,18 +50,24 @@ def create_tokens(tenant_id):
|
|||
|
||||
return access_token, refresh_token
|
||||
|
||||
|
||||
def verify_access_token(token: str):
|
||||
try:
|
||||
payload = jwt.decode(token, ACCESS_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE)
|
||||
payload = jwt.decode(
|
||||
token, ACCESS_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE
|
||||
)
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(status_code=401, detail="Token expired")
|
||||
except jwt.InvalidTokenError:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
|
||||
def verify_refresh_token(token: str):
|
||||
try:
|
||||
payload = jwt.decode(token, REFRESH_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE)
|
||||
payload = jwt.decode(
|
||||
token, REFRESH_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE
|
||||
)
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(status_code=401, detail="Token expired")
|
||||
|
|
@ -69,6 +76,8 @@ def verify_refresh_token(token: str):
|
|||
|
||||
|
||||
required_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
|
||||
# Authentication Dependency
|
||||
def auth_required(token: str = Depends(required_oauth2_scheme)):
|
||||
"""Dependency to check Authorization header."""
|
||||
|
|
@ -78,6 +87,8 @@ def auth_required(token: str = Depends(required_oauth2_scheme)):
|
|||
|
||||
|
||||
optional_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
|
||||
|
||||
|
||||
def auth_optional(token: str | None = Depends(optional_oauth2_scheme)):
|
||||
if token is None:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
import copy
|
||||
from datetime import datetime
|
||||
|
||||
|
|
@ -9,11 +7,15 @@ from decouple import config
|
|||
from fastapi import Depends, HTTPException, Header, Query, Response, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from pydantic import BaseModel, Field, field_serializer
|
||||
from pydantic import BaseModel, field_serializer
|
||||
|
||||
import schemas
|
||||
from chalicelib.core import users, roles, tenants
|
||||
from chalicelib.utils.scim_auth import auth_optional, auth_required, create_tokens, verify_refresh_token
|
||||
from chalicelib.core import users, tenants
|
||||
from chalicelib.utils.scim_auth import (
|
||||
auth_optional,
|
||||
auth_required,
|
||||
create_tokens,
|
||||
verify_refresh_token,
|
||||
)
|
||||
from routers.base import get_routers
|
||||
from routers.scim_constants import RESOURCE_TYPES, SCHEMAS, SERVICE_PROVIDER_CONFIG
|
||||
from routers import scim_helpers
|
||||
|
|
@ -26,29 +28,41 @@ public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2")
|
|||
|
||||
"""Authentication endpoints"""
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
|
||||
# Login endpoint to generate tokens
|
||||
@public_app.post("/token")
|
||||
async def login(host: str = Header(..., alias="Host"), form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
async def login(
|
||||
host: str = Header(..., alias="Host"),
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
):
|
||||
subdomain = host.split(".")[0]
|
||||
|
||||
# Missing authentication part, to add
|
||||
if form_data.username != config("SCIM_USER") or form_data.password != config("SCIM_PASSWORD"):
|
||||
if form_data.username != config("SCIM_USER") or form_data.password != config(
|
||||
"SCIM_PASSWORD"
|
||||
):
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
tenant = tenants.get_by_name(subdomain)
|
||||
access_token, refresh_token = create_tokens(tenant_id=tenant["tenantId"])
|
||||
|
||||
return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"}
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
|
||||
# Refresh token endpoint
|
||||
@public_app.post("/refresh")
|
||||
async def refresh_token(r: RefreshRequest):
|
||||
|
||||
payload = verify_refresh_token(r.refresh_token)
|
||||
new_access_token, _ = create_tokens(tenant_id=payload["tenant_id"])
|
||||
|
||||
|
|
@ -68,7 +82,7 @@ def _not_found_error_response(resource_id: str):
|
|||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
|
||||
"detail": f"Resource {resource_id} not found",
|
||||
"status": "404",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -80,7 +94,7 @@ def _uniqueness_error_response():
|
|||
"detail": "One or more of the attribute values are already in use or are reserved.",
|
||||
"status": "409",
|
||||
"scimType": "uniqueness",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -92,7 +106,7 @@ def _mutability_error_response():
|
|||
"detail": "The attempted modification is not compatible with the target attribute's mutability or current state.",
|
||||
"status": "400",
|
||||
"scimType": "mutability",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -105,7 +119,7 @@ async def get_resource_types(filter_param: str | None = Query(None, alias="filte
|
|||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
|
||||
"detail": "Operation is not permitted based on the supplied authorization",
|
||||
"status": "403",
|
||||
}
|
||||
},
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
|
|
@ -130,8 +144,7 @@ async def get_resource_type(resource_id: str):
|
|||
|
||||
|
||||
SCHEMA_IDS_TO_SCHEMA_DETAILS = {
|
||||
schema_detail["id"]: schema_detail
|
||||
for schema_detail in SCHEMAS
|
||||
schema_detail["id"]: schema_detail for schema_detail in SCHEMAS
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -144,7 +157,7 @@ async def get_schemas(filter_param: str | None = Query(None, alias="filter")):
|
|||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
|
||||
"detail": "Operation is not permitted based on the supplied authorization",
|
||||
"status": "403",
|
||||
}
|
||||
},
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
|
|
@ -154,9 +167,8 @@ async def get_schemas(filter_param: str | None = Query(None, alias="filter")):
|
|||
"startIndex": 1,
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
|
||||
"Resources": [
|
||||
value
|
||||
for _, value in sorted(SCHEMA_IDS_TO_SCHEMA_DETAILS.items())
|
||||
]
|
||||
value for _, value in sorted(SCHEMA_IDS_TO_SCHEMA_DETAILS.items())
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -174,7 +186,9 @@ async def get_schema(schema_id: str):
|
|||
# note(jon): it was recommended to make this endpoint partially open
|
||||
# so that clients can view the `authenticationSchemes` prior to being authenticated.
|
||||
@public_app.get("/ServiceProviderConfig")
|
||||
async def get_service_provider_config(r: Request, tenant_id: str | None = Depends(auth_optional)):
|
||||
async def get_service_provider_config(
|
||||
r: Request, tenant_id: str | None = Depends(auth_optional)
|
||||
):
|
||||
content = copy.deepcopy(SERVICE_PROVIDER_CONFIG)
|
||||
content["meta"]["location"] = str(r.url)
|
||||
is_authenticated = tenant_id is not None
|
||||
|
|
@ -193,6 +207,8 @@ async def get_service_provider_config(r: Request, tenant_id: str | None = Depend
|
|||
"""
|
||||
User endpoints
|
||||
"""
|
||||
|
||||
|
||||
class UserRequest(BaseModel):
|
||||
userName: str
|
||||
|
||||
|
|
@ -203,7 +219,9 @@ class PatchUserRequest(BaseModel):
|
|||
|
||||
|
||||
class ResourceMetaResponse(BaseModel):
|
||||
resourceType: Literal["ServiceProviderConfig", "ResourceType", "Schema", "User"] | None = None
|
||||
resourceType: (
|
||||
Literal["ServiceProviderConfig", "ResourceType", "Schema", "User"] | None
|
||||
) = None
|
||||
created: datetime | None = None
|
||||
lastModified: datetime | None = None
|
||||
location: str | None = None
|
||||
|
|
@ -231,12 +249,16 @@ class CommonResourceResponse(BaseModel):
|
|||
|
||||
|
||||
class UserResponse(CommonResourceResponse):
|
||||
schemas: list[Literal["urn:ietf:params:scim:schemas:core:2.0:User"]] = ["urn:ietf:params:scim:schemas:core:2.0:User"]
|
||||
schemas: list[Literal["urn:ietf:params:scim:schemas:core:2.0:User"]] = [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
]
|
||||
userName: str | None = None
|
||||
|
||||
|
||||
class QueryResourceResponse(BaseModel):
|
||||
schemas: list[Literal["urn:ietf:params:scim:api:messages:2.0:ListResponse"]] = ["urn:ietf:params:scim:api:messages:2.0:ListResponse"]
|
||||
schemas: list[Literal["urn:ietf:params:scim:api:messages:2.0:ListResponse"]] = [
|
||||
"urn:ietf:params:scim:api:messages:2.0:ListResponse"
|
||||
]
|
||||
totalResults: int
|
||||
# todo(jon): add the other schemas
|
||||
Resources: list[UserResponse]
|
||||
|
|
@ -247,21 +269,33 @@ class QueryResourceResponse(BaseModel):
|
|||
MAX_USERS_PER_PAGE = 10
|
||||
|
||||
|
||||
def _convert_db_user_to_scim_user(db_user: dict[str, Any], attributes: list[str] | None = None, excluded_attributes: list[str] | None = None) -> UserResponse:
|
||||
user_schema = SCHEMA_IDS_TO_SCHEMA_DETAILS["urn:ietf:params:scim:schemas:core:2.0:User"]
|
||||
def _convert_db_user_to_scim_user(
|
||||
db_user: dict[str, Any],
|
||||
attributes: list[str] | None = None,
|
||||
excluded_attributes: list[str] | None = None,
|
||||
) -> UserResponse:
|
||||
user_schema = SCHEMA_IDS_TO_SCHEMA_DETAILS[
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
]
|
||||
all_attributes = scim_helpers.get_all_attribute_names(user_schema)
|
||||
attributes = attributes or all_attributes
|
||||
always_returned_attributes = scim_helpers.get_all_attribute_names_where_returned_is_always(user_schema)
|
||||
always_returned_attributes = (
|
||||
scim_helpers.get_all_attribute_names_where_returned_is_always(user_schema)
|
||||
)
|
||||
included_attributes = list(set(attributes).union(set(always_returned_attributes)))
|
||||
excluded_attributes = excluded_attributes or []
|
||||
excluded_attributes = list(set(excluded_attributes).difference(set(always_returned_attributes)))
|
||||
excluded_attributes = list(
|
||||
set(excluded_attributes).difference(set(always_returned_attributes))
|
||||
)
|
||||
scim_user = {
|
||||
"id": str(db_user["userId"]),
|
||||
"meta": {
|
||||
"resourceType": "User",
|
||||
"created": db_user["createdAt"],
|
||||
"lastModified": db_user["createdAt"], # todo(jon): we currently don't keep track of this in the db
|
||||
"location": f"Users/{db_user['userId']}"
|
||||
"lastModified": db_user[
|
||||
"createdAt"
|
||||
], # todo(jon): we currently don't keep track of this in the db
|
||||
"location": f"Users/{db_user['userId']}",
|
||||
},
|
||||
"userName": db_user["email"],
|
||||
}
|
||||
|
|
@ -272,14 +306,16 @@ def _convert_db_user_to_scim_user(db_user: dict[str, Any], attributes: list[str]
|
|||
|
||||
@public_app.get("/Users")
|
||||
async def get_users(
|
||||
tenant_id = Depends(auth_required),
|
||||
tenant_id=Depends(auth_required),
|
||||
requested_start_index: int = Query(1, alias="startIndex"),
|
||||
requested_items_per_page: int | None = Query(None, alias="count"),
|
||||
attributes: list[str] | None = Query(None),
|
||||
excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"),
|
||||
):
|
||||
start_index = max(1, requested_start_index)
|
||||
items_per_page = min(max(0, requested_items_per_page or MAX_USERS_PER_PAGE), MAX_USERS_PER_PAGE)
|
||||
items_per_page = min(
|
||||
max(0, requested_items_per_page or MAX_USERS_PER_PAGE), MAX_USERS_PER_PAGE
|
||||
)
|
||||
# todo(jon): this might not be the most efficient thing to do. could be better to just do a count.
|
||||
# but this is the fastest thing at the moment just to test that it's working
|
||||
total_users = users.get_users_paginated(1, tenant_id)
|
||||
|
|
@ -302,7 +338,7 @@ async def get_users(
|
|||
@public_app.get("/Users/{user_id}")
|
||||
def get_user(
|
||||
user_id: str,
|
||||
tenant_id = Depends(auth_required),
|
||||
tenant_id=Depends(auth_required),
|
||||
attributes: list[str] | None = Query(None),
|
||||
excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"),
|
||||
):
|
||||
|
|
@ -311,13 +347,12 @@ def get_user(
|
|||
return _not_found_error_response(user_id)
|
||||
scim_user = _convert_db_user_to_scim_user(db_user, attributes, excluded_attributes)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content=scim_user.model_dump(mode="json", exclude_none=True)
|
||||
status_code=200, content=scim_user.model_dump(mode="json", exclude_none=True)
|
||||
)
|
||||
|
||||
|
||||
@public_app.post("/Users")
|
||||
async def create_user(r: UserRequest, tenant_id = Depends(auth_required)):
|
||||
async def create_user(r: UserRequest, tenant_id=Depends(auth_required)):
|
||||
# note(jon): this method will return soft deleted users as well
|
||||
existing_db_user = users.get_existing_scim_user_by_unique_values(r.userName)
|
||||
if existing_db_user and existing_db_user["deletedAt"] is None:
|
||||
|
|
@ -334,23 +369,26 @@ async def create_user(r: UserRequest, tenant_id = Depends(auth_required)):
|
|||
)
|
||||
scim_user = _convert_db_user_to_scim_user(db_user)
|
||||
response = JSONResponse(
|
||||
status_code=201,
|
||||
content=scim_user.model_dump(mode="json", exclude_none=True)
|
||||
status_code=201, content=scim_user.model_dump(mode="json", exclude_none=True)
|
||||
)
|
||||
response.headers["Location"] = scim_user.meta.location
|
||||
return response
|
||||
|
||||
|
||||
@public_app.put("/Users/{user_id}")
|
||||
def update_user(user_id: str, r: UserRequest, tenant_id = Depends(auth_required)):
|
||||
def update_user(user_id: str, r: UserRequest, tenant_id=Depends(auth_required)):
|
||||
db_resource = users.get_scim_user_by_id(user_id, tenant_id)
|
||||
if not db_resource:
|
||||
return _not_found_error_response(user_id)
|
||||
current_scim_resource = _convert_db_user_to_scim_user(db_resource).model_dump(mode="json", exclude_none=True)
|
||||
current_scim_resource = _convert_db_user_to_scim_user(db_resource).model_dump(
|
||||
mode="json", exclude_none=True
|
||||
)
|
||||
changes = r.model_dump(mode="json")
|
||||
schema = SCHEMA_IDS_TO_SCHEMA_DETAILS["urn:ietf:params:scim:schemas:core:2.0:User"]
|
||||
try:
|
||||
valid_mutable_changes = scim_helpers.filter_mutable_attributes(schema, changes, current_scim_resource)
|
||||
valid_mutable_changes = scim_helpers.filter_mutable_attributes(
|
||||
schema, changes, current_scim_resource
|
||||
)
|
||||
except ValueError:
|
||||
# todo(jon): will need to add a test for this once we have an immutable field
|
||||
return _mutability_error_response()
|
||||
|
|
@ -371,7 +409,7 @@ def update_user(user_id: str, r: UserRequest, tenant_id = Depends(auth_required)
|
|||
|
||||
|
||||
@public_app.delete("/Users/{user_id}")
|
||||
def delete_user(user_id: str, tenant_id = Depends(auth_required)):
|
||||
def delete_user(user_id: str, tenant_id=Depends(auth_required)):
|
||||
user = users.get_scim_user_by_id(user_id, tenant_id)
|
||||
if not user:
|
||||
return _not_found_error_response(user_id)
|
||||
|
|
|
|||
|
|
@ -1,22 +1,22 @@
|
|||
# note(jon): please see https://datatracker.ietf.org/doc/html/rfc7643 for details on these constants
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _attribute_characteristics(
|
||||
name: str,
|
||||
description: str,
|
||||
type: str="string",
|
||||
sub_attributes: dict[str, Any] | None=None,
|
||||
# note(jon): no default for multiValued is defined in the docs and it is marked as optional.
|
||||
# from our side, we'll default it to False.
|
||||
multi_valued: bool=False,
|
||||
required: bool=False,
|
||||
canonical_values: list[str] | None=None,
|
||||
case_exact: bool=False,
|
||||
mutability: str="readWrite",
|
||||
returned: str="default",
|
||||
uniqueness: str="none",
|
||||
reference_types: list[str] | None=None,
|
||||
name: str,
|
||||
description: str,
|
||||
type: str = "string",
|
||||
sub_attributes: dict[str, Any] | None = None,
|
||||
# note(jon): no default for multiValued is defined in the docs and it is marked as optional.
|
||||
# from our side, we'll default it to False.
|
||||
multi_valued: bool = False,
|
||||
required: bool = False,
|
||||
canonical_values: list[str] | None = None,
|
||||
case_exact: bool = False,
|
||||
mutability: str = "readWrite",
|
||||
returned: str = "default",
|
||||
uniqueness: str = "none",
|
||||
reference_types: list[str] | None = None,
|
||||
):
|
||||
characteristics = {
|
||||
"name": name,
|
||||
|
|
@ -33,14 +33,16 @@ def _attribute_characteristics(
|
|||
"referenceTypes": reference_types,
|
||||
}
|
||||
characteristics_without_none = {
|
||||
key: value
|
||||
for key, value in characteristics.items()
|
||||
if value is not None
|
||||
key: value for key, value in characteristics.items() if value is not None
|
||||
}
|
||||
return characteristics_without_none
|
||||
|
||||
|
||||
def _multi_valued_attributes(type_canonical_values: list[str], type_required: bool=False, type_mutability="readWrite"):
|
||||
def _multi_valued_attributes(
|
||||
type_canonical_values: list[str],
|
||||
type_required: bool = False,
|
||||
type_mutability="readWrite",
|
||||
):
|
||||
return [
|
||||
_attribute_characteristics(
|
||||
name="type",
|
||||
|
|
@ -68,7 +70,7 @@ def _multi_valued_attributes(type_canonical_values: list[str], type_required: bo
|
|||
name="$ref",
|
||||
type="reference",
|
||||
reference_types=["uri"],
|
||||
description="The reference URI of a target resource."
|
||||
description="The reference URI of a target resource.",
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -77,7 +79,7 @@ def _multi_valued_attributes(type_canonical_values: list[str], type_required: bo
|
|||
# in section 3.1 of RFC7643, it is specified that ResourceType and
|
||||
# ServiceProviderConfig are not included in the common attributes. but
|
||||
# in other references, they treat them as a resource.
|
||||
def _common_resource_attributes(id_required: bool=True, id_uniqueness: str="none"):
|
||||
def _common_resource_attributes(id_required: bool = True, id_uniqueness: str = "none"):
|
||||
return [
|
||||
_attribute_characteristics(
|
||||
name="id",
|
||||
|
|
@ -151,7 +153,6 @@ def _common_resource_attributes(id_required: bool=True, id_uniqueness: str="none
|
|||
]
|
||||
|
||||
|
||||
|
||||
SERVICE_PROVIDER_CONFIG_SCHEMA = {
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"],
|
||||
"id": "urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig",
|
||||
|
|
@ -339,7 +340,7 @@ SERVICE_PROVIDER_CONFIG_SCHEMA = {
|
|||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -409,9 +410,9 @@ RESOURCE_TYPE_SCHEMA = {
|
|||
required=True,
|
||||
mutability="readOnly",
|
||||
),
|
||||
]
|
||||
],
|
||||
),
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
SCHEMA_SCHEMA = {
|
||||
|
|
@ -548,7 +549,7 @@ SCHEMA_SCHEMA = {
|
|||
canonical_values=[
|
||||
# todo(jon): add "User" and "Group" once those are done.
|
||||
"external",
|
||||
"uri"
|
||||
"uri",
|
||||
],
|
||||
case_exact=True,
|
||||
),
|
||||
|
|
@ -659,15 +660,15 @@ SCHEMA_SCHEMA = {
|
|||
canonical_values=[
|
||||
# todo(jon): add "User" and "Group" once those are done.
|
||||
"external",
|
||||
"uri"
|
||||
"uri",
|
||||
],
|
||||
case_exact=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
]
|
||||
],
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -749,7 +750,7 @@ SERVICE_PROVIDER_CONFIG = {
|
|||
# and then updating these timestamps from an api and such. for now, if we update
|
||||
# the configuration, we should update the timestamp here.
|
||||
"lastModified": "2025-04-15T15:45:00Z",
|
||||
"location": "", # note(jon): this field will be computed in the /ServiceProviderConfig endpoint
|
||||
"location": "", # note(jon): this field will be computed in the /ServiceProviderConfig endpoint
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from copy import deepcopy
|
|||
|
||||
def get_all_attribute_names(schema: dict[str, Any]) -> list[str]:
|
||||
result = []
|
||||
|
||||
def _walk(attrs, prefix=None):
|
||||
for attr in attrs:
|
||||
name = attr["name"]
|
||||
|
|
@ -12,12 +13,16 @@ def get_all_attribute_names(schema: dict[str, Any]) -> list[str]:
|
|||
if attr["type"] == "complex":
|
||||
sub = attr.get("subAttributes") or attr.get("attributes") or []
|
||||
_walk(sub, path)
|
||||
|
||||
_walk(schema["attributes"])
|
||||
return result
|
||||
|
||||
|
||||
def get_all_attribute_names_where_returned_is_always(schema: dict[str, Any]) -> list[str]:
|
||||
def get_all_attribute_names_where_returned_is_always(
|
||||
schema: dict[str, Any],
|
||||
) -> list[str]:
|
||||
result = []
|
||||
|
||||
def _walk(attrs, prefix=None):
|
||||
for attr in attrs:
|
||||
name = attr["name"]
|
||||
|
|
@ -27,11 +32,14 @@ def get_all_attribute_names_where_returned_is_always(schema: dict[str, Any]) ->
|
|||
if attr["type"] == "complex":
|
||||
sub = attr.get("subAttributes") or attr.get("attributes") or []
|
||||
_walk(sub, path)
|
||||
|
||||
_walk(schema["attributes"])
|
||||
return result
|
||||
|
||||
|
||||
def filter_attributes(resource: dict[str, Any], include_list: list[str]) -> dict[str, Any]:
|
||||
def filter_attributes(
|
||||
resource: dict[str, Any], include_list: list[str]
|
||||
) -> dict[str, Any]:
|
||||
result = {}
|
||||
for attr in include_list:
|
||||
parts = attr.split(".", 1)
|
||||
|
|
@ -63,7 +71,9 @@ def filter_attributes(resource: dict[str, Any], include_list: list[str]) -> dict
|
|||
return result
|
||||
|
||||
|
||||
def exclude_attributes(resource: dict[str, Any], exclude_list: list[str]) -> dict[str, Any]:
|
||||
def exclude_attributes(
|
||||
resource: dict[str, Any], exclude_list: list[str]
|
||||
) -> dict[str, Any]:
|
||||
exclude_map = {}
|
||||
for attr in exclude_list:
|
||||
parts = attr.split(".", 1)
|
||||
|
|
@ -105,7 +115,11 @@ def exclude_attributes(resource: dict[str, Any], exclude_list: list[str]) -> dic
|
|||
return new_resource
|
||||
|
||||
|
||||
def filter_mutable_attributes(schema: dict[str, Any], requested_changes: dict[str, Any], current: dict[str, Any]) -> dict[str, Any]:
|
||||
def filter_mutable_attributes(
|
||||
schema: dict[str, Any],
|
||||
requested_changes: dict[str, Any],
|
||||
current_values: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
attributes = {attr.get("name"): attr for attr in schema.get("attributes", [])}
|
||||
|
||||
valid_changes = {}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue