openreplay/ee/api/routers/scim/api.py

540 lines
19 KiB
Python

import logging
from copy import deepcopy
from enum import Enum
from urllib.parse import urlencode
from chalicelib.utils import pg_client
from fastapi import Depends, HTTPException, Query, Response, Request
from fastapi.responses import JSONResponse, RedirectResponse
from psycopg2 import errors
from chalicelib.core import roles
from chalicelib.utils.scim_auth import (
auth_optional,
auth_required,
create_tokens,
verify_refresh_token,
)
from routers.base import get_routers
from routers.scim.constants import (
SERVICE_PROVIDER_CONFIG,
RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS,
SCHEMA_IDS_TO_SCHEMA_DETAILS,
)
from routers.scim import helpers, groups, users
from routers.scim.resource_config import ResourceConfig
from routers.scim import resource_config as api_helper
logger = logging.getLogger(__name__)
public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2")
@public_app.post("/token")
async def post_token(r: Request):
form = await r.form()
client_id = form.get("client_id")
client_secret = form.get("client_secret")
with pg_client.PostgresClient() as cur:
try:
cur.execute(
cur.mogrify(
"""
SELECT tenant_id
FROM public.tenants
WHERE tenant_id=%(tenant_id)s AND tenant_key=%(tenant_key)s
""",
{"tenant_id": int(client_id), "tenant_key": client_secret},
)
)
except ValueError:
raise HTTPException(status_code=401, detail="Invalid credentials")
tenant = cur.fetchone()
if not tenant:
raise HTTPException(status_code=401, detail="Invalid credentials")
grant_type = form.get("grant_type")
if grant_type == "refresh_token":
refresh_token = form.get("refresh_token")
verify_refresh_token(refresh_token)
else:
code = form.get("code")
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT *
FROM public.scim_auth_codes
WHERE auth_code=%(auth_code)s AND tenant_id=%(tenant_id)s AND used IS FALSE
""",
{"auth_code": code, "tenant_id": int(client_id)},
)
)
if cur.fetchone() is None:
raise HTTPException(
status_code=401, detail="Invalid code/client_id pair"
)
cur.execute(
cur.mogrify(
"""
UPDATE public.scim_auth_codes
SET used=TRUE
WHERE auth_code=%(auth_code)s AND tenant_id=%(tenant_id)s AND used IS FALSE
""",
{"auth_code": code, "tenant_id": int(client_id)},
)
)
access_token, refresh_token, expires_in = create_tokens(
tenant_id=tenant["tenant_id"]
)
return {
"access_token": access_token,
"token_type": "Bearer",
"expires_in": expires_in,
"refresh_token": refresh_token,
}
# note(jon): this might be specific to okta. if so, we should probably put specify that in the endpoint
@public_app.get("/authorize")
async def get_authorize(
r: Request,
response_type: str,
client_id: str,
redirect_uri: str,
state: str | None = None,
):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
UPDATE public.scim_auth_codes
SET used=TRUE
WHERE tenant_id=%(tenant_id)s
""",
{"tenant_id": int(client_id)},
)
)
cur.execute(
cur.mogrify(
"""
INSERT INTO public.scim_auth_codes (tenant_id)
VALUES (%(tenant_id)s)
RETURNING auth_code
""",
{"tenant_id": int(client_id)},
)
)
code = cur.fetchone()["auth_code"]
params = {"code": code}
if state:
params["state"] = state
url = f"{redirect_uri}?{urlencode(params)}"
return RedirectResponse(url)
def _not_found_error_response(resource_id: int):
return JSONResponse(
status_code=404,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": f"Resource {resource_id} not found",
"status": "404",
},
)
def _uniqueness_error_response():
return JSONResponse(
status_code=409,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": "One or more of the attribute values are already in use or are reserved.",
"status": "409",
"scimType": "uniqueness",
},
)
def _mutability_error_response():
return JSONResponse(
status_code=400,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": "The attempted modification is not compatible with the target attribute's mutability or current state.",
"status": "400",
"scimType": "mutability",
},
)
def _operation_not_permitted_error_response():
return JSONResponse(
status_code=403,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": "Operation is not permitted based on the supplied authorization",
"status": "403",
},
)
def _invalid_value_error_response():
return JSONResponse(
status_code=400,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": "A required value was missing, or the value specified was not compatible with the operation or attribtue type, or resource schema.",
"status": "400",
"scimType": "invalidValue",
},
)
def _internal_server_error_response(detail: str):
return JSONResponse(
status_code=500,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": detail,
"status": "500",
},
)
# note(jon): it was recommended to make this endpoint partially open
# so that clients can view the `authenticationSchemes` prior to being authenticated.
@public_app.get("/ServiceProviderConfig")
async def get_service_provider_config(
r: Request, tenant_id: str | None = Depends(auth_optional)
):
is_authenticated = tenant_id is not None
if not is_authenticated:
return JSONResponse(
status_code=200,
content={
"schemas": SERVICE_PROVIDER_CONFIG["schemas"],
"authenticationSchemes": SERVICE_PROVIDER_CONFIG[
"authenticationSchemes"
],
"meta": SERVICE_PROVIDER_CONFIG["meta"],
},
)
return JSONResponse(status_code=200, content=SERVICE_PROVIDER_CONFIG)
@public_app.get("/ResourceTypes", dependencies=[Depends(auth_required)])
async def get_resource_types(filter_param: str | None = Query(None, alias="filter")):
if filter_param is not None:
return _operation_not_permitted_error_response()
return JSONResponse(
status_code=200,
content={
"totalResults": len(RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS),
"itemsPerPage": len(RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS),
"startIndex": 1,
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
"Resources": list(RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS.values()),
},
)
@public_app.get("/ResourceTypes/{resource_id}", dependencies=[Depends(auth_required)])
async def get_resource_type(resource_id: str):
if resource_id not in RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS:
return _not_found_error_response(resource_id)
return JSONResponse(
status_code=200,
content=RESOURCE_TYPE_IDS_TO_RESOURCE_TYPE_DETAILS[resource_id],
)
@public_app.get("/Schemas", dependencies=[Depends(auth_required)])
async def get_schemas(filter_param: str | None = Query(None, alias="filter")):
if filter_param is not None:
return _operation_not_permitted_error_response()
return JSONResponse(
status_code=200,
content={
"totalResults": len(SCHEMA_IDS_TO_SCHEMA_DETAILS),
"itemsPerPage": len(SCHEMA_IDS_TO_SCHEMA_DETAILS),
"startIndex": 1,
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
"Resources": [
value for _, value in sorted(SCHEMA_IDS_TO_SCHEMA_DETAILS.items())
],
},
)
@public_app.get("/Schemas/{schema_id}")
async def get_schema(schema_id: str, tenant_id=Depends(auth_required)):
if schema_id not in SCHEMA_IDS_TO_SCHEMA_DETAILS:
return _not_found_error_response(schema_id)
schema = deepcopy(SCHEMA_IDS_TO_SCHEMA_DETAILS[schema_id])
if schema_id == "urn:ietf:params:scim:schemas:core:2.0:User":
db_roles = roles.get_roles(tenant_id)
role_names = [role["name"] for role in db_roles]
user_type_attribute = next(
filter(lambda x: x["name"] == "userType", schema["attributes"])
)
user_type_attribute["canonicalValues"] = role_names
return JSONResponse(
status_code=200,
content=schema,
)
user_config = ResourceConfig(
resource_type_id="User",
max_chunk_size=10,
get_active_resource_count=users.get_active_resource_count,
convert_provider_resource_to_client_resource=users.convert_provider_resource_to_client_resource,
get_provider_resource_chunk=users.get_provider_resource_chunk,
get_provider_resource=users.get_provider_resource,
convert_client_resource_creation_input_to_provider_resource_creation_input=users.convert_client_resource_creation_input_to_provider_resource_creation_input,
get_provider_resource_from_unique_fields=users.get_provider_resource_from_unique_fields,
restore_provider_resource=users.restore_provider_resource,
create_provider_resource=users.create_provider_resource,
delete_provider_resource=users.delete_provider_resource,
convert_client_resource_rewrite_input_to_provider_resource_rewrite_input=users.convert_client_resource_rewrite_input_to_provider_resource_rewrite_input,
rewrite_provider_resource=users.rewrite_provider_resource,
convert_client_resource_update_input_to_provider_resource_update_input=users.convert_client_resource_update_input_to_provider_resource_update_input,
update_provider_resource=users.update_provider_resource,
filter_attribute_mapping=users.filter_attribute_mapping,
)
group_config = ResourceConfig(
resource_type_id="Group",
max_chunk_size=10,
get_active_resource_count=groups.get_active_resource_count,
convert_provider_resource_to_client_resource=groups.convert_provider_resource_to_client_resource,
get_provider_resource_chunk=groups.get_provider_resource_chunk,
get_provider_resource=groups.get_provider_resource,
convert_client_resource_creation_input_to_provider_resource_creation_input=groups.convert_client_resource_creation_input_to_provider_resource_creation_input,
get_provider_resource_from_unique_fields=lambda **kwargs: None,
restore_provider_resource=None,
create_provider_resource=groups.create_provider_resource,
delete_provider_resource=groups.delete_provider_resource,
convert_client_resource_rewrite_input_to_provider_resource_rewrite_input=groups.convert_client_resource_rewrite_input_to_provider_resource_rewrite_input,
rewrite_provider_resource=groups.rewrite_provider_resource,
convert_client_resource_update_input_to_provider_resource_update_input=groups.convert_client_resource_update_input_to_provider_resource_update_input,
update_provider_resource=groups.update_provider_resource,
filter_attribute_mapping=groups.filter_attribute_mapping,
)
RESOURCE_TYPE_TO_RESOURCE_CONFIG: dict[str, ResourceConfig] = {
"Users": user_config,
"Groups": group_config,
}
class SCIMResource(str, Enum):
USERS = "Users"
GROUPS = "Groups"
@public_app.get("/{resource_type}")
async def get_resources(
resource_type: SCIMResource,
tenant_id=Depends(auth_required),
requested_start_index_one_indexed: int = Query(1, alias="startIndex"),
requested_items_per_page: int | None = Query(None, alias="count"),
attributes: str | None = Query(None),
excluded_attributes: str | None = Query(None, alias="excludedAttributes"),
filter: str | None = Query(None),
):
config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
filter_clause = helpers.scim_to_sql_where(filter, config.filter_attribute_mapping())
total_resources = config.get_active_resource_count(tenant_id, filter_clause)
start_index_one_indexed = max(1, requested_start_index_one_indexed)
offset = start_index_one_indexed - 1
limit = min(
max(0, requested_items_per_page or config.max_chunk_size), config.max_chunk_size
)
provider_resources = config.get_provider_resource_chunk(
offset, tenant_id, limit, filter_clause
)
client_resources = [
api_helper.convert_provider_resource_to_client_resource(
config, provider_resource, attributes, excluded_attributes
)
for provider_resource in provider_resources
]
return JSONResponse(
status_code=200,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
"totalResults": total_resources,
"startIndex": start_index_one_indexed,
"itemsPerPage": len(client_resources),
"Resources": client_resources,
},
)
@public_app.get("/{resource_type}/{resource_id}")
async def get_resource(
resource_type: SCIMResource,
resource_id: int | str,
tenant_id=Depends(auth_required),
attributes: list[str] | None = Query(None),
excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"),
):
resource_config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
resource = api_helper.get_resource(
resource_config,
resource_id,
tenant_id,
attributes,
excluded_attributes,
)
if not resource:
return _not_found_error_response(resource_id)
return JSONResponse(status_code=200, content=resource)
@public_app.post("/{resource_type}")
async def create_resource(
resource_type: SCIMResource,
r: Request,
tenant_id=Depends(auth_required),
attributes: list[str] | None = Query(None),
excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"),
):
config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
payload = await r.json()
try:
provider_resource_input = config.convert_client_resource_creation_input_to_provider_resource_creation_input(
tenant_id,
payload,
)
except KeyError:
return _invalid_value_error_response()
existing_provider_resource = config.get_provider_resource_from_unique_fields(
**provider_resource_input
)
if (
existing_provider_resource
and existing_provider_resource.get("deleted_at") is None
):
return _uniqueness_error_response()
if (
existing_provider_resource
and existing_provider_resource.get("deleted_at") is not None
):
provider_resource = config.restore_provider_resource(
tenant_id=tenant_id, **provider_resource_input
)
else:
provider_resource = config.create_provider_resource(
tenant_id=tenant_id, **provider_resource_input
)
client_resource = api_helper.convert_provider_resource_to_client_resource(
config, provider_resource, attributes, excluded_attributes
)
response = JSONResponse(status_code=201, content=client_resource)
response.headers["Location"] = client_resource["meta"]["location"]
return response
@public_app.delete("/{resource_type}/{resource_id}")
async def delete_resource(
resource_type: SCIMResource,
resource_id: str,
tenant_id=Depends(auth_required),
):
# note(jon): this can be a soft or a hard delete
config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
resource = api_helper.get_resource(config, resource_id, tenant_id)
if not resource:
return _not_found_error_response(resource_id)
config.delete_provider_resource(resource_id, tenant_id)
return Response(status_code=204, content="")
@public_app.put("/{resource_type}/{resource_id}")
async def put_resource(
resource_type: SCIMResource,
resource_id: int | str,
r: Request,
tenant_id=Depends(auth_required),
attributes: list[str] | None = Query(None),
excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"),
):
config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
client_resource = api_helper.get_resource(config, resource_id, tenant_id)
if not client_resource:
return _not_found_error_response(resource_id)
schema = api_helper.get_schema(config)
payload = await r.json()
try:
client_resource_input = helpers.filter_mutable_attributes(
schema, payload, client_resource
)
except ValueError:
return _mutability_error_response()
provider_resource_input = (
config.convert_client_resource_rewrite_input_to_provider_resource_rewrite_input(
tenant_id, client_resource_input
)
)
try:
provider_resource = config.rewrite_provider_resource(
resource_id,
tenant_id,
**provider_resource_input,
)
except errors.UniqueViolation:
return _uniqueness_error_response()
except Exception as e:
return _internal_server_error_response(str(e))
client_resource = api_helper.convert_provider_resource_to_client_resource(
config, provider_resource, attributes, excluded_attributes
)
return JSONResponse(status_code=200, content=client_resource)
@public_app.patch("/{resource_type}/{resource_id}")
async def patch_resource(
resource_type: SCIMResource,
resource_id: int | str,
r: Request,
tenant_id=Depends(auth_required),
attributes: list[str] | None = Query(None),
excluded_attributes: list[str] | None = Query(None, alias="excludedAttributes"),
):
config = RESOURCE_TYPE_TO_RESOURCE_CONFIG[resource_type]
client_resource = api_helper.get_resource(config, resource_id, tenant_id)
if not client_resource:
return _not_found_error_response(resource_id)
schema = api_helper.get_schema(config)
payload = await r.json()
_, changes = helpers.apply_scim_patch(
payload["Operations"], client_resource, schema
)
client_resource_input = {
k: new_value for k, (old_value, new_value) in changes.items()
}
provider_resource_input = (
config.convert_client_resource_update_input_to_provider_resource_update_input(
tenant_id, client_resource_input
)
)
try:
provider_resource = config.update_provider_resource(
resource_id, tenant_id, **provider_resource_input
)
except errors.UniqueViolation:
return _uniqueness_error_response()
except Exception as e:
return _internal_server_error_response(str(e))
client_resource = api_helper.convert_provider_resource_to_client_resource(
config, provider_resource, attributes, excluded_attributes
)
return JSONResponse(status_code=200, content=client_resource)