This commit is contained in:
jonathan-caleb-griffin 2025-06-02 14:39:07 +00:00 committed by GitHub
commit e1fdbb1c36
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 1958 additions and 1260 deletions

1
ee/api/.gitignore vendored
View file

@ -283,4 +283,3 @@ Pipfile.lock
/chalicelib/utils/contextual_validators.py
/routers/subs/product_analytics.py
/schemas/product_analytics.py
/ee/bin/*

View file

@ -26,6 +26,8 @@ xmlsec = "==1.3.14"
python-multipart = "==0.0.20"
redis = "==6.1.0"
azure-storage-blob = "==12.25.1"
scim2-server = "*"
scim2-models = "*"
[dev-packages]

View file

@ -9,6 +9,7 @@ from decouple import config
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.middleware.wsgi import WSGIMiddleware
from psycopg import AsyncConnection
from psycopg.rows import dict_row
from starlette import status
@ -21,12 +22,20 @@ from chalicelib.utils import pg_client, ch_client
from crons import core_crons, ee_crons, core_dynamic_crons
from routers import core, core_dynamic
from routers import ee
from routers.subs import insights, metrics, v1_api, health, usability_tests, spot, product_analytics
from routers.subs import (
insights,
metrics,
v1_api,
health,
usability_tests,
spot,
product_analytics,
)
from routers.subs import v1_api_ee
if config("ENABLE_SSO", cast=bool, default=True):
from routers import saml
from routers import scim
from routers.scim import api as scim
loglevel = config("LOGLEVEL", default=logging.WARNING)
print(f">Loglevel set to: {loglevel}")
@ -34,7 +43,6 @@ logging.basicConfig(level=loglevel)
class ORPYAsyncConnection(AsyncConnection):
def __init__(self, *args, **kwargs):
super().__init__(*args, row_factory=dict_row, **kwargs)
@ -43,7 +51,7 @@ class ORPYAsyncConnection(AsyncConnection):
async def lifespan(app: FastAPI):
# Startup
logging.info(">>>>> starting up <<<<<")
ap_logger = logging.getLogger('apscheduler')
ap_logger = logging.getLogger("apscheduler")
ap_logger.setLevel(loglevel)
app.schedule = AsyncIOScheduler()
@ -53,12 +61,23 @@ async def lifespan(app: FastAPI):
await events_queue.init()
app.schedule.start()
for job in core_crons.cron_jobs + core_dynamic_crons.cron_jobs + traces.cron_jobs + ee_crons.ee_cron_jobs:
for job in (
core_crons.cron_jobs
+ core_dynamic_crons.cron_jobs
+ traces.cron_jobs
+ ee_crons.ee_cron_jobs
):
app.schedule.add_job(id=job["func"].__name__, **job)
ap_logger.info(">Scheduled jobs:")
for job in app.schedule.get_jobs():
ap_logger.info({"Name": str(job.id), "Run Frequency": str(job.trigger), "Next Run": str(job.next_run_time)})
ap_logger.info(
{
"Name": str(job.id),
"Run Frequency": str(job.trigger),
"Next Run": str(job.next_run_time),
}
)
database = {
"host": config("pg_host", default="localhost"),
@ -69,9 +88,12 @@ async def lifespan(app: FastAPI):
"application_name": "AIO" + config("APP_NAME", default="PY"),
}
database = psycopg_pool.AsyncConnectionPool(kwargs=database, connection_class=ORPYAsyncConnection,
min_size=config("PG_AIO_MINCONN", cast=int, default=1),
max_size=config("PG_AIO_MAXCONN", cast=int, default=5), )
database = psycopg_pool.AsyncConnectionPool(
kwargs=database,
connection_class=ORPYAsyncConnection,
min_size=config("PG_AIO_MINCONN", cast=int, default=1),
max_size=config("PG_AIO_MAXCONN", cast=int, default=5),
)
app.state.postgresql = database
# App listening
@ -86,16 +108,24 @@ async def lifespan(app: FastAPI):
await pg_client.terminate()
app = FastAPI(root_path=config("root_path", default="/api"), docs_url=config("docs_url", default=""),
redoc_url=config("redoc_url", default=""), lifespan=lifespan)
app = FastAPI(
root_path=config("root_path", default="/api"),
docs_url=config("docs_url", default=""),
redoc_url=config("redoc_url", default=""),
lifespan=lifespan,
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
@app.middleware('http')
@app.middleware("http")
async def or_middleware(request: Request, call_next):
from chalicelib.core import unlock
if not unlock.is_valid():
return JSONResponse(content={"errors": ["expired license"]}, status_code=status.HTTP_403_FORBIDDEN)
return JSONResponse(
content={"errors": ["expired license"]},
status_code=status.HTTP_403_FORBIDDEN,
)
if helper.TRACK_TIME:
now = time.time()
@ -110,8 +140,10 @@ async def or_middleware(request: Request, call_next):
now = time.time() - now
if now > 2:
now = round(now, 2)
logging.warning(f"Execution time: {now} s for {request.method}: {request.url.path}")
response.headers["x-robots-tag"] = 'noindex, nofollow'
logging.warning(
f"Execution time: {now} s for {request.method}: {request.url.path}"
)
response.headers["x-robots-tag"] = "noindex, nofollow"
return response
@ -162,3 +194,4 @@ if config("ENABLE_SSO", cast=bool, default=True):
app.include_router(scim.public_app)
app.include_router(scim.app)
app.include_router(scim.app_apikey)
app.mount("/sso/scim/v2", WSGIMiddleware(scim.scim_app))

View file

@ -1,4 +1,3 @@
import json
from typing import Optional
from fastapi import HTTPException, status
@ -10,13 +9,15 @@ from chalicelib.utils.TimeUTC import TimeUTC
def __exists_by_name(tenant_id: int, name: str, exclude_id: Optional[int]) -> bool:
with pg_client.PostgresClient() as cur:
query = cur.mogrify(f"""SELECT EXISTS(SELECT 1
query = cur.mogrify(
f"""SELECT EXISTS(SELECT 1
FROM public.roles
WHERE tenant_id = %(tenant_id)s
AND name ILIKE %(name)s
AND deleted_at ISNULL
{"AND role_id!=%(exclude_id)s" if exclude_id else ""}) AS exists;""",
{"tenant_id": tenant_id, "name": name, "exclude_id": exclude_id})
{"tenant_id": tenant_id, "name": name, "exclude_id": exclude_id},
)
cur.execute(query=query)
row = cur.fetchone()
return row["exists"]
@ -28,24 +29,31 @@ def update(tenant_id, user_id, role_id, data: schemas.RolePayloadSchema):
if not admin["admin"] and not admin["superAdmin"]:
return {"errors": ["unauthorized"]}
if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=role_id):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"name already exists.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="name already exists."
)
if not data.all_projects and (data.projects is None or len(data.projects) == 0):
return {"errors": ["must specify a project or all projects"]}
if data.projects is not None and len(data.projects) > 0 and not data.all_projects:
data.projects = projects.is_authorized_batch(project_ids=data.projects, tenant_id=tenant_id)
data.projects = projects.is_authorized_batch(
project_ids=data.projects, tenant_id=tenant_id
)
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT 1
query = cur.mogrify(
"""SELECT 1
FROM public.roles
WHERE role_id = %(role_id)s
AND tenant_id = %(tenant_id)s
AND protected = TRUE
LIMIT 1;""",
{"tenant_id": tenant_id, "role_id": role_id})
{"tenant_id": tenant_id, "role_id": role_id},
)
cur.execute(query=query)
if cur.fetchone() is not None:
return {"errors": ["this role is protected"]}
query = cur.mogrify("""UPDATE public.roles
query = cur.mogrify(
"""UPDATE public.roles
SET name= %(name)s,
description= %(description)s,
permissions= %(permissions)s,
@ -57,43 +65,36 @@ def update(tenant_id, user_id, role_id, data: schemas.RolePayloadSchema):
RETURNING *, COALESCE((SELECT ARRAY_AGG(project_id)
FROM roles_projects
WHERE roles_projects.role_id=%(role_id)s),'{}') AS projects;""",
{"tenant_id": tenant_id, "role_id": role_id, **data.model_dump()})
{"tenant_id": tenant_id, "role_id": role_id, **data.model_dump()},
)
cur.execute(query=query)
row = cur.fetchone()
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
if not data.all_projects:
d_projects = [i for i in row["projects"] if i not in data.projects]
if len(d_projects) > 0:
query = cur.mogrify("""DELETE FROM roles_projects
query = cur.mogrify(
"""DELETE FROM roles_projects
WHERE role_id=%(role_id)s
AND project_id IN %(project_ids)s""",
{"role_id": role_id, "project_ids": tuple(d_projects)})
{"role_id": role_id, "project_ids": tuple(d_projects)},
)
cur.execute(query=query)
n_projects = [i for i in data.projects if i not in row["projects"]]
if len(n_projects) > 0:
query = cur.mogrify(f"""INSERT INTO roles_projects(role_id, project_id)
query = cur.mogrify(
f"""INSERT INTO roles_projects(role_id, project_id)
VALUES {",".join([f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(n_projects))])}""",
{"role_id": role_id, **{f"project_id_{i}": p for i, p in enumerate(n_projects)}})
{
"role_id": role_id,
**{f"project_id_{i}": p for i, p in enumerate(n_projects)},
},
)
cur.execute(query=query)
row["projects"] = data.projects
return helper.dict_to_camel_case(row)
def update_group_name(tenant_id, group_id, name):
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""UPDATE public.roles
SET name= %(name)s
WHERE roles.data->>'group_id' = %(group_id)s
AND tenant_id = %(tenant_id)s
AND deleted_at ISNULL
AND protected = FALSE
RETURNING *;""",
{"tenant_id": tenant_id, "group_id": group_id, "name": name })
cur.execute(query=query)
row = cur.fetchone()
return helper.dict_to_camel_case(row)
def create(tenant_id, user_id, data: schemas.RolePayloadSchema):
admin = users.get(user_id=user_id, tenant_id=tenant_id)
@ -102,57 +103,44 @@ def create(tenant_id, user_id, data: schemas.RolePayloadSchema):
return {"errors": ["unauthorized"]}
if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=None):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"name already exists.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="name already exists."
)
if not data.all_projects and (data.projects is None or len(data.projects) == 0):
return {"errors": ["must specify a project or all projects"]}
if data.projects is not None and len(data.projects) > 0 and not data.all_projects:
data.projects = projects.is_authorized_batch(project_ids=data.projects, tenant_id=tenant_id)
data.projects = projects.is_authorized_batch(
project_ids=data.projects, tenant_id=tenant_id
)
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""INSERT INTO roles(tenant_id, name, description, permissions, all_projects)
query = cur.mogrify(
"""INSERT INTO roles(tenant_id, name, description, permissions, all_projects)
VALUES (%(tenant_id)s, %(name)s, %(description)s, %(permissions)s::text[], %(all_projects)s)
RETURNING *;""",
{"tenant_id": tenant_id, "name": data.name, "description": data.description,
"permissions": data.permissions, "all_projects": data.all_projects})
{
"tenant_id": tenant_id,
"name": data.name,
"description": data.description,
"permissions": data.permissions,
"all_projects": data.all_projects,
},
)
cur.execute(query=query)
row = cur.fetchone()
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
row["projects"] = []
if not data.all_projects:
role_id = row["role_id"]
query = cur.mogrify(f"""INSERT INTO roles_projects(role_id, project_id)
query = cur.mogrify(
f"""INSERT INTO roles_projects(role_id, project_id)
VALUES {",".join(f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(data.projects)))}
RETURNING project_id;""",
{"role_id": role_id, **{f"project_id_{i}": p for i, p in enumerate(data.projects)}})
cur.execute(query=query)
row["projects"] = [r["project_id"] for r in cur.fetchall()]
return helper.dict_to_camel_case(row)
def create_as_admin(tenant_id, group_id, data: schemas.RolePayloadSchema):
if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=None):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"name already exists.")
if not data.all_projects and (data.projects is None or len(data.projects) == 0):
return {"errors": ["must specify a project or all projects"]}
if data.projects is not None and len(data.projects) > 0 and not data.all_projects:
data.projects = projects.is_authorized_batch(project_ids=data.projects, tenant_id=tenant_id)
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""INSERT INTO roles(tenant_id, name, description, permissions, all_projects, data)
VALUES (%(tenant_id)s, %(name)s, %(description)s, %(permissions)s::text[], %(all_projects)s, %(data)s)
RETURNING *;""",
{"tenant_id": tenant_id, "name": data.name, "description": data.description,
"permissions": data.permissions, "all_projects": data.all_projects, "data": json.dumps({ "group_id": group_id })})
cur.execute(query=query)
row = cur.fetchone()
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
row["projects"] = []
if not data.all_projects:
role_id = row["role_id"]
query = cur.mogrify(f"""INSERT INTO roles_projects(role_id, project_id)
VALUES {",".join(f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(data.projects)))}
RETURNING project_id;""",
{"role_id": role_id, **{f"project_id_{i}": p for i, p in enumerate(data.projects)}})
{
"role_id": role_id,
**{f"project_id_{i}": p for i, p in enumerate(data.projects)},
},
)
cur.execute(query=query)
row["projects"] = [r["project_id"] for r in cur.fetchall()]
return helper.dict_to_camel_case(row)
@ -160,7 +148,8 @@ def create_as_admin(tenant_id, group_id, data: schemas.RolePayloadSchema):
def get_roles(tenant_id):
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT roles.*, COALESCE(projects, '{}') AS projects
query = cur.mogrify(
"""SELECT roles.*, COALESCE(projects, '{}') AS projects
FROM public.roles
LEFT JOIN LATERAL (SELECT array_agg(project_id) AS projects
FROM roles_projects
@ -171,66 +160,25 @@ def get_roles(tenant_id):
AND deleted_at IS NULL
AND not service_role
ORDER BY role_id;""",
{"tenant_id": tenant_id})
{"tenant_id": tenant_id},
)
cur.execute(query=query)
rows = cur.fetchall()
for r in rows:
r["created_at"] = TimeUTC.datetime_to_timestamp(r["created_at"])
return helper.list_to_camel_case(rows)
def get_roles_with_uuid(tenant_id):
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT roles.*, COALESCE(projects, '{}') AS projects
FROM public.roles
LEFT JOIN LATERAL (SELECT array_agg(project_id) AS projects
FROM roles_projects
INNER JOIN projects USING (project_id)
WHERE roles_projects.role_id = roles.role_id
AND projects.deleted_at ISNULL ) AS role_projects ON (TRUE)
WHERE tenant_id =%(tenant_id)s
AND data ? 'group_id'
AND deleted_at IS NULL
AND not service_role
ORDER BY role_id;""",
{"tenant_id": tenant_id})
cur.execute(query=query)
rows = cur.fetchall()
for r in rows:
r["created_at"] = TimeUTC.datetime_to_timestamp(r["created_at"])
return helper.list_to_camel_case(rows)
def get_roles_with_uuid_paginated(tenant_id, start_index, count=None, name=None):
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT roles.*, COALESCE(projects, '{}') AS projects
FROM public.roles
LEFT JOIN LATERAL (SELECT array_agg(project_id) AS projects
FROM roles_projects
INNER JOIN projects USING (project_id)
WHERE roles_projects.role_id = roles.role_id
AND projects.deleted_at ISNULL ) AS role_projects ON (TRUE)
WHERE tenant_id =%(tenant_id)s
AND data ? 'group_id'
AND deleted_at IS NULL
AND not service_role
AND name = COALESCE(%(name)s, name)
ORDER BY role_id
LIMIT %(count)s
OFFSET %(startIndex)s;""",
{"tenant_id": tenant_id, "name": name, "startIndex": start_index - 1, "count": count})
cur.execute(query=query)
rows = cur.fetchall()
return helper.list_to_camel_case(rows)
def get_role_by_name(tenant_id, name):
### "name" isn't unique in database
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT *
query = cur.mogrify(
"""SELECT *
FROM public.roles
WHERE tenant_id =%(tenant_id)s
AND deleted_at IS NULL
AND name ILIKE %(name)s;""",
{"tenant_id": tenant_id, "name": name})
{"tenant_id": tenant_id, "name": name},
)
cur.execute(query=query)
row = cur.fetchone()
if row is not None:
@ -244,139 +192,55 @@ def delete(tenant_id, user_id, role_id):
if not admin["admin"] and not admin["superAdmin"]:
return {"errors": ["unauthorized"]}
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT 1
query = cur.mogrify(
"""SELECT 1
FROM public.roles
WHERE role_id = %(role_id)s
AND tenant_id = %(tenant_id)s
AND protected = TRUE
LIMIT 1;""",
{"tenant_id": tenant_id, "role_id": role_id})
{"tenant_id": tenant_id, "role_id": role_id},
)
cur.execute(query=query)
if cur.fetchone() is not None:
return {"errors": ["this role is protected"]}
query = cur.mogrify("""SELECT 1
query = cur.mogrify(
"""SELECT 1
FROM public.users
WHERE role_id = %(role_id)s
AND tenant_id = %(tenant_id)s
LIMIT 1;""",
{"tenant_id": tenant_id, "role_id": role_id})
{"tenant_id": tenant_id, "role_id": role_id},
)
cur.execute(query=query)
if cur.fetchone() is not None:
return {"errors": ["this role is already attached to other user(s)"]}
query = cur.mogrify("""UPDATE public.roles
query = cur.mogrify(
"""UPDATE public.roles
SET deleted_at = timezone('utc'::text, now())
WHERE role_id = %(role_id)s
AND tenant_id = %(tenant_id)s
AND protected = FALSE;""",
{"tenant_id": tenant_id, "role_id": role_id})
{"tenant_id": tenant_id, "role_id": role_id},
)
cur.execute(query=query)
return get_roles(tenant_id=tenant_id)
def delete_scim_group(tenant_id, group_uuid):
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT 1
FROM public.roles
WHERE data->>'group_id' = %(group_uuid)s
AND tenant_id = %(tenant_id)s
AND protected = TRUE
LIMIT 1;""",
{"tenant_id": tenant_id, "group_uuid": group_uuid})
cur.execute(query)
if cur.fetchone() is not None:
return {"errors": ["this role is protected"]}
query = cur.mogrify(
f"""DELETE FROM public.roles
WHERE roles.data->>'group_id' = %(group_uuid)s;""", # removed this: AND users.deleted_at IS NOT NULL
{"group_uuid": group_uuid})
cur.execute(query)
return get_roles(tenant_id=tenant_id)
def get_role(tenant_id, role_id):
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT roles.*
query = cur.mogrify(
"""SELECT roles.*
FROM public.roles
WHERE tenant_id =%(tenant_id)s
AND deleted_at IS NULL
AND not service_role
AND role_id = %(role_id)s
LIMIT 1;""",
{"tenant_id": tenant_id, "role_id": role_id})
{"tenant_id": tenant_id, "role_id": role_id},
)
cur.execute(query=query)
row = cur.fetchone()
if row is not None:
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
return helper.dict_to_camel_case(row)
def get_role_by_group_id(tenant_id, group_id):
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT roles.*
FROM public.roles
WHERE tenant_id =%(tenant_id)s
AND deleted_at IS NULL
AND not service_role
AND data->>'group_id' = %(group_id)s
LIMIT 1;""",
{"tenant_id": tenant_id, "group_id": group_id})
cur.execute(query=query)
row = cur.fetchone()
if row is not None:
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
return helper.dict_to_camel_case(row)
def get_users_by_group_uuid(tenant_id, group_id):
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT
u.user_id,
u.name,
u.data
FROM public.roles r
LEFT JOIN public.users u USING (role_id, tenant_id)
WHERE u.tenant_id = %(tenant_id)s
AND u.deleted_at IS NULL
AND r.data->>'group_id' = %(group_id)s
""",
{"tenant_id": tenant_id, "group_id": group_id})
cur.execute(query=query)
rows = cur.fetchall()
return helper.list_to_camel_case(rows)
def get_member_permissions(tenant_id):
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""SELECT
r.permissions
FROM public.roles r
WHERE r.tenant_id = %(tenant_id)s
AND r.name = 'Member'
AND r.deleted_at IS NULL
""",
{"tenant_id": tenant_id})
cur.execute(query=query)
row = cur.fetchone()
return helper.dict_to_camel_case(row)
def remove_group_membership(tenant_id, group_id, user_id):
with pg_client.PostgresClient() as cur:
query = cur.mogrify("""WITH r AS (
SELECT role_id
FROM public.roles
WHERE data->>'group_id' = %(group_id)s
LIMIT 1
)
UPDATE public.users u
SET role_id= NULL
FROM r
WHERE u.data->>'user_id' = %(user_id)s
AND u.role_id = r.role_id
AND u.tenant_id = %(tenant_id)s
AND u.deleted_at IS NULL
RETURNING *;""",
{"tenant_id": tenant_id, "group_id": group_id, "user_id": user_id})
cur.execute(query=query)
row = cur.fetchone()
return helper.dict_to_camel_case(row)

File diff suppressed because it is too large Load diff

View file

@ -23,20 +23,18 @@ SAML2 = {
"entityId": config("SITE_URL") + API_PREFIX + "/sso/saml2/metadata/",
"assertionConsumerService": {
"url": config("SITE_URL") + API_PREFIX + "/sso/saml2/acs/",
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
},
"singleLogoutService": {
"url": config("SITE_URL") + API_PREFIX + "/sso/saml2/sls/",
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
},
"NameIDFormat": "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
"x509cert": config("sp_crt", default=""),
"privateKey": config("sp_key", default=""),
},
"security": {
"requestedAuthnContext": False
},
"idp": None
"security": {"requestedAuthnContext": False},
"idp": None,
}
# in case tenantKey is included in the URL
@ -50,25 +48,29 @@ if config("SAML2_MD_URL", default=None) is not None and len(config("SAML2_MD_URL
print("SAML2_MD_URL provided, getting IdP metadata config")
from onelogin.saml2.idp_metadata_parser import OneLogin_Saml2_IdPMetadataParser
idp_data = OneLogin_Saml2_IdPMetadataParser.parse_remote(config("SAML2_MD_URL", default=None))
idp_data = OneLogin_Saml2_IdPMetadataParser.parse_remote(
config("SAML2_MD_URL", default=None)
)
idp = idp_data.get("idp")
if SAML2["idp"] is None:
if len(config("idp_entityId", default="")) > 0 \
and len(config("idp_sso_url", default="")) > 0 \
and len(config("idp_x509cert", default="")) > 0:
if (
len(config("idp_entityId", default="")) > 0
and len(config("idp_sso_url", default="")) > 0
and len(config("idp_x509cert", default="")) > 0
):
idp = {
"entityId": config("idp_entityId"),
"singleSignOnService": {
"url": config("idp_sso_url"),
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
},
"x509cert": config("idp_x509cert")
"x509cert": config("idp_x509cert"),
}
if len(config("idp_sls_url", default="")) > 0:
idp["singleLogoutService"] = {
"url": config("idp_sls_url"),
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
}
if idp is None:
@ -106,8 +108,8 @@ async def prepare_request(request: Request):
session = {}
# If server is behind proxys or balancers use the HTTP_X_FORWARDED fields
headers = request.headers
proto = headers.get('x-forwarded-proto', 'http')
url_data = urlparse('%s://%s' % (proto, headers['host']))
proto = headers.get("x-forwarded-proto", "http")
url_data = urlparse("%s://%s" % (proto, headers["host"]))
path = request.url.path
site_url = urlparse(config("SITE_URL"))
# to support custom port without changing IDP config
@ -117,21 +119,21 @@ async def prepare_request(request: Request):
# add / to /acs
if not path.endswith("/"):
path = path + '/'
path = path + "/"
if len(API_PREFIX) > 0 and not path.startswith(API_PREFIX):
path = API_PREFIX + path
return {
'https': 'on' if proto == 'https' else 'off',
'http_host': request.headers['host'] + host_suffix,
'server_port': url_data.port,
'script_name': path,
'get_data': request.args.copy(),
"https": "on" if proto == "https" else "off",
"http_host": request.headers["host"] + host_suffix,
"server_port": url_data.port,
"script_name": path,
"get_data": request.args.copy(),
# Uncomment if using ADFS as IdP, https://github.com/onelogin/python-saml/pull/144
# 'lowercase_urlencoding': True,
'post_data': request.form.copy(),
'cookie': {"session": session},
'request': request
"post_data": request.form.copy(),
"cookie": {"session": session},
"request": request,
}
@ -140,8 +142,11 @@ def is_saml2_available():
def get_saml2_provider():
return config("idp_name", default="saml2") if is_saml2_available() and len(
config("idp_name", default="saml2")) > 0 else None
return (
config("idp_name", default="saml2")
if is_saml2_available() and len(config("idp_name", default="saml2")) > 0
else None
)
def get_landing_URL(query_params: dict = None, redirect_to_link2=False):
@ -152,11 +157,14 @@ def get_landing_URL(query_params: dict = None, redirect_to_link2=False):
if redirect_to_link2:
if len(config("sso_landing_override", default="")) == 0:
logging.warning("SSO trying to redirect to custom URL, but sso_landing_override env var is empty")
logging.warning(
"SSO trying to redirect to custom URL, but sso_landing_override env var is empty"
)
else:
return config("sso_landing_override") + query_params
return config("SITE_URL") + config("sso_landing", default="/login") + query_params
base_url = config("SITE_URLx") if config("LOCAL_DEV") else config("SITE_URL")
return base_url + config("sso_landing", default="/login") + query_params
environ["hastSAML2"] = str(is_saml2_available())

View file

@ -13,23 +13,9 @@ REFRESH_SECRET_KEY = config("SCIM_REFRESH_SECRET_KEY")
ALGORITHM = config("SCIM_JWT_ALGORITHM")
ACCESS_TOKEN_EXPIRE_SECONDS = int(config("SCIM_ACCESS_TOKEN_EXPIRE_SECONDS"))
REFRESH_TOKEN_EXPIRE_SECONDS = int(config("SCIM_REFRESH_TOKEN_EXPIRE_SECONDS"))
AUDIENCE="okta_client"
ISSUER=config("JWT_ISSUER"),
AUDIENCE = config("SCIM_AUDIENCE")
ISSUER = (config("JWT_ISSUER"),)
# Simulated Okta Client Credentials
# OKTA_CLIENT_ID = "okta-client"
# OKTA_CLIENT_SECRET = "okta-secret"
# class TokenRequest(BaseModel):
# client_id: str
# client_secret: str
# async def authenticate_client(token_request: TokenRequest):
# """Validate Okta Client Credentials and issue JWT"""
# if token_request.client_id != OKTA_CLIENT_ID or token_request.client_secret != OKTA_CLIENT_SECRET:
# raise HTTPException(status_code=401, detail="Invalid client credentials")
# return {"access_token": create_jwt(), "token_type": "bearer"}
def create_tokens(tenant_id):
curr_time = time.time()
@ -38,7 +24,7 @@ def create_tokens(tenant_id):
"sub": "scim_server",
"aud": AUDIENCE,
"iss": ISSUER,
"exp": ""
"exp": "",
}
access_payload.update({"exp": curr_time + ACCESS_TOKEN_EXPIRE_SECONDS})
access_token = jwt.encode(access_payload, ACCESS_SECRET_KEY, algorithm=ALGORITHM)
@ -47,20 +33,26 @@ def create_tokens(tenant_id):
refresh_payload.update({"exp": curr_time + REFRESH_TOKEN_EXPIRE_SECONDS})
refresh_token = jwt.encode(refresh_payload, REFRESH_SECRET_KEY, algorithm=ALGORITHM)
return access_token, refresh_token
return access_token, refresh_token, ACCESS_TOKEN_EXPIRE_SECONDS
def verify_access_token(token: str):
try:
payload = jwt.decode(token, ACCESS_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE)
payload = jwt.decode(
token, ACCESS_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE
)
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")
def verify_refresh_token(token: str):
try:
payload = jwt.decode(token, REFRESH_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE)
payload = jwt.decode(
token, REFRESH_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE
)
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
@ -68,10 +60,25 @@ def verify_refresh_token(token: str):
raise HTTPException(status_code=401, detail="Invalid token")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
required_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Authentication Dependency
def auth_required(token: str = Depends(oauth2_scheme)):
def auth_required(token: str = Depends(required_oauth2_scheme)):
"""Dependency to check Authorization header."""
if config("SCIM_AUTH_TYPE") == "OAuth2":
payload = verify_access_token(token)
return payload["tenant_id"]
optional_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
def auth_optional(token: str | None = Depends(optional_oauth2_scheme)):
if token is None:
return None
try:
tenant_id = auth_required(token)
return tenant_id
except HTTPException:
return None

View file

@ -1,466 +0,0 @@
import logging
import re
import uuid
from typing import Optional
from decouple import config
from fastapi import Depends, HTTPException, Header, Query, Response
from fastapi.responses import JSONResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel, Field
import schemas
from chalicelib.core import users, roles, tenants
from chalicelib.utils.scim_auth import auth_required, create_tokens, verify_refresh_token
from routers.base import get_routers
logger = logging.getLogger(__name__)
public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2")
"""Authentication endpoints"""
class RefreshRequest(BaseModel):
refresh_token: str
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Login endpoint to generate tokens
@public_app.post("/token")
async def login(host: str = Header(..., alias="Host"), form_data: OAuth2PasswordRequestForm = Depends()):
subdomain = host.split(".")[0]
# Missing authentication part, to add
if form_data.username != config("SCIM_USER") or form_data.password != config("SCIM_PASSWORD"):
raise HTTPException(status_code=401, detail="Invalid credentials")
subdomain = "Openreplay EE"
tenant = tenants.get_by_name(subdomain)
access_token, refresh_token = create_tokens(tenant_id=tenant["tenantId"])
return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"}
# Refresh token endpoint
@public_app.post("/refresh")
async def refresh_token(r: RefreshRequest):
payload = verify_refresh_token(r.refresh_token)
new_access_token, _ = create_tokens(tenant_id=payload["tenant_id"])
return {"access_token": new_access_token, "token_type": "Bearer"}
"""
User endpoints
"""
class Name(BaseModel):
givenName: str
familyName: str
class Email(BaseModel):
primary: bool
value: str
type: str
class UserRequest(BaseModel):
schemas: list[str]
userName: str
name: Name
emails: list[Email]
displayName: str
locale: str
externalId: str
groups: list[dict]
password: str = Field(default=None)
active: bool
class UserResponse(BaseModel):
schemas: list[str]
id: str
userName: str
name: Name
emails: list[Email] # ignore for now
displayName: str
locale: str
externalId: str
active: bool
groups: list[dict]
meta: dict = Field(default={"resourceType": "User"})
class PatchUserRequest(BaseModel):
schemas: list[str]
Operations: list[dict]
@public_app.get("/Users", dependencies=[Depends(auth_required)])
async def get_users(
start_index: int = Query(1, alias="startIndex"),
count: Optional[int] = Query(None, alias="count"),
email: Optional[str] = Query(None, alias="filter"),
):
"""Get SCIM Users"""
if email:
email = email.split(" ")[2].strip('"')
result_users = users.get_users_paginated(start_index, count, email)
serialized_users = []
for user in result_users:
serialized_users.append(
UserResponse(
schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"],
id = user["data"]["userId"],
userName = user["email"],
name = Name.model_validate(user["data"]["name"]),
emails = [Email.model_validate(user["data"]["emails"])],
displayName = user["name"],
locale = user["data"]["locale"],
externalId = user["internalId"],
active = True, # ignore for now, since, can't insert actual timestamp
groups = [], # ignore
).model_dump(mode='json')
)
return JSONResponse(
status_code=200,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
"totalResults": len(serialized_users),
"startIndex": start_index,
"itemsPerPage": len(serialized_users),
"Resources": serialized_users,
},
)
@public_app.get("/Users/{user_id}", dependencies=[Depends(auth_required)])
def get_user(user_id: str):
"""Get SCIM User"""
tenant_id = 1
user = users.get_by_uuid(user_id, tenant_id)
if not user:
return JSONResponse(
status_code=404,
content={
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": "User not found",
"status": 404,
}
)
res = UserResponse(
schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"],
id = user["data"]["userId"],
userName = user["email"],
name = Name.model_validate(user["data"]["name"]),
emails = [Email.model_validate(user["data"]["emails"])],
displayName = user["name"],
locale = user["data"]["locale"],
externalId = user["internalId"],
active = True, # ignore for now, since, can't insert actual timestamp
groups = [], # ignore
)
return JSONResponse(status_code=201, content=res.model_dump(mode='json'))
@public_app.post("/Users", dependencies=[Depends(auth_required)])
async def create_user(r: UserRequest):
"""Create SCIM User"""
tenant_id = 1
existing_user = users.get_by_email_only(r.userName)
deleted_user = users.get_deleted_user_by_email(r.userName)
if existing_user:
return JSONResponse(
status_code = 409,
content = {
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"detail": "User already exists in the database.",
"status": 409,
}
)
elif deleted_user:
user_id = users.get_deleted_by_uuid(deleted_user["data"]["userId"], tenant_id)
user = users.restore_scim_user(user_id=user_id["userId"], tenant_id=tenant_id, user_uuid=uuid.uuid4().hex, email=r.emails[0].value, admin=False,
display_name=r.displayName, full_name=r.name.model_dump(mode='json'), emails=r.emails[0].model_dump(mode='json'),
origin="okta", locale=r.locale, role_id=None, internal_id=r.externalId)
else:
try:
user = users.create_scim_user(tenant_id=tenant_id, user_uuid=uuid.uuid4().hex, email=r.emails[0].value, admin=False,
display_name=r.displayName, full_name=r.name.model_dump(mode='json'), emails=r.emails[0].model_dump(mode='json'),
origin="okta", locale=r.locale, role_id=None, internal_id=r.externalId)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
res = UserResponse(
schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"],
id = user["data"]["userId"],
userName = r.userName,
name = r.name,
emails = r.emails,
displayName = r.displayName,
locale = r.locale,
externalId = r.externalId,
active = r.active, # ignore for now, since, can't insert actual timestamp
groups = [], # ignore
)
return JSONResponse(status_code=201, content=res.model_dump(mode='json'))
@public_app.put("/Users/{user_id}", dependencies=[Depends(auth_required)])
def update_user(user_id: str, r: UserRequest):
"""Update SCIM User"""
tenant_id = 1
user = users.get_by_uuid(user_id, tenant_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
changes = r.model_dump(mode='json', exclude={"schemas", "emails", "name", "locale", "groups", "password", "active"}) # some of these should be added later if necessary
nested_changes = r.model_dump(mode='json', include={"name", "emails"})
mapping = {"userName": "email", "displayName": "name", "externalId": "internal_id"} # mapping between scim schema field names and local database model, can be done as config?
for k, v in mapping.items():
if k in changes:
changes[v] = changes.pop(k)
changes["data"] = {}
for k, v in nested_changes.items():
value_to_insert = v[0] if k == "emails" else v
changes["data"][k] = value_to_insert
try:
users.update(tenant_id, user["userId"], changes)
res = UserResponse(
schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"],
id = user["data"]["userId"],
userName = r.userName,
name = r.name,
emails = r.emails,
displayName = r.displayName,
locale = r.locale,
externalId = r.externalId,
active = r.active, # ignore for now, since, can't insert actual timestamp
groups = [], # ignore
)
return JSONResponse(status_code=201, content=res.model_dump(mode='json'))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@public_app.patch("/Users/{user_id}", dependencies=[Depends(auth_required)])
def deactivate_user(user_id: str, r: PatchUserRequest):
"""Deactivate user, soft-delete"""
tenant_id = 1
active = r.model_dump(mode='json')["Operations"][0]["value"]["active"]
if active:
raise HTTPException(status_code=404, detail="Activating user is not supported")
user = users.get_by_uuid(user_id, tenant_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
users.delete_member_as_admin(tenant_id, user["userId"])
return Response(status_code=204, content="")
@public_app.delete("/Users/{user_uuid}", dependencies=[Depends(auth_required)])
def delete_user(user_uuid: str):
"""Delete user from database, hard-delete"""
tenant_id = 1
user = users.get_by_uuid(user_uuid, tenant_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
users.__hard_delete_user_uuid(user_uuid)
return Response(status_code=204, content="")
"""
Group endpoints
"""
class Operation(BaseModel):
op: str
path: str = Field(default=None)
value: list[dict] | dict = Field(default=None)
class GroupGetResponse(BaseModel):
schemas: list[str] = Field(default=["urn:ietf:params:scim:api:messages:2.0:ListResponse"])
totalResults: int
startIndex: int
itemsPerPage: int
resources: list = Field(alias="Resources")
class GroupRequest(BaseModel):
schemas: list[str] = Field(default=["urn:ietf:params:scim:schemas:core:2.0:Group"])
displayName: str = Field(default=None)
members: list = Field(default=None)
operations: list[Operation] = Field(default=None, alias="Operations")
class GroupPatchRequest(BaseModel):
schemas: list[str] = Field(default=["urn:ietf:params:scim:api:messages:2.0:PatchOp"])
operations: list[Operation] = Field(alias="Operations")
class GroupResponse(BaseModel):
schemas: list[str] = Field(default=["urn:ietf:params:scim:schemas:core:2.0:Group"])
id: str
displayName: str
members: list
meta: dict = Field(default={"resourceType": "Group"})
@public_app.get("/Groups", dependencies=[Depends(auth_required)])
def get_groups(
start_index: int = Query(1, alias="startIndex"),
count: Optional[int] = Query(None, alias="count"),
group_name: Optional[str] = Query(None, alias="filter"),
):
"""Get groups"""
tenant_id = 1
res = []
if group_name:
group_name = group_name.split(" ")[2].strip('"')
groups = roles.get_roles_with_uuid_paginated(tenant_id, start_index, count, group_name)
res = [{
"id": group["data"]["groupId"],
"meta": {
"created": group["createdAt"],
"lastModified": "", # not currently a field
"version": "v1.0"
},
"displayName": group["name"]
} for group in groups
]
return JSONResponse(
status_code=200,
content=GroupGetResponse(
totalResults=len(groups),
startIndex=start_index,
itemsPerPage=len(groups),
Resources=res
).model_dump(mode='json'))
@public_app.get("/Groups/{group_id}", dependencies=[Depends(auth_required)])
def get_group(group_id: str):
"""Get a group by id"""
tenant_id = 1
group = roles.get_role_by_group_id(tenant_id, group_id)
if not group:
raise HTTPException(status_code=404, detail="Group not found")
members = roles.get_users_by_group_uuid(tenant_id, group["data"]["groupId"])
members = [{"value": member["data"]["userId"], "display": member["name"]} for member in members]
return JSONResponse(
status_code=200,
content=GroupResponse(
id=group["data"]["groupId"],
displayName=group["name"],
members=members,
).model_dump(mode='json'))
@public_app.post("/Groups", dependencies=[Depends(auth_required)])
def create_group(r: GroupRequest):
"""Create a group"""
tenant_id = 1
member_role = roles.get_member_permissions(tenant_id)
try:
data = schemas.RolePayloadSchema(name=r.displayName, permissions=member_role["permissions"]) # permissions by default are same as for member role
group = roles.create_as_admin(tenant_id, uuid.uuid4().hex, data)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
added_members = []
for member in r.members:
user = users.get_by_uuid(member["value"], tenant_id)
if user:
users.update(tenant_id, user["userId"], {"role_id": group["roleId"]})
added_members.append({
"value": user["data"]["userId"],
"display": user["name"]
})
return JSONResponse(
status_code=200,
content=GroupResponse(
id=group["data"]["groupId"],
displayName=group["name"],
members=added_members,
).model_dump(mode='json'))
@public_app.put("/Groups/{group_id}", dependencies=[Depends(auth_required)])
def update_put_group(group_id: str, r: GroupRequest):
"""Update a group or members of the group (not used by anything yet)"""
tenant_id = 1
group = roles.get_role_by_group_id(tenant_id, group_id)
if not group:
raise HTTPException(status_code=404, detail="Group not found")
if r.operations and r.operations[0].op == "replace" and r.operations[0].path is None:
roles.update_group_name(tenant_id, group["data"]["groupId"], r.operations[0].value["displayName"])
return Response(status_code=200, content="")
members = r.members
modified_members = []
for member in members:
user = users.get_by_uuid(member["value"], tenant_id)
if user:
users.update(tenant_id, user["userId"], {"role_id": group["roleId"]})
modified_members.append({
"value": user["data"]["userId"],
"display": user["name"]
})
return JSONResponse(
status_code=200,
content=GroupResponse(
id=group_id,
displayName=group["name"],
members=modified_members,
).model_dump(mode='json'))
@public_app.patch("/Groups/{group_id}", dependencies=[Depends(auth_required)])
def update_patch_group(group_id: str, r: GroupPatchRequest):
"""Update a group or members of the group, used by AIW"""
tenant_id = 1
group = roles.get_role_by_group_id(tenant_id, group_id)
if not group:
raise HTTPException(status_code=404, detail="Group not found")
if r.operations[0].op == "replace" and r.operations[0].path is None:
roles.update_group_name(tenant_id, group["data"]["groupId"], r.operations[0].value["displayName"])
return Response(status_code=200, content="")
modified_members = []
for op in r.operations:
if op.op == "add" or op.op == "replace":
# Both methods work as "replace"
for u in op.value:
user = users.get_by_uuid(u["value"], tenant_id)
if user:
users.update(tenant_id, user["userId"], {"role_id": group["roleId"]})
modified_members.append({
"value": user["data"]["userId"],
"display": user["name"]
})
elif op.op == "remove":
user_id = re.search(r'\[value eq \"([a-f0-9]+)\"\]', op.path).group(1)
roles.remove_group_membership(tenant_id, group_id, user_id)
return JSONResponse(
status_code=200,
content=GroupResponse(
id=group_id,
displayName=group["name"],
members=modified_members,
).model_dump(mode='json'))
@public_app.delete("/Groups/{group_id}", dependencies=[Depends(auth_required)])
def delete_group(group_id: str):
"""Delete a group, hard-delete"""
# possibly need to set the user's roles to default member role, instead of null
tenant_id = 1
group = roles.get_role_by_group_id(tenant_id, group_id)
if not group:
raise HTTPException(status_code=404, detail="Group not found")
roles.delete_scim_group(tenant_id, group["data"]["groupId"])
return Response(status_code=200, content="")

View file

164
ee/api/routers/scim/api.py Normal file
View file

@ -0,0 +1,164 @@
from scim2_server import utils
from routers.base import get_routers
from routers.scim.providers import MultiTenantProvider
from routers.scim.backends import PostgresBackend
from routers.scim.postgres_resource import PostgresResource
from routers.scim import users, groups, helpers
from urllib.parse import urlencode
from chalicelib.utils import pg_client
from fastapi import HTTPException, Request
from fastapi.responses import RedirectResponse
from chalicelib.utils.scim_auth import (
create_tokens,
verify_refresh_token,
)
b = PostgresBackend()
b.register_postgres_resource(
"User",
PostgresResource(
query_resources=users.query_resources,
get_resource=users.get_resource,
create_resource=users.create_resource,
search_existing=users.search_existing,
restore_resource=users.restore_resource,
delete_resource=users.delete_resource,
update_resource=users.update_resource,
),
)
b.register_postgres_resource(
"Group",
PostgresResource(
query_resources=groups.query_resources,
get_resource=groups.get_resource,
create_resource=groups.create_resource,
search_existing=groups.search_existing,
restore_resource=None,
delete_resource=groups.delete_resource,
update_resource=groups.update_resource,
),
)
scim_app = MultiTenantProvider(b)
for schema in utils.load_default_schemas().values():
scim_app.register_schema(schema)
for schema in helpers.load_custom_schemas().values():
scim_app.register_schema(schema)
for resource_type in helpers.load_custom_resource_types().values():
scim_app.register_resource_type(resource_type)
public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2")
@public_app.post("/token")
async def post_token(r: Request):
form = await r.form()
client_id = form.get("client_id")
client_secret = form.get("client_secret")
with pg_client.PostgresClient() as cur:
try:
cur.execute(
cur.mogrify(
"""
SELECT tenant_id
FROM public.tenants
WHERE tenant_id=%(tenant_id)s AND tenant_key=%(tenant_key)s
""",
{"tenant_id": int(client_id), "tenant_key": client_secret},
)
)
except ValueError:
raise HTTPException(status_code=401, detail="Invalid credentials")
tenant = cur.fetchone()
if not tenant:
raise HTTPException(status_code=401, detail="Invalid credentials")
grant_type = form.get("grant_type")
if grant_type == "refresh_token":
refresh_token = form.get("refresh_token")
verify_refresh_token(refresh_token)
else:
code = form.get("code")
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT *
FROM public.scim_auth_codes
WHERE auth_code=%(auth_code)s AND tenant_id=%(tenant_id)s AND used IS FALSE
""",
{"auth_code": code, "tenant_id": int(client_id)},
)
)
if cur.fetchone() is None:
raise HTTPException(
status_code=401, detail="Invalid code/client_id pair"
)
cur.execute(
cur.mogrify(
"""
UPDATE public.scim_auth_codes
SET used=TRUE
WHERE auth_code=%(auth_code)s AND tenant_id=%(tenant_id)s AND used IS FALSE
""",
{"auth_code": code, "tenant_id": int(client_id)},
)
)
access_token, refresh_token, expires_in = create_tokens(
tenant_id=tenant["tenant_id"]
)
return {
"access_token": access_token,
"token_type": "Bearer",
"expires_in": expires_in,
"refresh_token": refresh_token,
}
# note(jon): this might be specific to okta. if so, we should probably put specify that in the endpoint
@public_app.get("/authorize")
async def get_authorize(
r: Request,
response_type: str,
client_id: str,
redirect_uri: str,
state: str | None = None,
):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
UPDATE public.scim_auth_codes
SET used=TRUE
WHERE tenant_id=%(tenant_id)s
""",
{"tenant_id": int(client_id)},
)
)
cur.execute(
cur.mogrify(
"""
INSERT INTO public.scim_auth_codes (tenant_id)
VALUES (%(tenant_id)s)
RETURNING auth_code
""",
{"tenant_id": int(client_id)},
)
)
code = cur.fetchone()["auth_code"]
params = {"code": code}
if state:
params["state"] = state
url = f"{redirect_uri}?{urlencode(params)}"
return RedirectResponse(url)

View file

@ -0,0 +1,203 @@
from scim2_server import backend
from scim2_server.filter import evaluate_filter
from scim2_server.utils import SCIMException
from scim2_models import (
SearchRequest,
Resource,
Context,
Error,
)
from scim2_filter_parser import lexer
from scim2_filter_parser.parser import SCIMParser
from routers.scim.postgres_resource import PostgresResource
from scim2_server.operators import ResolveSortOperator
import operator
class PostgresBackend(backend.Backend):
def __init__(self):
super().__init__()
self._postgres_resources = {}
def register_postgres_resource(
self, resource_type_id: str, postgres_resource: PostgresResource
):
self._postgres_resources[resource_type_id] = postgres_resource
def query_resources(
self,
search_request: SearchRequest,
tenant_id: int,
resource_type_id: str | None = None,
) -> tuple[int, list[Resource]]:
"""Query the backend for a set of resources.
:param search_request: SearchRequest instance describing the
query.
:param resource_type_id: ID of the resource type to query. If
None, all resource types are queried.
:return: A tuple of "total results" and a List of found
Resources. The List must contain a copy of resources.
Mutating elements in the List must not modify the data
stored in the backend.
:raises SCIMException: If the backend only supports querying for
one resource type at a time, setting resource_type_id to
None the backend may raise a
SCIMException(Error.make_too_many_error()).
"""
start_index = (search_request.start_index or 1) - 1
tree = None
if search_request.filter is not None:
token_stream = lexer.SCIMLexer().tokenize(search_request.filter)
tree = SCIMParser().parse(token_stream)
# todo(jon): handle the case when resource_type_id is None.
# we're assuming it's never None for now.
# but, this is fine to leave as it doesn't seem to used or reached in
# any of my tests yet.
if not resource_type_id:
raise NotImplementedError
resources = self._postgres_resources[resource_type_id].query_resources(
tenant_id
)
model = self.get_model(resource_type_id)
resources = [
model.model_validate(r, scim_ctx=Context.RESOURCE_QUERY_RESPONSE)
for r in resources
]
resources = [r for r in resources if (tree is None or evaluate_filter(r, tree))]
if search_request.sort_by is not None:
descending = search_request.sort_order == SearchRequest.SortOrder.descending
sort_operator = ResolveSortOperator(search_request.sort_by)
# To ensure that unset attributes are sorted last (when ascending, as defined in the RFC),
# we have to divide the result set into a set and unset subset.
unset_values = []
set_values = []
for resource in resources:
result = sort_operator(resource)
if result is None:
unset_values.append(resource)
else:
set_values.append((resource, result))
set_values.sort(key=operator.itemgetter(1), reverse=descending)
set_values = [value[0] for value in set_values]
if descending:
resources = unset_values + set_values
else:
resources = set_values + unset_values
found_resources = resources[start_index:]
if search_request.count is not None:
found_resources = resources[: search_request.count]
return len(resources), found_resources
def get_resource(
self, tenant_id: int, resource_type_id: str, object_id: str
) -> Resource | None:
"""Query the backend for a resources by its ID.
:param resource_type_id: ID of the resource type to get the
object from.
:param object_id: ID of the object to get.
:return: The resource object if it exists, None otherwise. The
resource must be a copy, modifying it must not change the
data stored in the backend.
"""
resource = self._postgres_resources[resource_type_id].get_resource(
object_id, tenant_id
)
if resource:
model = self.get_model(resource_type_id)
resource = model.model_validate(resource)
return resource
def delete_resource(
self, tenant_id: int, resource_type_id: str, object_id: str
) -> bool:
"""Delete a resource.
:param resource_type_id: ID of the resource type to delete the
object from.
:param object_id: ID of the object to delete.
:return: True if the resource was deleted, False otherwise.
"""
resource = self.get_resource(tenant_id, resource_type_id, object_id)
if resource:
self._postgres_resources[resource_type_id].delete_resource(
object_id, tenant_id
)
return True
return False
def create_resource(
self, tenant_id: int, resource_type_id: str, resource: Resource
) -> Resource | None:
"""Create a resource.
:param resource_type_id: ID of the resource type to create.
:param resource: Resource to create.
:return: The created resource. Creation should set system-
defined attributes (ID, Metadata). May be the same object
that is passed in.
"""
model = self.get_model(resource_type_id)
existing = self._postgres_resources[resource_type_id].search_existing(
tenant_id, resource
)
if existing:
existing = model.model_validate(existing)
if existing.active:
raise SCIMException(Error.make_uniqueness_error())
resource = self._postgres_resources[resource_type_id].restore_resource(
tenant_id, resource
)
else:
resource = self._postgres_resources[resource_type_id].create_resource(
tenant_id, resource
)
resource = model.model_validate(resource)
return resource
def update_resource(
self, tenant_id: int, resource_type_id: str, resource: Resource
) -> Resource | None:
"""Update a resource. The resource is identified by its ID.
:param resource_type_id: ID of the resource type to update.
:param resource: Resource to update.
:return: The updated resource. Updating should update the
"meta.lastModified" data. May be the same object that is
passed in.
"""
model = self.get_model(resource_type_id)
existing = self._postgres_resources[resource_type_id].search_existing(
tenant_id, resource
)
if existing:
existing = model.model_validate(existing)
if existing.active:
if existing.id != resource.id:
raise SCIMException(Error.make_uniqueness_error())
resource = self._postgres_resources[resource_type_id].update_resource(
tenant_id, resource
)
else:
self._postgres_resources[resource_type_id].delete_resource(
existing.id, tenant_id
)
resource = self._postgres_resources[resource_type_id].update_resource(
resource.id, tenant_id, resource
)
else:
resource = self._postgres_resources[resource_type_id].update_resource(
tenant_id, resource
)
resource = model.model_validate(resource)
return resource

View file

@ -0,0 +1,36 @@
[{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ResourceType"],
"id": "User",
"name": "User",
"endpoint": "/Users",
"description": "User Account",
"schema": "urn:ietf:params:scim:schemas:core:2.0:User",
"schemaExtensions": [
{
"schema":
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
"required": true
},
{
"schema":
"urn:ietf:params:scim:schemas:extension:openreplay:2.0:User",
"required": true
}
],
"meta": {
"location": "/v2/ResourceTypes/User",
"resourceType": "ResourceType"
}
},
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ResourceType"],
"id": "Group",
"name": "Group",
"endpoint": "/Groups",
"description": "Group",
"schema": "urn:ietf:params:scim:schemas:core:2.0:Group",
"meta": {
"location": "/v2/ResourceTypes/Group",
"resourceType": "ResourceType"
}
}]

View file

@ -0,0 +1,32 @@
[
{
"id": "urn:ietf:params:scim:schemas:extension:openreplay:2.0:User",
"name": "OpenreplayUser",
"description": "Openreplay User Account Extension",
"attributes": [
{
"name": "permissions",
"type": "string",
"multiValued": true,
"description": "A list of permissions for the User that represent a thing the User is capable of doing.",
"required": false,
"canonicalValues": ["SESSION_REPLAY", "DEV_TOOLS", "METRICS", "ASSIST_LIVE", "ASSIST_CALL", "SPOT", "SPOT_PUBLIC"],
"mutability": "readWrite",
"returned": "default"
},
{
"name": "projectKeys",
"type": "string",
"multiValued": true,
"description": "A list of project keys for the User that represent a project the User is allowed to work on.",
"required": false,
"mutability": "readWrite",
"returned": "default"
}
],
"meta": {
"resourceType": "Schema",
"location": "/v2/Schemas/urn:ietf:params:scim:schemas:extension:openreplay:2.0:User"
}
}
]

View file

@ -0,0 +1,203 @@
from typing import Any
from datetime import datetime
from psycopg2.extensions import AsIs
from chalicelib.utils import pg_client
from routers.scim import helpers
from scim2_models import Error, Resource
from scim2_server.utils import SCIMException
def convert_provider_resource_to_client_resource(
provider_resource: dict,
) -> dict:
members = provider_resource["users"] or []
return {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"id": str(provider_resource["role_id"]),
"meta": {
"resourceType": "Group",
"created": provider_resource["created_at"].strftime("%Y-%m-%dT%H:%M:%SZ"),
"lastModified": provider_resource["updated_at"].strftime(
"%Y-%m-%dT%H:%M:%SZ"
),
},
"displayName": provider_resource["name"],
"members": [
{
"value": str(member["user_id"]),
"$ref": f"Users/{member['user_id']}",
"type": "User",
}
for member in members
],
}
def query_resources(tenant_id: int) -> list[dict]:
query = _main_select_query(tenant_id)
with pg_client.PostgresClient() as cur:
cur.execute(query)
items = cur.fetchall()
return [convert_provider_resource_to_client_resource(item) for item in items]
def get_resource(resource_id: str, tenant_id: int) -> dict | None:
query = _main_select_query(tenant_id, resource_id)
with pg_client.PostgresClient() as cur:
cur.execute(query)
item = cur.fetchone()
if item:
return convert_provider_resource_to_client_resource(item)
return None
def delete_resource(resource_id: str, tenant_id: int) -> None:
_update_resource_sql(
resource_id=resource_id,
tenant_id=tenant_id,
deleted_at=datetime.now(),
)
def search_existing(tenant_id: int, resource: Resource) -> dict | None:
return None
def create_resource(tenant_id: int, resource: Resource) -> dict:
with pg_client.PostgresClient() as cur:
user_ids = (
[int(x.value) for x in resource.members] if resource.members else None
)
user_id_clause = helpers.safe_mogrify_array(user_ids, "int", cur)
try:
cur.execute(
cur.mogrify(
"""
INSERT INTO public.roles (
name,
tenant_id
)
VALUES (
%(name)s,
%(tenant_id)s
)
RETURNING role_id
""",
{
"name": resource.display_name,
"tenant_id": tenant_id,
},
)
)
except Exception:
raise SCIMException(Error.make_invalid_value_error())
role_id = cur.fetchone()["role_id"]
cur.execute(
f"""
UPDATE public.users
SET
updated_at = now(),
role_id = {role_id}
WHERE users.user_id = ANY({user_id_clause})
"""
)
cur.execute(f"{_main_select_query(tenant_id, role_id)} LIMIT 1")
item = cur.fetchone()
return convert_provider_resource_to_client_resource(item)
def update_resource(tenant_id: int, resource: Resource) -> dict | None:
item = _update_resource_sql(
resource_id=resource.id,
tenant_id=tenant_id,
name=resource.display_name,
user_ids=[int(x.value) for x in resource.members],
deleted_at=None,
)
return convert_provider_resource_to_client_resource(item)
def _main_select_query(tenant_id: int, resource_id: str | None = None) -> str:
where_and_clauses = [
f"roles.tenant_id = {tenant_id}",
"roles.deleted_at IS NULL",
]
if resource_id is not None:
where_and_clauses.append(f"roles.role_id = {resource_id}")
where_clause = " AND ".join(where_and_clauses)
return f"""
SELECT
roles.*,
COALESCE(
(
SELECT json_agg(users)
FROM public.users
WHERE users.role_id = roles.role_id
),
'[]'
) AS users,
COALESCE(
(
SELECT json_agg(projects.project_key)
FROM public.projects
LEFT JOIN public.roles_projects USING (project_id)
WHERE roles_projects.role_id = roles.role_id
),
'[]'
) AS project_keys
FROM public.roles
WHERE {where_clause}
"""
def _update_resource_sql(
resource_id: int,
tenant_id: int,
user_ids: list[int] | None = None,
**kwargs: dict[str, Any],
) -> dict[str, Any]:
with pg_client.PostgresClient() as cur:
kwargs["updated_at"] = datetime.now()
set_fragments = [
cur.mogrify("%s = %s", (AsIs(k), v)).decode("utf-8")
for k, v in kwargs.items()
]
set_clause = ", ".join(set_fragments)
user_id_clause = helpers.safe_mogrify_array(user_ids, "int", cur)
cur.execute(
f"""
UPDATE public.users
SET
updated_at = now(),
role_id = NULL
WHERE
users.role_id = {resource_id}
AND users.user_id != ALL({user_id_clause})
RETURNING *
"""
)
cur.execute(
f"""
UPDATE public.users
SET
updated_at = now(),
role_id = {resource_id}
WHERE
(users.role_id != {resource_id} OR users.role_id IS NULL)
AND users.user_id = ANY({user_id_clause})
RETURNING *
"""
)
cur.execute(
f"""
UPDATE public.roles
SET {set_clause}
WHERE
roles.role_id = {resource_id}
AND roles.tenant_id = {tenant_id}
"""
)
cur.execute(f"{_main_select_query(tenant_id, resource_id)} LIMIT 1")
return cur.fetchone()

View file

@ -0,0 +1,44 @@
from typing import Any, Literal
from chalicelib.utils import pg_client
from scim2_models import Schema, Resource, ResourceType
import os
import json
def safe_mogrify_array(
items: list[Any] | None,
array_type: Literal["varchar", "int"],
cursor: pg_client.PostgresClient,
) -> str:
items = items or []
fragments = [cursor.mogrify("%s", (item,)).decode("utf-8") for item in items]
result = f"ARRAY[{', '.join(fragments)}]::{array_type}[]"
return result
def load_json_resource(json_name: str) -> dict:
with open(json_name) as f:
return json.load(f)
def load_scim_resource(
json_name: str, type_: type[Resource]
) -> dict[str, type[Resource]]:
ret = {}
definitions = load_json_resource(json_name)
for d in definitions:
model = type_.model_validate(d)
ret[model.id] = model
return ret
def load_custom_schemas() -> dict[str, Schema]:
json_name = os.path.join("routers", "scim", "fixtures", "custom_schemas.json")
return load_scim_resource(json_name, Schema)
def load_custom_resource_types() -> dict[str, ResourceType]:
json_name = os.path.join(
"routers", "scim", "fixtures", "custom_resource_types.json"
)
return load_scim_resource(json_name, ResourceType)

View file

@ -0,0 +1,14 @@
from dataclasses import dataclass
from typing import Callable
from scim2_models import Resource
@dataclass
class PostgresResource:
query_resources: Callable[[int], list[dict]]
get_resource: Callable[[str, int], dict | None]
create_resource: Callable[[int, Resource], dict]
search_existing: Callable[[int, Resource], dict | None]
restore_resource: Callable[[int, Resource], dict] | None
delete_resource: Callable[[str, int], None]
update_resource: Callable[[int, Resource], dict]

View file

@ -0,0 +1,280 @@
import traceback
from typing import Union
from scim2_server import provider
from scim2_models import (
AuthenticationScheme,
ServiceProviderConfig,
Patch,
Bulk,
Filter,
Sort,
ETag,
Meta,
ChangePassword,
Error,
ResourceType,
Context,
ListResponse,
PatchOp,
)
from werkzeug import Request, Response
from werkzeug.exceptions import HTTPException, NotFound, PreconditionFailed
from pydantic import ValidationError
from werkzeug.routing.exceptions import RequestRedirect
from scim2_server.utils import SCIMException, merge_resources
from chalicelib.utils.scim_auth import verify_access_token
class MultiTenantProvider(provider.SCIMProvider):
def check_auth(self, request: Request):
auth = request.headers.get("Authorization")
if not auth or not auth.startswith("Bearer "):
return None
token = auth[len("Bearer ") :]
if not token:
return Response(
"Missing or invalid Authorization header",
status=401,
headers={"WWW-Authenticate": 'Bearer realm="login required"'},
)
payload = verify_access_token(token)
tenant_id = payload["tenant_id"]
return tenant_id
def get_service_provider_config(self):
auth_schemes = [
AuthenticationScheme(
type="oauthbearertoken",
name="OAuth Bearer Token",
description="Authentication scheme using the OAuth Bearer Token Standard. The access token should be sent in the 'Authorization' header using the Bearer schema.",
spec_uri="https://datatracker.ietf.org/doc/html/rfc6750",
)
]
return ServiceProviderConfig(
# todo(jon): write correct documentation uri
documentation_uri="https://www.example.com/",
patch=Patch(supported=True),
bulk=Bulk(supported=False),
filter=Filter(supported=True, max_results=1000),
change_password=ChangePassword(supported=False),
sort=Sort(supported=True),
etag=ETag(supported=False),
authentication_schemes=auth_schemes,
meta=Meta(resource_type="ServiceProviderConfig"),
)
def query_resource(
self, request: Request, tenant_id: int, resource: ResourceType | None
):
search_request = self.build_search_request(request)
kwargs = {}
if resource is not None:
kwargs["resource_type_id"] = resource.id
total_results, results = self.backend.query_resources(
search_request=search_request, tenant_id=tenant_id, **kwargs
)
for r in results:
self.adjust_location(request, r)
resources = [
s.model_dump(
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
attributes=search_request.attributes,
excluded_attributes=search_request.excluded_attributes,
)
for s in results
]
return ListResponse[Union[tuple(self.backend.get_models())]]( # noqa: UP007
total_results=total_results,
items_per_page=search_request.count,
start_index=search_request.start_index,
resources=resources,
)
def call_resource(
self, request: Request, resource_endpoint: str, **kwargs
) -> Response:
resource_type = self.backend.get_resource_type_by_endpoint(
"/" + resource_endpoint
)
if not resource_type:
raise NotFound
if "tenant_id" not in kwargs:
raise Exception
tenant_id = kwargs["tenant_id"]
match request.method:
case "GET":
return self.make_response(
self.query_resource(request, tenant_id, resource_type).model_dump(
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
)
)
case _: # "POST"
payload = request.json
resource = self.backend.get_model(resource_type.id).model_validate(
payload, scim_ctx=Context.RESOURCE_CREATION_REQUEST
)
created_resource = self.backend.create_resource(
tenant_id,
resource_type.id,
resource,
)
self.adjust_location(request, created_resource)
return self.make_response(
created_resource.model_dump(
scim_ctx=Context.RESOURCE_CREATION_RESPONSE
),
status=201,
headers={"Location": created_resource.meta.location},
)
def call_single_resource(
self, request: Request, resource_endpoint: str, resource_id: str, **kwargs
) -> Response:
find_endpoint = "/" + resource_endpoint
resource_type = self.backend.get_resource_type_by_endpoint(find_endpoint)
if not resource_type:
raise NotFound
if "tenant_id" not in kwargs:
raise Exception
tenant_id = kwargs["tenant_id"]
match request.method:
case "GET":
if resource := self.backend.get_resource(
tenant_id, resource_type.id, resource_id
):
if self.continue_etag(request, resource):
response_args = self.get_attrs_from_request(request)
self.adjust_location(request, resource)
return self.make_response(
resource.model_dump(
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
**response_args,
)
)
else:
return self.make_response(None, status=304)
raise NotFound
case "DELETE":
if self.backend.delete_resource(
tenant_id, resource_type.id, resource_id
):
return self.make_response(None, 204)
else:
raise NotFound
case "PUT":
response_args = self.get_attrs_from_request(request)
resource = self.backend.get_resource(
tenant_id, resource_type.id, resource_id
)
if resource is None:
raise NotFound
if not self.continue_etag(request, resource):
raise PreconditionFailed
updated_attributes = self.backend.get_model(
resource_type.id
).model_validate(request.json)
merge_resources(resource, updated_attributes)
updated = self.backend.update_resource(
tenant_id, resource_type.id, resource
)
self.adjust_location(request, updated)
return self.make_response(
updated.model_dump(
scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE,
**response_args,
)
)
case _: # "PATCH"
payload = request.json
# MS Entra sometimes passes a "id" attribute
if "id" in payload:
del payload["id"]
operations = payload.get("Operations", [])
for operation in operations:
if "name" in operation:
# MS Entra sometimes passes a "name" attribute
del operation["name"]
patch_operation = PatchOp.model_validate(payload)
response_args = self.get_attrs_from_request(request)
resource = self.backend.get_resource(
tenant_id, resource_type.id, resource_id
)
if resource is None:
raise NotFound
if not self.continue_etag(request, resource):
raise PreconditionFailed
self.apply_patch_operation(resource, patch_operation)
updated = self.backend.update_resource(
tenant_id, resource_type.id, resource
)
if response_args:
self.adjust_location(request, updated)
return self.make_response(
updated.model_dump(
scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE,
**response_args,
)
)
else:
# RFC 7644, section 3.5.2:
# A PATCH operation MAY return a 204 (no content)
# if no attributes were requested
return self.make_response(
None, 204, headers={"ETag": updated.meta.version}
)
def wsgi_app(self, request: Request, environ):
try:
if environ.get("PATH_INFO", "").endswith(".scim"):
# RFC 7644, Section 3.8
# Just strip .scim suffix, the provider always returns application/scim+json
environ["PATH_INFO"], _, _ = environ["PATH_INFO"].rpartition(".scim")
urls = self.url_map.bind_to_environ(environ)
endpoint, args = urls.match()
tenant_id = None
if endpoint != "service_provider_config":
# RFC7643, Section 5: skip authentication for ServiceProviderConfig
tenant_id = self.check_auth(request)
# Wrap the entire call in a transaction. Should probably be optimized (use transaction only when necessary).
with self.backend:
if endpoint == "service_provider_config" or endpoint == "schema":
response = getattr(self, f"call_{endpoint}")(request, **args)
else:
response = getattr(self, f"call_{endpoint}")(
request, **args, tenant_id=tenant_id
)
return response
except RequestRedirect as e:
# urls.match may cause a redirect, handle it as a special case of HTTPException
self.log.exception(e)
return e.get_response(environ)
except HTTPException as e:
self.log.exception(e)
return self.make_error(Error(status=e.code, detail=e.description))
except SCIMException as e:
self.log.exception(e)
return self.make_error(e.scim_error)
except ValidationError as e:
self.log.exception(e)
return self.make_error(Error(status=400, detail=str(e)))
except Exception as e:
self.log.exception(e)
tb = traceback.format_exc()
return self.make_error(Error(status=500, detail=str(e) + "\n" + tb))

View file

@ -0,0 +1,371 @@
from routers.scim import helpers
from chalicelib.utils import pg_client
from scim2_models import Resource
def convert_provider_resource_to_client_resource(
provider_resource: dict,
) -> dict:
groups = []
if provider_resource["role_id"]:
groups.append(
{
"value": str(provider_resource["role_id"]),
"$ref": f"Groups/{provider_resource['role_id']}",
}
)
return {
"id": str(provider_resource["user_id"]),
"schemas": [
"urn:ietf:params:scim:schemas:core:2.0:User",
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
"urn:ietf:params:scim:schemas:extension:openreplay:2.0:User",
],
"meta": {
"resourceType": "User",
"created": provider_resource["created_at"].strftime("%Y-%m-%dT%H:%M:%SZ"),
"lastModified": provider_resource["updated_at"].strftime(
"%Y-%m-%dT%H:%M:%SZ"
),
},
"userName": provider_resource["email"],
"externalId": provider_resource["internal_id"],
"name": {
"formatted": provider_resource["name"],
},
"displayName": provider_resource["name"] or provider_resource["email"],
"active": provider_resource["deleted_at"] is None,
"groups": groups,
"urn:ietf:params:scim:schemas:extension:openreplay:2.0:User": {
"permissions": provider_resource.get("permissions") or [],
"projectKeys": provider_resource.get("project_keys") or [],
},
}
def query_resources(tenant_id: int) -> list[dict]:
with pg_client.PostgresClient() as cur:
cur.execute(
f"""
SELECT
users.*,
roles.permissions AS permissions,
COALESCE(
(
SELECT json_agg(projects.project_key)
FROM public.projects
LEFT JOIN public.roles_projects USING (project_id)
WHERE roles_projects.role_id = roles.role_id
),
'[]'
) AS project_keys
FROM public.users
LEFT JOIN public.roles ON roles.role_id = users.role_id
WHERE users.tenant_id = {tenant_id} AND users.deleted_at IS NULL
"""
)
items = cur.fetchall()
return [convert_provider_resource_to_client_resource(item) for item in items]
def get_resource(resource_id: str, tenant_id: int) -> dict | None:
with pg_client.PostgresClient() as cur:
cur.execute(
f"""
SELECT
users.*,
roles.permissions AS permissions,
COALESCE(
(
SELECT json_agg(projects.project_key)
FROM public.projects
LEFT JOIN public.roles_projects USING (project_id)
WHERE roles_projects.role_id = roles.role_id
),
'[]'
) AS project_keys
FROM public.users
LEFT JOIN public.roles ON roles.role_id = users.role_id
WHERE users.tenant_id = {tenant_id} AND users.deleted_at IS NULL AND users.user_id = {resource_id}
"""
)
item = cur.fetchone()
if item:
return convert_provider_resource_to_client_resource(item)
return None
def delete_resource(resource_id: str, tenatn_id: int) -> None:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
UPDATE public.users
SET
deleted_at = NULL,
updated_at = now()
WHERE users.user_id = %(user_id)s
""",
{"user_id": resource_id},
)
)
def search_existing(tenant_id: int, resource: Resource) -> dict | None:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT *
FROM public.users
WHERE email = %(email)s
""",
{"email": resource.user_name},
)
)
item = cur.fetchone()
if item:
return convert_provider_resource_to_client_resource(item)
return None
def restore_resource(tenant_id: int, resource: Resource) -> dict | None:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT role_id
FROM public.users
WHERE user_id = %(user_id)s
""",
{"user_id": resource.id},
)
)
item = cur.fetchone()
if item and item["role_id"] is not None:
_update_role_projects_and_permissions(
item["role_id"],
resource.OpenreplayUser.project_keys,
resource.OpenreplayUser.permissions,
cur,
)
cur.execute(
cur.mogrify(
"""
WITH u AS (
UPDATE public.users
SET
tenant_id = %(tenant_id)s,
email = %(email)s,
name = %(name)s,
internal_id = %(internal_id)s,
deleted_at = NULL,
created_at = now(),
updated_at = now(),
api_key = default,
jwt_iat = NULL,
weekly_report = default
WHERE users.email = %(email)s
RETURNING *
)
SELECT
u.*,
roles.permissions AS permissions,
COALESCE(
(
SELECT json_agg(projects.project_key)
FROM public.projects
LEFT JOIN public.roles_projects USING (project_id)
WHERE roles_projects.role_id = roles.role_id
),
'[]'
) AS project_keys
FROM u
LEFT JOIN public.roles ON roles.role_id = u.role_id
""",
{
"tenant_id": tenant_id,
"email": resource.user_name,
"name": " ".join(
[
x
for x in [
resource.name.honorific_prefix,
resource.name.given_name,
resource.name.middle_name,
resource.name.family_name,
resource.name.honorific_suffix,
]
if x
]
)
if resource.name
else "",
"internal_id": resource.external_id,
},
)
)
item = cur.fetchone()
return convert_provider_resource_to_client_resource(item)
def create_resource(tenant_id: int, resource: Resource) -> dict:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
WITH u AS (
INSERT INTO public.users (
tenant_id,
email,
name,
internal_id
)
VALUES (
%(tenant_id)s,
%(email)s,
%(name)s,
%(internal_id)s
)
RETURNING *
)
SELECT *
FROM u
""",
{
"tenant_id": tenant_id,
"email": resource.user_name,
"name": " ".join(
[
x
for x in [
resource.name.honorific_prefix,
resource.name.given_name,
resource.name.middle_name,
resource.name.family_name,
resource.name.honorific_suffix,
]
if x
]
)
if resource.name
else "",
"internal_id": resource.external_id,
},
)
)
item = cur.fetchone()
return convert_provider_resource_to_client_resource(item)
def update_resource(tenant_id: int, resource: Resource) -> dict | None:
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT role_id
FROM public.users
WHERE user_id = %(user_id)s
""",
{"user_id": resource.id},
)
)
item = cur.fetchone()
if item and item["role_id"] is not None:
_update_role_projects_and_permissions(
item["role_id"],
resource.OpenreplayUser.project_keys,
resource.OpenreplayUser.permissions,
cur,
)
cur.execute(
cur.mogrify(
"""
WITH u AS (
UPDATE public.users
SET
tenant_id = %(tenant_id)s,
email = %(email)s,
name = %(name)s,
internal_id = %(internal_id)s,
updated_at = now()
WHERE user_id = %(user_id)s
RETURNING *
)
SELECT
u.*,
roles.permissions AS permissions,
COALESCE(
(
SELECT json_agg(projects.project_key)
FROM public.projects
LEFT JOIN public.roles_projects USING (project_id)
WHERE roles_projects.role_id = roles.role_id
),
'[]'
) AS project_keys
FROM u
LEFT JOIN public.roles ON roles.role_id = u.role_id
""",
{
"user_id": resource.id,
"tenant_id": tenant_id,
"email": resource.user_name,
"name": " ".join(
[
x
for x in [
resource.name.honorific_prefix,
resource.name.given_name,
resource.name.middle_name,
resource.name.family_name,
resource.name.honorific_suffix,
]
if x
]
)
if resource.name
else "",
"internal_id": resource.external_id,
},
)
)
item = cur.fetchone()
return convert_provider_resource_to_client_resource(item)
def _update_role_projects_and_permissions(
role_id: int,
project_keys: list[str] | None,
permissions: list[str] | None,
cur: pg_client.PostgresClient,
) -> None:
all_projects = "true" if not project_keys else "false"
project_key_clause = helpers.safe_mogrify_array(project_keys, "varchar", cur)
permission_clause = helpers.safe_mogrify_array(permissions, "varchar", cur)
cur.execute(
f"""
UPDATE public.roles
SET
updated_at = now(),
all_projects = {all_projects},
permissions = {permission_clause}
WHERE role_id = {role_id}
RETURNING *
"""
)
cur.execute(
f"""
DELETE FROM public.roles_projects
WHERE roles_projects.role_id = {role_id}
"""
)
cur.execute(
f"""
INSERT INTO public.roles_projects (role_id, project_id)
SELECT {role_id}, projects.project_id
FROM public.projects
WHERE projects.project_key = ANY({project_key_clause})
"""
)

View file

@ -108,6 +108,16 @@ CREATE TABLE public.tenants
);
CREATE TABLE public.scim_auth_codes
(
auth_code_id integer generated BY DEFAULT AS IDENTITY PRIMARY KEY,
tenant_id integer NOT NULL REFERENCES public.tenants (tenant_id) ON DELETE CASCADE,
auth_code text NOT NULL UNIQUE DEFAULT generate_api_key(20),
created_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'),
used bool NOT NULL DEFAULT FALSE
);
CREATE TABLE public.roles
(
role_id integer generated BY DEFAULT AS IDENTITY PRIMARY KEY,
@ -118,6 +128,7 @@ CREATE TABLE public.roles
protected bool NOT NULL DEFAULT FALSE,
all_projects bool NOT NULL DEFAULT TRUE,
created_at timestamp NOT NULL DEFAULT timezone('utc'::text, now()),
updated_at timestamp NOT NULL DEFAULT timezone('utc'::text, now()),
deleted_at timestamp NULL DEFAULT NULL,
service_role bool NOT NULL DEFAULT FALSE
);
@ -132,6 +143,7 @@ CREATE TABLE public.users
role user_role NOT NULL DEFAULT 'member',
name text NOT NULL,
created_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'),
updated_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'),
deleted_at timestamp without time zone NULL DEFAULT NULL,
api_key text UNIQUE DEFAULT generate_api_key(20) NOT NULL,
jwt_iat timestamp without time zone NULL DEFAULT NULL,