openreplay/ee/api/routers/scim/providers.py
2025-06-02 16:39:00 +02:00

280 lines
11 KiB
Python

import traceback
from typing import Union
from scim2_server import provider
from scim2_models import (
AuthenticationScheme,
ServiceProviderConfig,
Patch,
Bulk,
Filter,
Sort,
ETag,
Meta,
ChangePassword,
Error,
ResourceType,
Context,
ListResponse,
PatchOp,
)
from werkzeug import Request, Response
from werkzeug.exceptions import HTTPException, NotFound, PreconditionFailed
from pydantic import ValidationError
from werkzeug.routing.exceptions import RequestRedirect
from scim2_server.utils import SCIMException, merge_resources
from chalicelib.utils.scim_auth import verify_access_token
class MultiTenantProvider(provider.SCIMProvider):
def check_auth(self, request: Request):
auth = request.headers.get("Authorization")
if not auth or not auth.startswith("Bearer "):
return None
token = auth[len("Bearer ") :]
if not token:
return Response(
"Missing or invalid Authorization header",
status=401,
headers={"WWW-Authenticate": 'Bearer realm="login required"'},
)
payload = verify_access_token(token)
tenant_id = payload["tenant_id"]
return tenant_id
def get_service_provider_config(self):
auth_schemes = [
AuthenticationScheme(
type="oauthbearertoken",
name="OAuth Bearer Token",
description="Authentication scheme using the OAuth Bearer Token Standard. The access token should be sent in the 'Authorization' header using the Bearer schema.",
spec_uri="https://datatracker.ietf.org/doc/html/rfc6750",
)
]
return ServiceProviderConfig(
# todo(jon): write correct documentation uri
documentation_uri="https://www.example.com/",
patch=Patch(supported=True),
bulk=Bulk(supported=False),
filter=Filter(supported=True, max_results=1000),
change_password=ChangePassword(supported=False),
sort=Sort(supported=True),
etag=ETag(supported=False),
authentication_schemes=auth_schemes,
meta=Meta(resource_type="ServiceProviderConfig"),
)
def query_resource(
self, request: Request, tenant_id: int, resource: ResourceType | None
):
search_request = self.build_search_request(request)
kwargs = {}
if resource is not None:
kwargs["resource_type_id"] = resource.id
total_results, results = self.backend.query_resources(
search_request=search_request, tenant_id=tenant_id, **kwargs
)
for r in results:
self.adjust_location(request, r)
resources = [
s.model_dump(
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
attributes=search_request.attributes,
excluded_attributes=search_request.excluded_attributes,
)
for s in results
]
return ListResponse[Union[tuple(self.backend.get_models())]]( # noqa: UP007
total_results=total_results,
items_per_page=search_request.count,
start_index=search_request.start_index,
resources=resources,
)
def call_resource(
self, request: Request, resource_endpoint: str, **kwargs
) -> Response:
resource_type = self.backend.get_resource_type_by_endpoint(
"/" + resource_endpoint
)
if not resource_type:
raise NotFound
if "tenant_id" not in kwargs:
raise Exception
tenant_id = kwargs["tenant_id"]
match request.method:
case "GET":
return self.make_response(
self.query_resource(request, tenant_id, resource_type).model_dump(
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
)
)
case _: # "POST"
payload = request.json
resource = self.backend.get_model(resource_type.id).model_validate(
payload, scim_ctx=Context.RESOURCE_CREATION_REQUEST
)
created_resource = self.backend.create_resource(
tenant_id,
resource_type.id,
resource,
)
self.adjust_location(request, created_resource)
return self.make_response(
created_resource.model_dump(
scim_ctx=Context.RESOURCE_CREATION_RESPONSE
),
status=201,
headers={"Location": created_resource.meta.location},
)
def call_single_resource(
self, request: Request, resource_endpoint: str, resource_id: str, **kwargs
) -> Response:
find_endpoint = "/" + resource_endpoint
resource_type = self.backend.get_resource_type_by_endpoint(find_endpoint)
if not resource_type:
raise NotFound
if "tenant_id" not in kwargs:
raise Exception
tenant_id = kwargs["tenant_id"]
match request.method:
case "GET":
if resource := self.backend.get_resource(
tenant_id, resource_type.id, resource_id
):
if self.continue_etag(request, resource):
response_args = self.get_attrs_from_request(request)
self.adjust_location(request, resource)
return self.make_response(
resource.model_dump(
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
**response_args,
)
)
else:
return self.make_response(None, status=304)
raise NotFound
case "DELETE":
if self.backend.delete_resource(
tenant_id, resource_type.id, resource_id
):
return self.make_response(None, 204)
else:
raise NotFound
case "PUT":
response_args = self.get_attrs_from_request(request)
resource = self.backend.get_resource(
tenant_id, resource_type.id, resource_id
)
if resource is None:
raise NotFound
if not self.continue_etag(request, resource):
raise PreconditionFailed
updated_attributes = self.backend.get_model(
resource_type.id
).model_validate(request.json)
merge_resources(resource, updated_attributes)
updated = self.backend.update_resource(
tenant_id, resource_type.id, resource
)
self.adjust_location(request, updated)
return self.make_response(
updated.model_dump(
scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE,
**response_args,
)
)
case _: # "PATCH"
payload = request.json
# MS Entra sometimes passes a "id" attribute
if "id" in payload:
del payload["id"]
operations = payload.get("Operations", [])
for operation in operations:
if "name" in operation:
# MS Entra sometimes passes a "name" attribute
del operation["name"]
patch_operation = PatchOp.model_validate(payload)
response_args = self.get_attrs_from_request(request)
resource = self.backend.get_resource(
tenant_id, resource_type.id, resource_id
)
if resource is None:
raise NotFound
if not self.continue_etag(request, resource):
raise PreconditionFailed
self.apply_patch_operation(resource, patch_operation)
updated = self.backend.update_resource(
tenant_id, resource_type.id, resource
)
if response_args:
self.adjust_location(request, updated)
return self.make_response(
updated.model_dump(
scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE,
**response_args,
)
)
else:
# RFC 7644, section 3.5.2:
# A PATCH operation MAY return a 204 (no content)
# if no attributes were requested
return self.make_response(
None, 204, headers={"ETag": updated.meta.version}
)
def wsgi_app(self, request: Request, environ):
try:
if environ.get("PATH_INFO", "").endswith(".scim"):
# RFC 7644, Section 3.8
# Just strip .scim suffix, the provider always returns application/scim+json
environ["PATH_INFO"], _, _ = environ["PATH_INFO"].rpartition(".scim")
urls = self.url_map.bind_to_environ(environ)
endpoint, args = urls.match()
tenant_id = None
if endpoint != "service_provider_config":
# RFC7643, Section 5: skip authentication for ServiceProviderConfig
tenant_id = self.check_auth(request)
# Wrap the entire call in a transaction. Should probably be optimized (use transaction only when necessary).
with self.backend:
if endpoint == "service_provider_config" or endpoint == "schema":
response = getattr(self, f"call_{endpoint}")(request, **args)
else:
response = getattr(self, f"call_{endpoint}")(
request, **args, tenant_id=tenant_id
)
return response
except RequestRedirect as e:
# urls.match may cause a redirect, handle it as a special case of HTTPException
self.log.exception(e)
return e.get_response(environ)
except HTTPException as e:
self.log.exception(e)
return self.make_error(Error(status=e.code, detail=e.description))
except SCIMException as e:
self.log.exception(e)
return self.make_error(e.scim_error)
except ValidationError as e:
self.log.exception(e)
return self.make_error(Error(status=400, detail=str(e)))
except Exception as e:
self.log.exception(e)
tb = traceback.format_exc()
return self.make_error(Error(status=500, detail=str(e) + "\n" + tb))