280 lines
11 KiB
Python
280 lines
11 KiB
Python
import traceback
|
|
from typing import Union
|
|
|
|
from scim2_server import provider
|
|
|
|
from scim2_models import (
|
|
AuthenticationScheme,
|
|
ServiceProviderConfig,
|
|
Patch,
|
|
Bulk,
|
|
Filter,
|
|
Sort,
|
|
ETag,
|
|
Meta,
|
|
ChangePassword,
|
|
Error,
|
|
ResourceType,
|
|
Context,
|
|
ListResponse,
|
|
PatchOp,
|
|
)
|
|
|
|
from werkzeug import Request, Response
|
|
from werkzeug.exceptions import HTTPException, NotFound, PreconditionFailed
|
|
from pydantic import ValidationError
|
|
from werkzeug.routing.exceptions import RequestRedirect
|
|
from scim2_server.utils import SCIMException, merge_resources
|
|
|
|
from chalicelib.utils.scim_auth import verify_access_token
|
|
|
|
|
|
class MultiTenantProvider(provider.SCIMProvider):
|
|
def check_auth(self, request: Request):
|
|
auth = request.headers.get("Authorization")
|
|
if not auth or not auth.startswith("Bearer "):
|
|
return None
|
|
token = auth[len("Bearer ") :]
|
|
if not token:
|
|
return Response(
|
|
"Missing or invalid Authorization header",
|
|
status=401,
|
|
headers={"WWW-Authenticate": 'Bearer realm="login required"'},
|
|
)
|
|
payload = verify_access_token(token)
|
|
tenant_id = payload["tenant_id"]
|
|
return tenant_id
|
|
|
|
def get_service_provider_config(self):
|
|
auth_schemes = [
|
|
AuthenticationScheme(
|
|
type="oauthbearertoken",
|
|
name="OAuth Bearer Token",
|
|
description="Authentication scheme using the OAuth Bearer Token Standard. The access token should be sent in the 'Authorization' header using the Bearer schema.",
|
|
spec_uri="https://datatracker.ietf.org/doc/html/rfc6750",
|
|
)
|
|
]
|
|
return ServiceProviderConfig(
|
|
# todo(jon): write correct documentation uri
|
|
documentation_uri="https://www.example.com/",
|
|
patch=Patch(supported=True),
|
|
bulk=Bulk(supported=False),
|
|
filter=Filter(supported=True, max_results=1000),
|
|
change_password=ChangePassword(supported=False),
|
|
sort=Sort(supported=True),
|
|
etag=ETag(supported=False),
|
|
authentication_schemes=auth_schemes,
|
|
meta=Meta(resource_type="ServiceProviderConfig"),
|
|
)
|
|
|
|
def query_resource(
|
|
self, request: Request, tenant_id: int, resource: ResourceType | None
|
|
):
|
|
search_request = self.build_search_request(request)
|
|
|
|
kwargs = {}
|
|
if resource is not None:
|
|
kwargs["resource_type_id"] = resource.id
|
|
total_results, results = self.backend.query_resources(
|
|
search_request=search_request, tenant_id=tenant_id, **kwargs
|
|
)
|
|
for r in results:
|
|
self.adjust_location(request, r)
|
|
|
|
resources = [
|
|
s.model_dump(
|
|
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
|
|
attributes=search_request.attributes,
|
|
excluded_attributes=search_request.excluded_attributes,
|
|
)
|
|
for s in results
|
|
]
|
|
|
|
return ListResponse[Union[tuple(self.backend.get_models())]]( # noqa: UP007
|
|
total_results=total_results,
|
|
items_per_page=search_request.count,
|
|
start_index=search_request.start_index,
|
|
resources=resources,
|
|
)
|
|
|
|
def call_resource(
|
|
self, request: Request, resource_endpoint: str, **kwargs
|
|
) -> Response:
|
|
resource_type = self.backend.get_resource_type_by_endpoint(
|
|
"/" + resource_endpoint
|
|
)
|
|
if not resource_type:
|
|
raise NotFound
|
|
|
|
if "tenant_id" not in kwargs:
|
|
raise Exception
|
|
tenant_id = kwargs["tenant_id"]
|
|
|
|
match request.method:
|
|
case "GET":
|
|
return self.make_response(
|
|
self.query_resource(request, tenant_id, resource_type).model_dump(
|
|
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
|
|
)
|
|
)
|
|
case _: # "POST"
|
|
payload = request.json
|
|
resource = self.backend.get_model(resource_type.id).model_validate(
|
|
payload, scim_ctx=Context.RESOURCE_CREATION_REQUEST
|
|
)
|
|
created_resource = self.backend.create_resource(
|
|
tenant_id,
|
|
resource_type.id,
|
|
resource,
|
|
)
|
|
self.adjust_location(request, created_resource)
|
|
return self.make_response(
|
|
created_resource.model_dump(
|
|
scim_ctx=Context.RESOURCE_CREATION_RESPONSE
|
|
),
|
|
status=201,
|
|
headers={"Location": created_resource.meta.location},
|
|
)
|
|
|
|
def call_single_resource(
|
|
self, request: Request, resource_endpoint: str, resource_id: str, **kwargs
|
|
) -> Response:
|
|
find_endpoint = "/" + resource_endpoint
|
|
resource_type = self.backend.get_resource_type_by_endpoint(find_endpoint)
|
|
if not resource_type:
|
|
raise NotFound
|
|
|
|
if "tenant_id" not in kwargs:
|
|
raise Exception
|
|
tenant_id = kwargs["tenant_id"]
|
|
|
|
match request.method:
|
|
case "GET":
|
|
if resource := self.backend.get_resource(
|
|
tenant_id, resource_type.id, resource_id
|
|
):
|
|
if self.continue_etag(request, resource):
|
|
response_args = self.get_attrs_from_request(request)
|
|
self.adjust_location(request, resource)
|
|
return self.make_response(
|
|
resource.model_dump(
|
|
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
|
|
**response_args,
|
|
)
|
|
)
|
|
else:
|
|
return self.make_response(None, status=304)
|
|
raise NotFound
|
|
case "DELETE":
|
|
if self.backend.delete_resource(
|
|
tenant_id, resource_type.id, resource_id
|
|
):
|
|
return self.make_response(None, 204)
|
|
else:
|
|
raise NotFound
|
|
case "PUT":
|
|
response_args = self.get_attrs_from_request(request)
|
|
resource = self.backend.get_resource(
|
|
tenant_id, resource_type.id, resource_id
|
|
)
|
|
if resource is None:
|
|
raise NotFound
|
|
if not self.continue_etag(request, resource):
|
|
raise PreconditionFailed
|
|
|
|
updated_attributes = self.backend.get_model(
|
|
resource_type.id
|
|
).model_validate(request.json)
|
|
merge_resources(resource, updated_attributes)
|
|
updated = self.backend.update_resource(
|
|
tenant_id, resource_type.id, resource
|
|
)
|
|
self.adjust_location(request, updated)
|
|
return self.make_response(
|
|
updated.model_dump(
|
|
scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE,
|
|
**response_args,
|
|
)
|
|
)
|
|
case _: # "PATCH"
|
|
payload = request.json
|
|
# MS Entra sometimes passes a "id" attribute
|
|
if "id" in payload:
|
|
del payload["id"]
|
|
operations = payload.get("Operations", [])
|
|
for operation in operations:
|
|
if "name" in operation:
|
|
# MS Entra sometimes passes a "name" attribute
|
|
del operation["name"]
|
|
|
|
patch_operation = PatchOp.model_validate(payload)
|
|
response_args = self.get_attrs_from_request(request)
|
|
resource = self.backend.get_resource(
|
|
tenant_id, resource_type.id, resource_id
|
|
)
|
|
if resource is None:
|
|
raise NotFound
|
|
if not self.continue_etag(request, resource):
|
|
raise PreconditionFailed
|
|
|
|
self.apply_patch_operation(resource, patch_operation)
|
|
updated = self.backend.update_resource(
|
|
tenant_id, resource_type.id, resource
|
|
)
|
|
|
|
if response_args:
|
|
self.adjust_location(request, updated)
|
|
return self.make_response(
|
|
updated.model_dump(
|
|
scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE,
|
|
**response_args,
|
|
)
|
|
)
|
|
else:
|
|
# RFC 7644, section 3.5.2:
|
|
# A PATCH operation MAY return a 204 (no content)
|
|
# if no attributes were requested
|
|
return self.make_response(
|
|
None, 204, headers={"ETag": updated.meta.version}
|
|
)
|
|
|
|
def wsgi_app(self, request: Request, environ):
|
|
try:
|
|
if environ.get("PATH_INFO", "").endswith(".scim"):
|
|
# RFC 7644, Section 3.8
|
|
# Just strip .scim suffix, the provider always returns application/scim+json
|
|
environ["PATH_INFO"], _, _ = environ["PATH_INFO"].rpartition(".scim")
|
|
urls = self.url_map.bind_to_environ(environ)
|
|
endpoint, args = urls.match()
|
|
|
|
tenant_id = None
|
|
if endpoint != "service_provider_config":
|
|
# RFC7643, Section 5: skip authentication for ServiceProviderConfig
|
|
tenant_id = self.check_auth(request)
|
|
|
|
# Wrap the entire call in a transaction. Should probably be optimized (use transaction only when necessary).
|
|
with self.backend:
|
|
if endpoint == "service_provider_config" or endpoint == "schema":
|
|
response = getattr(self, f"call_{endpoint}")(request, **args)
|
|
else:
|
|
response = getattr(self, f"call_{endpoint}")(
|
|
request, **args, tenant_id=tenant_id
|
|
)
|
|
return response
|
|
except RequestRedirect as e:
|
|
# urls.match may cause a redirect, handle it as a special case of HTTPException
|
|
self.log.exception(e)
|
|
return e.get_response(environ)
|
|
except HTTPException as e:
|
|
self.log.exception(e)
|
|
return self.make_error(Error(status=e.code, detail=e.description))
|
|
except SCIMException as e:
|
|
self.log.exception(e)
|
|
return self.make_error(e.scim_error)
|
|
except ValidationError as e:
|
|
self.log.exception(e)
|
|
return self.make_error(Error(status=400, detail=str(e)))
|
|
except Exception as e:
|
|
self.log.exception(e)
|
|
tb = traceback.format_exc()
|
|
return self.make_error(Error(status=500, detail=str(e) + "\n" + tb))
|