Merge 7324283c29 into cd70633d1f
This commit is contained in:
commit
e1fdbb1c36
19 changed files with 1958 additions and 1260 deletions
1
ee/api/.gitignore
vendored
1
ee/api/.gitignore
vendored
|
|
@ -283,4 +283,3 @@ Pipfile.lock
|
|||
/chalicelib/utils/contextual_validators.py
|
||||
/routers/subs/product_analytics.py
|
||||
/schemas/product_analytics.py
|
||||
/ee/bin/*
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ xmlsec = "==1.3.14"
|
|||
python-multipart = "==0.0.20"
|
||||
redis = "==6.1.0"
|
||||
azure-storage-blob = "==12.25.1"
|
||||
scim2-server = "*"
|
||||
scim2-models = "*"
|
||||
|
||||
[dev-packages]
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from decouple import config
|
|||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.middleware.wsgi import WSGIMiddleware
|
||||
from psycopg import AsyncConnection
|
||||
from psycopg.rows import dict_row
|
||||
from starlette import status
|
||||
|
|
@ -21,12 +22,20 @@ from chalicelib.utils import pg_client, ch_client
|
|||
from crons import core_crons, ee_crons, core_dynamic_crons
|
||||
from routers import core, core_dynamic
|
||||
from routers import ee
|
||||
from routers.subs import insights, metrics, v1_api, health, usability_tests, spot, product_analytics
|
||||
from routers.subs import (
|
||||
insights,
|
||||
metrics,
|
||||
v1_api,
|
||||
health,
|
||||
usability_tests,
|
||||
spot,
|
||||
product_analytics,
|
||||
)
|
||||
from routers.subs import v1_api_ee
|
||||
|
||||
if config("ENABLE_SSO", cast=bool, default=True):
|
||||
from routers import saml
|
||||
from routers import scim
|
||||
from routers.scim import api as scim
|
||||
|
||||
loglevel = config("LOGLEVEL", default=logging.WARNING)
|
||||
print(f">Loglevel set to: {loglevel}")
|
||||
|
|
@ -34,7 +43,6 @@ logging.basicConfig(level=loglevel)
|
|||
|
||||
|
||||
class ORPYAsyncConnection(AsyncConnection):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, row_factory=dict_row, **kwargs)
|
||||
|
||||
|
|
@ -43,7 +51,7 @@ class ORPYAsyncConnection(AsyncConnection):
|
|||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
logging.info(">>>>> starting up <<<<<")
|
||||
ap_logger = logging.getLogger('apscheduler')
|
||||
ap_logger = logging.getLogger("apscheduler")
|
||||
ap_logger.setLevel(loglevel)
|
||||
|
||||
app.schedule = AsyncIOScheduler()
|
||||
|
|
@ -53,12 +61,23 @@ async def lifespan(app: FastAPI):
|
|||
await events_queue.init()
|
||||
app.schedule.start()
|
||||
|
||||
for job in core_crons.cron_jobs + core_dynamic_crons.cron_jobs + traces.cron_jobs + ee_crons.ee_cron_jobs:
|
||||
for job in (
|
||||
core_crons.cron_jobs
|
||||
+ core_dynamic_crons.cron_jobs
|
||||
+ traces.cron_jobs
|
||||
+ ee_crons.ee_cron_jobs
|
||||
):
|
||||
app.schedule.add_job(id=job["func"].__name__, **job)
|
||||
|
||||
ap_logger.info(">Scheduled jobs:")
|
||||
for job in app.schedule.get_jobs():
|
||||
ap_logger.info({"Name": str(job.id), "Run Frequency": str(job.trigger), "Next Run": str(job.next_run_time)})
|
||||
ap_logger.info(
|
||||
{
|
||||
"Name": str(job.id),
|
||||
"Run Frequency": str(job.trigger),
|
||||
"Next Run": str(job.next_run_time),
|
||||
}
|
||||
)
|
||||
|
||||
database = {
|
||||
"host": config("pg_host", default="localhost"),
|
||||
|
|
@ -69,9 +88,12 @@ async def lifespan(app: FastAPI):
|
|||
"application_name": "AIO" + config("APP_NAME", default="PY"),
|
||||
}
|
||||
|
||||
database = psycopg_pool.AsyncConnectionPool(kwargs=database, connection_class=ORPYAsyncConnection,
|
||||
min_size=config("PG_AIO_MINCONN", cast=int, default=1),
|
||||
max_size=config("PG_AIO_MAXCONN", cast=int, default=5), )
|
||||
database = psycopg_pool.AsyncConnectionPool(
|
||||
kwargs=database,
|
||||
connection_class=ORPYAsyncConnection,
|
||||
min_size=config("PG_AIO_MINCONN", cast=int, default=1),
|
||||
max_size=config("PG_AIO_MAXCONN", cast=int, default=5),
|
||||
)
|
||||
app.state.postgresql = database
|
||||
|
||||
# App listening
|
||||
|
|
@ -86,16 +108,24 @@ async def lifespan(app: FastAPI):
|
|||
await pg_client.terminate()
|
||||
|
||||
|
||||
app = FastAPI(root_path=config("root_path", default="/api"), docs_url=config("docs_url", default=""),
|
||||
redoc_url=config("redoc_url", default=""), lifespan=lifespan)
|
||||
app = FastAPI(
|
||||
root_path=config("root_path", default="/api"),
|
||||
docs_url=config("docs_url", default=""),
|
||||
redoc_url=config("redoc_url", default=""),
|
||||
lifespan=lifespan,
|
||||
)
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
|
||||
@app.middleware('http')
|
||||
@app.middleware("http")
|
||||
async def or_middleware(request: Request, call_next):
|
||||
from chalicelib.core import unlock
|
||||
|
||||
if not unlock.is_valid():
|
||||
return JSONResponse(content={"errors": ["expired license"]}, status_code=status.HTTP_403_FORBIDDEN)
|
||||
return JSONResponse(
|
||||
content={"errors": ["expired license"]},
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
if helper.TRACK_TIME:
|
||||
now = time.time()
|
||||
|
|
@ -110,8 +140,10 @@ async def or_middleware(request: Request, call_next):
|
|||
now = time.time() - now
|
||||
if now > 2:
|
||||
now = round(now, 2)
|
||||
logging.warning(f"Execution time: {now} s for {request.method}: {request.url.path}")
|
||||
response.headers["x-robots-tag"] = 'noindex, nofollow'
|
||||
logging.warning(
|
||||
f"Execution time: {now} s for {request.method}: {request.url.path}"
|
||||
)
|
||||
response.headers["x-robots-tag"] = "noindex, nofollow"
|
||||
return response
|
||||
|
||||
|
||||
|
|
@ -162,3 +194,4 @@ if config("ENABLE_SSO", cast=bool, default=True):
|
|||
app.include_router(scim.public_app)
|
||||
app.include_router(scim.app)
|
||||
app.include_router(scim.app_apikey)
|
||||
app.mount("/sso/scim/v2", WSGIMiddleware(scim.scim_app))
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
|
@ -10,13 +9,15 @@ from chalicelib.utils.TimeUTC import TimeUTC
|
|||
|
||||
def __exists_by_name(tenant_id: int, name: str, exclude_id: Optional[int]) -> bool:
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify(f"""SELECT EXISTS(SELECT 1
|
||||
query = cur.mogrify(
|
||||
f"""SELECT EXISTS(SELECT 1
|
||||
FROM public.roles
|
||||
WHERE tenant_id = %(tenant_id)s
|
||||
AND name ILIKE %(name)s
|
||||
AND deleted_at ISNULL
|
||||
{"AND role_id!=%(exclude_id)s" if exclude_id else ""}) AS exists;""",
|
||||
{"tenant_id": tenant_id, "name": name, "exclude_id": exclude_id})
|
||||
{"tenant_id": tenant_id, "name": name, "exclude_id": exclude_id},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
row = cur.fetchone()
|
||||
return row["exists"]
|
||||
|
|
@ -28,24 +29,31 @@ def update(tenant_id, user_id, role_id, data: schemas.RolePayloadSchema):
|
|||
if not admin["admin"] and not admin["superAdmin"]:
|
||||
return {"errors": ["unauthorized"]}
|
||||
if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=role_id):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"name already exists.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="name already exists."
|
||||
)
|
||||
|
||||
if not data.all_projects and (data.projects is None or len(data.projects) == 0):
|
||||
return {"errors": ["must specify a project or all projects"]}
|
||||
if data.projects is not None and len(data.projects) > 0 and not data.all_projects:
|
||||
data.projects = projects.is_authorized_batch(project_ids=data.projects, tenant_id=tenant_id)
|
||||
data.projects = projects.is_authorized_batch(
|
||||
project_ids=data.projects, tenant_id=tenant_id
|
||||
)
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""SELECT 1
|
||||
query = cur.mogrify(
|
||||
"""SELECT 1
|
||||
FROM public.roles
|
||||
WHERE role_id = %(role_id)s
|
||||
AND tenant_id = %(tenant_id)s
|
||||
AND protected = TRUE
|
||||
LIMIT 1;""",
|
||||
{"tenant_id": tenant_id, "role_id": role_id})
|
||||
{"tenant_id": tenant_id, "role_id": role_id},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
if cur.fetchone() is not None:
|
||||
return {"errors": ["this role is protected"]}
|
||||
query = cur.mogrify("""UPDATE public.roles
|
||||
query = cur.mogrify(
|
||||
"""UPDATE public.roles
|
||||
SET name= %(name)s,
|
||||
description= %(description)s,
|
||||
permissions= %(permissions)s,
|
||||
|
|
@ -57,43 +65,36 @@ def update(tenant_id, user_id, role_id, data: schemas.RolePayloadSchema):
|
|||
RETURNING *, COALESCE((SELECT ARRAY_AGG(project_id)
|
||||
FROM roles_projects
|
||||
WHERE roles_projects.role_id=%(role_id)s),'{}') AS projects;""",
|
||||
{"tenant_id": tenant_id, "role_id": role_id, **data.model_dump()})
|
||||
{"tenant_id": tenant_id, "role_id": role_id, **data.model_dump()},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
row = cur.fetchone()
|
||||
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
|
||||
if not data.all_projects:
|
||||
d_projects = [i for i in row["projects"] if i not in data.projects]
|
||||
if len(d_projects) > 0:
|
||||
query = cur.mogrify("""DELETE FROM roles_projects
|
||||
query = cur.mogrify(
|
||||
"""DELETE FROM roles_projects
|
||||
WHERE role_id=%(role_id)s
|
||||
AND project_id IN %(project_ids)s""",
|
||||
{"role_id": role_id, "project_ids": tuple(d_projects)})
|
||||
{"role_id": role_id, "project_ids": tuple(d_projects)},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
n_projects = [i for i in data.projects if i not in row["projects"]]
|
||||
if len(n_projects) > 0:
|
||||
query = cur.mogrify(f"""INSERT INTO roles_projects(role_id, project_id)
|
||||
query = cur.mogrify(
|
||||
f"""INSERT INTO roles_projects(role_id, project_id)
|
||||
VALUES {",".join([f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(n_projects))])}""",
|
||||
{"role_id": role_id, **{f"project_id_{i}": p for i, p in enumerate(n_projects)}})
|
||||
{
|
||||
"role_id": role_id,
|
||||
**{f"project_id_{i}": p for i, p in enumerate(n_projects)},
|
||||
},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
row["projects"] = data.projects
|
||||
|
||||
return helper.dict_to_camel_case(row)
|
||||
|
||||
def update_group_name(tenant_id, group_id, name):
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""UPDATE public.roles
|
||||
SET name= %(name)s
|
||||
WHERE roles.data->>'group_id' = %(group_id)s
|
||||
AND tenant_id = %(tenant_id)s
|
||||
AND deleted_at ISNULL
|
||||
AND protected = FALSE
|
||||
RETURNING *;""",
|
||||
{"tenant_id": tenant_id, "group_id": group_id, "name": name })
|
||||
cur.execute(query=query)
|
||||
row = cur.fetchone()
|
||||
|
||||
return helper.dict_to_camel_case(row)
|
||||
|
||||
|
||||
def create(tenant_id, user_id, data: schemas.RolePayloadSchema):
|
||||
admin = users.get(user_id=user_id, tenant_id=tenant_id)
|
||||
|
|
@ -102,57 +103,44 @@ def create(tenant_id, user_id, data: schemas.RolePayloadSchema):
|
|||
return {"errors": ["unauthorized"]}
|
||||
|
||||
if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=None):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"name already exists.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="name already exists."
|
||||
)
|
||||
|
||||
if not data.all_projects and (data.projects is None or len(data.projects) == 0):
|
||||
return {"errors": ["must specify a project or all projects"]}
|
||||
if data.projects is not None and len(data.projects) > 0 and not data.all_projects:
|
||||
data.projects = projects.is_authorized_batch(project_ids=data.projects, tenant_id=tenant_id)
|
||||
data.projects = projects.is_authorized_batch(
|
||||
project_ids=data.projects, tenant_id=tenant_id
|
||||
)
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""INSERT INTO roles(tenant_id, name, description, permissions, all_projects)
|
||||
query = cur.mogrify(
|
||||
"""INSERT INTO roles(tenant_id, name, description, permissions, all_projects)
|
||||
VALUES (%(tenant_id)s, %(name)s, %(description)s, %(permissions)s::text[], %(all_projects)s)
|
||||
RETURNING *;""",
|
||||
{"tenant_id": tenant_id, "name": data.name, "description": data.description,
|
||||
"permissions": data.permissions, "all_projects": data.all_projects})
|
||||
{
|
||||
"tenant_id": tenant_id,
|
||||
"name": data.name,
|
||||
"description": data.description,
|
||||
"permissions": data.permissions,
|
||||
"all_projects": data.all_projects,
|
||||
},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
row = cur.fetchone()
|
||||
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
|
||||
row["projects"] = []
|
||||
if not data.all_projects:
|
||||
role_id = row["role_id"]
|
||||
query = cur.mogrify(f"""INSERT INTO roles_projects(role_id, project_id)
|
||||
query = cur.mogrify(
|
||||
f"""INSERT INTO roles_projects(role_id, project_id)
|
||||
VALUES {",".join(f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(data.projects)))}
|
||||
RETURNING project_id;""",
|
||||
{"role_id": role_id, **{f"project_id_{i}": p for i, p in enumerate(data.projects)}})
|
||||
cur.execute(query=query)
|
||||
row["projects"] = [r["project_id"] for r in cur.fetchall()]
|
||||
return helper.dict_to_camel_case(row)
|
||||
|
||||
def create_as_admin(tenant_id, group_id, data: schemas.RolePayloadSchema):
|
||||
|
||||
if __exists_by_name(tenant_id=tenant_id, name=data.name, exclude_id=None):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"name already exists.")
|
||||
|
||||
if not data.all_projects and (data.projects is None or len(data.projects) == 0):
|
||||
return {"errors": ["must specify a project or all projects"]}
|
||||
if data.projects is not None and len(data.projects) > 0 and not data.all_projects:
|
||||
data.projects = projects.is_authorized_batch(project_ids=data.projects, tenant_id=tenant_id)
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""INSERT INTO roles(tenant_id, name, description, permissions, all_projects, data)
|
||||
VALUES (%(tenant_id)s, %(name)s, %(description)s, %(permissions)s::text[], %(all_projects)s, %(data)s)
|
||||
RETURNING *;""",
|
||||
{"tenant_id": tenant_id, "name": data.name, "description": data.description,
|
||||
"permissions": data.permissions, "all_projects": data.all_projects, "data": json.dumps({ "group_id": group_id })})
|
||||
cur.execute(query=query)
|
||||
row = cur.fetchone()
|
||||
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
|
||||
row["projects"] = []
|
||||
if not data.all_projects:
|
||||
role_id = row["role_id"]
|
||||
query = cur.mogrify(f"""INSERT INTO roles_projects(role_id, project_id)
|
||||
VALUES {",".join(f"(%(role_id)s,%(project_id_{i})s)" for i in range(len(data.projects)))}
|
||||
RETURNING project_id;""",
|
||||
{"role_id": role_id, **{f"project_id_{i}": p for i, p in enumerate(data.projects)}})
|
||||
{
|
||||
"role_id": role_id,
|
||||
**{f"project_id_{i}": p for i, p in enumerate(data.projects)},
|
||||
},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
row["projects"] = [r["project_id"] for r in cur.fetchall()]
|
||||
return helper.dict_to_camel_case(row)
|
||||
|
|
@ -160,7 +148,8 @@ def create_as_admin(tenant_id, group_id, data: schemas.RolePayloadSchema):
|
|||
|
||||
def get_roles(tenant_id):
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""SELECT roles.*, COALESCE(projects, '{}') AS projects
|
||||
query = cur.mogrify(
|
||||
"""SELECT roles.*, COALESCE(projects, '{}') AS projects
|
||||
FROM public.roles
|
||||
LEFT JOIN LATERAL (SELECT array_agg(project_id) AS projects
|
||||
FROM roles_projects
|
||||
|
|
@ -171,66 +160,25 @@ def get_roles(tenant_id):
|
|||
AND deleted_at IS NULL
|
||||
AND not service_role
|
||||
ORDER BY role_id;""",
|
||||
{"tenant_id": tenant_id})
|
||||
{"tenant_id": tenant_id},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
rows = cur.fetchall()
|
||||
for r in rows:
|
||||
r["created_at"] = TimeUTC.datetime_to_timestamp(r["created_at"])
|
||||
return helper.list_to_camel_case(rows)
|
||||
|
||||
def get_roles_with_uuid(tenant_id):
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""SELECT roles.*, COALESCE(projects, '{}') AS projects
|
||||
FROM public.roles
|
||||
LEFT JOIN LATERAL (SELECT array_agg(project_id) AS projects
|
||||
FROM roles_projects
|
||||
INNER JOIN projects USING (project_id)
|
||||
WHERE roles_projects.role_id = roles.role_id
|
||||
AND projects.deleted_at ISNULL ) AS role_projects ON (TRUE)
|
||||
WHERE tenant_id =%(tenant_id)s
|
||||
AND data ? 'group_id'
|
||||
AND deleted_at IS NULL
|
||||
AND not service_role
|
||||
ORDER BY role_id;""",
|
||||
{"tenant_id": tenant_id})
|
||||
cur.execute(query=query)
|
||||
rows = cur.fetchall()
|
||||
for r in rows:
|
||||
r["created_at"] = TimeUTC.datetime_to_timestamp(r["created_at"])
|
||||
return helper.list_to_camel_case(rows)
|
||||
|
||||
def get_roles_with_uuid_paginated(tenant_id, start_index, count=None, name=None):
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""SELECT roles.*, COALESCE(projects, '{}') AS projects
|
||||
FROM public.roles
|
||||
LEFT JOIN LATERAL (SELECT array_agg(project_id) AS projects
|
||||
FROM roles_projects
|
||||
INNER JOIN projects USING (project_id)
|
||||
WHERE roles_projects.role_id = roles.role_id
|
||||
AND projects.deleted_at ISNULL ) AS role_projects ON (TRUE)
|
||||
WHERE tenant_id =%(tenant_id)s
|
||||
AND data ? 'group_id'
|
||||
AND deleted_at IS NULL
|
||||
AND not service_role
|
||||
AND name = COALESCE(%(name)s, name)
|
||||
ORDER BY role_id
|
||||
LIMIT %(count)s
|
||||
OFFSET %(startIndex)s;""",
|
||||
{"tenant_id": tenant_id, "name": name, "startIndex": start_index - 1, "count": count})
|
||||
cur.execute(query=query)
|
||||
rows = cur.fetchall()
|
||||
return helper.list_to_camel_case(rows)
|
||||
|
||||
|
||||
def get_role_by_name(tenant_id, name):
|
||||
### "name" isn't unique in database
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""SELECT *
|
||||
query = cur.mogrify(
|
||||
"""SELECT *
|
||||
FROM public.roles
|
||||
WHERE tenant_id =%(tenant_id)s
|
||||
AND deleted_at IS NULL
|
||||
AND name ILIKE %(name)s;""",
|
||||
{"tenant_id": tenant_id, "name": name})
|
||||
{"tenant_id": tenant_id, "name": name},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
row = cur.fetchone()
|
||||
if row is not None:
|
||||
|
|
@ -244,139 +192,55 @@ def delete(tenant_id, user_id, role_id):
|
|||
if not admin["admin"] and not admin["superAdmin"]:
|
||||
return {"errors": ["unauthorized"]}
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""SELECT 1
|
||||
query = cur.mogrify(
|
||||
"""SELECT 1
|
||||
FROM public.roles
|
||||
WHERE role_id = %(role_id)s
|
||||
AND tenant_id = %(tenant_id)s
|
||||
AND protected = TRUE
|
||||
LIMIT 1;""",
|
||||
{"tenant_id": tenant_id, "role_id": role_id})
|
||||
{"tenant_id": tenant_id, "role_id": role_id},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
if cur.fetchone() is not None:
|
||||
return {"errors": ["this role is protected"]}
|
||||
query = cur.mogrify("""SELECT 1
|
||||
query = cur.mogrify(
|
||||
"""SELECT 1
|
||||
FROM public.users
|
||||
WHERE role_id = %(role_id)s
|
||||
AND tenant_id = %(tenant_id)s
|
||||
LIMIT 1;""",
|
||||
{"tenant_id": tenant_id, "role_id": role_id})
|
||||
{"tenant_id": tenant_id, "role_id": role_id},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
if cur.fetchone() is not None:
|
||||
return {"errors": ["this role is already attached to other user(s)"]}
|
||||
query = cur.mogrify("""UPDATE public.roles
|
||||
query = cur.mogrify(
|
||||
"""UPDATE public.roles
|
||||
SET deleted_at = timezone('utc'::text, now())
|
||||
WHERE role_id = %(role_id)s
|
||||
AND tenant_id = %(tenant_id)s
|
||||
AND protected = FALSE;""",
|
||||
{"tenant_id": tenant_id, "role_id": role_id})
|
||||
{"tenant_id": tenant_id, "role_id": role_id},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
return get_roles(tenant_id=tenant_id)
|
||||
|
||||
def delete_scim_group(tenant_id, group_uuid):
|
||||
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""SELECT 1
|
||||
FROM public.roles
|
||||
WHERE data->>'group_id' = %(group_uuid)s
|
||||
AND tenant_id = %(tenant_id)s
|
||||
AND protected = TRUE
|
||||
LIMIT 1;""",
|
||||
{"tenant_id": tenant_id, "group_uuid": group_uuid})
|
||||
cur.execute(query)
|
||||
if cur.fetchone() is not None:
|
||||
return {"errors": ["this role is protected"]}
|
||||
|
||||
query = cur.mogrify(
|
||||
f"""DELETE FROM public.roles
|
||||
WHERE roles.data->>'group_id' = %(group_uuid)s;""", # removed this: AND users.deleted_at IS NOT NULL
|
||||
{"group_uuid": group_uuid})
|
||||
cur.execute(query)
|
||||
|
||||
return get_roles(tenant_id=tenant_id)
|
||||
|
||||
|
||||
|
||||
def get_role(tenant_id, role_id):
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""SELECT roles.*
|
||||
query = cur.mogrify(
|
||||
"""SELECT roles.*
|
||||
FROM public.roles
|
||||
WHERE tenant_id =%(tenant_id)s
|
||||
AND deleted_at IS NULL
|
||||
AND not service_role
|
||||
AND role_id = %(role_id)s
|
||||
LIMIT 1;""",
|
||||
{"tenant_id": tenant_id, "role_id": role_id})
|
||||
{"tenant_id": tenant_id, "role_id": role_id},
|
||||
)
|
||||
cur.execute(query=query)
|
||||
row = cur.fetchone()
|
||||
if row is not None:
|
||||
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
|
||||
return helper.dict_to_camel_case(row)
|
||||
|
||||
def get_role_by_group_id(tenant_id, group_id):
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""SELECT roles.*
|
||||
FROM public.roles
|
||||
WHERE tenant_id =%(tenant_id)s
|
||||
AND deleted_at IS NULL
|
||||
AND not service_role
|
||||
AND data->>'group_id' = %(group_id)s
|
||||
LIMIT 1;""",
|
||||
{"tenant_id": tenant_id, "group_id": group_id})
|
||||
cur.execute(query=query)
|
||||
row = cur.fetchone()
|
||||
if row is not None:
|
||||
row["created_at"] = TimeUTC.datetime_to_timestamp(row["created_at"])
|
||||
return helper.dict_to_camel_case(row)
|
||||
|
||||
def get_users_by_group_uuid(tenant_id, group_id):
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""SELECT
|
||||
u.user_id,
|
||||
u.name,
|
||||
u.data
|
||||
FROM public.roles r
|
||||
LEFT JOIN public.users u USING (role_id, tenant_id)
|
||||
WHERE u.tenant_id = %(tenant_id)s
|
||||
AND u.deleted_at IS NULL
|
||||
AND r.data->>'group_id' = %(group_id)s
|
||||
""",
|
||||
{"tenant_id": tenant_id, "group_id": group_id})
|
||||
cur.execute(query=query)
|
||||
rows = cur.fetchall()
|
||||
return helper.list_to_camel_case(rows)
|
||||
|
||||
def get_member_permissions(tenant_id):
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""SELECT
|
||||
r.permissions
|
||||
FROM public.roles r
|
||||
WHERE r.tenant_id = %(tenant_id)s
|
||||
AND r.name = 'Member'
|
||||
AND r.deleted_at IS NULL
|
||||
""",
|
||||
{"tenant_id": tenant_id})
|
||||
cur.execute(query=query)
|
||||
row = cur.fetchone()
|
||||
return helper.dict_to_camel_case(row)
|
||||
|
||||
def remove_group_membership(tenant_id, group_id, user_id):
|
||||
with pg_client.PostgresClient() as cur:
|
||||
query = cur.mogrify("""WITH r AS (
|
||||
SELECT role_id
|
||||
FROM public.roles
|
||||
WHERE data->>'group_id' = %(group_id)s
|
||||
LIMIT 1
|
||||
)
|
||||
UPDATE public.users u
|
||||
SET role_id= NULL
|
||||
FROM r
|
||||
WHERE u.data->>'user_id' = %(user_id)s
|
||||
AND u.role_id = r.role_id
|
||||
AND u.tenant_id = %(tenant_id)s
|
||||
AND u.deleted_at IS NULL
|
||||
RETURNING *;""",
|
||||
{"tenant_id": tenant_id, "group_id": group_id, "user_id": user_id})
|
||||
cur.execute(query=query)
|
||||
row = cur.fetchone()
|
||||
|
||||
return helper.dict_to_camel_case(row)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -23,20 +23,18 @@ SAML2 = {
|
|||
"entityId": config("SITE_URL") + API_PREFIX + "/sso/saml2/metadata/",
|
||||
"assertionConsumerService": {
|
||||
"url": config("SITE_URL") + API_PREFIX + "/sso/saml2/acs/",
|
||||
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
|
||||
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
|
||||
},
|
||||
"singleLogoutService": {
|
||||
"url": config("SITE_URL") + API_PREFIX + "/sso/saml2/sls/",
|
||||
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
|
||||
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
|
||||
},
|
||||
"NameIDFormat": "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
|
||||
"x509cert": config("sp_crt", default=""),
|
||||
"privateKey": config("sp_key", default=""),
|
||||
},
|
||||
"security": {
|
||||
"requestedAuthnContext": False
|
||||
},
|
||||
"idp": None
|
||||
"security": {"requestedAuthnContext": False},
|
||||
"idp": None,
|
||||
}
|
||||
|
||||
# in case tenantKey is included in the URL
|
||||
|
|
@ -50,25 +48,29 @@ if config("SAML2_MD_URL", default=None) is not None and len(config("SAML2_MD_URL
|
|||
print("SAML2_MD_URL provided, getting IdP metadata config")
|
||||
from onelogin.saml2.idp_metadata_parser import OneLogin_Saml2_IdPMetadataParser
|
||||
|
||||
idp_data = OneLogin_Saml2_IdPMetadataParser.parse_remote(config("SAML2_MD_URL", default=None))
|
||||
idp_data = OneLogin_Saml2_IdPMetadataParser.parse_remote(
|
||||
config("SAML2_MD_URL", default=None)
|
||||
)
|
||||
idp = idp_data.get("idp")
|
||||
|
||||
if SAML2["idp"] is None:
|
||||
if len(config("idp_entityId", default="")) > 0 \
|
||||
and len(config("idp_sso_url", default="")) > 0 \
|
||||
and len(config("idp_x509cert", default="")) > 0:
|
||||
if (
|
||||
len(config("idp_entityId", default="")) > 0
|
||||
and len(config("idp_sso_url", default="")) > 0
|
||||
and len(config("idp_x509cert", default="")) > 0
|
||||
):
|
||||
idp = {
|
||||
"entityId": config("idp_entityId"),
|
||||
"singleSignOnService": {
|
||||
"url": config("idp_sso_url"),
|
||||
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
|
||||
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
|
||||
},
|
||||
"x509cert": config("idp_x509cert")
|
||||
"x509cert": config("idp_x509cert"),
|
||||
}
|
||||
if len(config("idp_sls_url", default="")) > 0:
|
||||
idp["singleLogoutService"] = {
|
||||
"url": config("idp_sls_url"),
|
||||
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
|
||||
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
|
||||
}
|
||||
|
||||
if idp is None:
|
||||
|
|
@ -106,8 +108,8 @@ async def prepare_request(request: Request):
|
|||
session = {}
|
||||
# If server is behind proxys or balancers use the HTTP_X_FORWARDED fields
|
||||
headers = request.headers
|
||||
proto = headers.get('x-forwarded-proto', 'http')
|
||||
url_data = urlparse('%s://%s' % (proto, headers['host']))
|
||||
proto = headers.get("x-forwarded-proto", "http")
|
||||
url_data = urlparse("%s://%s" % (proto, headers["host"]))
|
||||
path = request.url.path
|
||||
site_url = urlparse(config("SITE_URL"))
|
||||
# to support custom port without changing IDP config
|
||||
|
|
@ -117,21 +119,21 @@ async def prepare_request(request: Request):
|
|||
|
||||
# add / to /acs
|
||||
if not path.endswith("/"):
|
||||
path = path + '/'
|
||||
path = path + "/"
|
||||
if len(API_PREFIX) > 0 and not path.startswith(API_PREFIX):
|
||||
path = API_PREFIX + path
|
||||
|
||||
return {
|
||||
'https': 'on' if proto == 'https' else 'off',
|
||||
'http_host': request.headers['host'] + host_suffix,
|
||||
'server_port': url_data.port,
|
||||
'script_name': path,
|
||||
'get_data': request.args.copy(),
|
||||
"https": "on" if proto == "https" else "off",
|
||||
"http_host": request.headers["host"] + host_suffix,
|
||||
"server_port": url_data.port,
|
||||
"script_name": path,
|
||||
"get_data": request.args.copy(),
|
||||
# Uncomment if using ADFS as IdP, https://github.com/onelogin/python-saml/pull/144
|
||||
# 'lowercase_urlencoding': True,
|
||||
'post_data': request.form.copy(),
|
||||
'cookie': {"session": session},
|
||||
'request': request
|
||||
"post_data": request.form.copy(),
|
||||
"cookie": {"session": session},
|
||||
"request": request,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -140,8 +142,11 @@ def is_saml2_available():
|
|||
|
||||
|
||||
def get_saml2_provider():
|
||||
return config("idp_name", default="saml2") if is_saml2_available() and len(
|
||||
config("idp_name", default="saml2")) > 0 else None
|
||||
return (
|
||||
config("idp_name", default="saml2")
|
||||
if is_saml2_available() and len(config("idp_name", default="saml2")) > 0
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
def get_landing_URL(query_params: dict = None, redirect_to_link2=False):
|
||||
|
|
@ -152,11 +157,14 @@ def get_landing_URL(query_params: dict = None, redirect_to_link2=False):
|
|||
|
||||
if redirect_to_link2:
|
||||
if len(config("sso_landing_override", default="")) == 0:
|
||||
logging.warning("SSO trying to redirect to custom URL, but sso_landing_override env var is empty")
|
||||
logging.warning(
|
||||
"SSO trying to redirect to custom URL, but sso_landing_override env var is empty"
|
||||
)
|
||||
else:
|
||||
return config("sso_landing_override") + query_params
|
||||
|
||||
return config("SITE_URL") + config("sso_landing", default="/login") + query_params
|
||||
base_url = config("SITE_URLx") if config("LOCAL_DEV") else config("SITE_URL")
|
||||
return base_url + config("sso_landing", default="/login") + query_params
|
||||
|
||||
|
||||
environ["hastSAML2"] = str(is_saml2_available())
|
||||
|
|
|
|||
|
|
@ -13,23 +13,9 @@ REFRESH_SECRET_KEY = config("SCIM_REFRESH_SECRET_KEY")
|
|||
ALGORITHM = config("SCIM_JWT_ALGORITHM")
|
||||
ACCESS_TOKEN_EXPIRE_SECONDS = int(config("SCIM_ACCESS_TOKEN_EXPIRE_SECONDS"))
|
||||
REFRESH_TOKEN_EXPIRE_SECONDS = int(config("SCIM_REFRESH_TOKEN_EXPIRE_SECONDS"))
|
||||
AUDIENCE="okta_client"
|
||||
ISSUER=config("JWT_ISSUER"),
|
||||
AUDIENCE = config("SCIM_AUDIENCE")
|
||||
ISSUER = (config("JWT_ISSUER"),)
|
||||
|
||||
# Simulated Okta Client Credentials
|
||||
# OKTA_CLIENT_ID = "okta-client"
|
||||
# OKTA_CLIENT_SECRET = "okta-secret"
|
||||
|
||||
# class TokenRequest(BaseModel):
|
||||
# client_id: str
|
||||
# client_secret: str
|
||||
|
||||
# async def authenticate_client(token_request: TokenRequest):
|
||||
# """Validate Okta Client Credentials and issue JWT"""
|
||||
# if token_request.client_id != OKTA_CLIENT_ID or token_request.client_secret != OKTA_CLIENT_SECRET:
|
||||
# raise HTTPException(status_code=401, detail="Invalid client credentials")
|
||||
|
||||
# return {"access_token": create_jwt(), "token_type": "bearer"}
|
||||
|
||||
def create_tokens(tenant_id):
|
||||
curr_time = time.time()
|
||||
|
|
@ -38,7 +24,7 @@ def create_tokens(tenant_id):
|
|||
"sub": "scim_server",
|
||||
"aud": AUDIENCE,
|
||||
"iss": ISSUER,
|
||||
"exp": ""
|
||||
"exp": "",
|
||||
}
|
||||
access_payload.update({"exp": curr_time + ACCESS_TOKEN_EXPIRE_SECONDS})
|
||||
access_token = jwt.encode(access_payload, ACCESS_SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
|
@ -47,20 +33,26 @@ def create_tokens(tenant_id):
|
|||
refresh_payload.update({"exp": curr_time + REFRESH_TOKEN_EXPIRE_SECONDS})
|
||||
refresh_token = jwt.encode(refresh_payload, REFRESH_SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
return access_token, refresh_token
|
||||
return access_token, refresh_token, ACCESS_TOKEN_EXPIRE_SECONDS
|
||||
|
||||
|
||||
def verify_access_token(token: str):
|
||||
try:
|
||||
payload = jwt.decode(token, ACCESS_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE)
|
||||
payload = jwt.decode(
|
||||
token, ACCESS_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE
|
||||
)
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(status_code=401, detail="Token expired")
|
||||
except jwt.InvalidTokenError:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
|
||||
def verify_refresh_token(token: str):
|
||||
try:
|
||||
payload = jwt.decode(token, REFRESH_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE)
|
||||
payload = jwt.decode(
|
||||
token, REFRESH_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE
|
||||
)
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(status_code=401, detail="Token expired")
|
||||
|
|
@ -68,10 +60,25 @@ def verify_refresh_token(token: str):
|
|||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
required_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
|
||||
# Authentication Dependency
|
||||
def auth_required(token: str = Depends(oauth2_scheme)):
|
||||
def auth_required(token: str = Depends(required_oauth2_scheme)):
|
||||
"""Dependency to check Authorization header."""
|
||||
if config("SCIM_AUTH_TYPE") == "OAuth2":
|
||||
payload = verify_access_token(token)
|
||||
return payload["tenant_id"]
|
||||
|
||||
|
||||
optional_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
|
||||
|
||||
|
||||
def auth_optional(token: str | None = Depends(optional_oauth2_scheme)):
|
||||
if token is None:
|
||||
return None
|
||||
try:
|
||||
tenant_id = auth_required(token)
|
||||
return tenant_id
|
||||
except HTTPException:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,466 +0,0 @@
|
|||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from decouple import config
|
||||
from fastapi import Depends, HTTPException, Header, Query, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import schemas
|
||||
from chalicelib.core import users, roles, tenants
|
||||
from chalicelib.utils.scim_auth import auth_required, create_tokens, verify_refresh_token
|
||||
from routers.base import get_routers
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2")
|
||||
|
||||
|
||||
"""Authentication endpoints"""
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
# Login endpoint to generate tokens
|
||||
@public_app.post("/token")
|
||||
async def login(host: str = Header(..., alias="Host"), form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
subdomain = host.split(".")[0]
|
||||
|
||||
# Missing authentication part, to add
|
||||
if form_data.username != config("SCIM_USER") or form_data.password != config("SCIM_PASSWORD"):
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
subdomain = "Openreplay EE"
|
||||
tenant = tenants.get_by_name(subdomain)
|
||||
access_token, refresh_token = create_tokens(tenant_id=tenant["tenantId"])
|
||||
|
||||
return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"}
|
||||
|
||||
# Refresh token endpoint
|
||||
@public_app.post("/refresh")
|
||||
async def refresh_token(r: RefreshRequest):
|
||||
|
||||
payload = verify_refresh_token(r.refresh_token)
|
||||
new_access_token, _ = create_tokens(tenant_id=payload["tenant_id"])
|
||||
|
||||
return {"access_token": new_access_token, "token_type": "Bearer"}
|
||||
|
||||
"""
|
||||
User endpoints
|
||||
"""
|
||||
|
||||
class Name(BaseModel):
|
||||
givenName: str
|
||||
familyName: str
|
||||
|
||||
class Email(BaseModel):
|
||||
primary: bool
|
||||
value: str
|
||||
type: str
|
||||
|
||||
class UserRequest(BaseModel):
|
||||
schemas: list[str]
|
||||
userName: str
|
||||
name: Name
|
||||
emails: list[Email]
|
||||
displayName: str
|
||||
locale: str
|
||||
externalId: str
|
||||
groups: list[dict]
|
||||
password: str = Field(default=None)
|
||||
active: bool
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
schemas: list[str]
|
||||
id: str
|
||||
userName: str
|
||||
name: Name
|
||||
emails: list[Email] # ignore for now
|
||||
displayName: str
|
||||
locale: str
|
||||
externalId: str
|
||||
active: bool
|
||||
groups: list[dict]
|
||||
meta: dict = Field(default={"resourceType": "User"})
|
||||
|
||||
class PatchUserRequest(BaseModel):
|
||||
schemas: list[str]
|
||||
Operations: list[dict]
|
||||
|
||||
|
||||
@public_app.get("/Users", dependencies=[Depends(auth_required)])
|
||||
async def get_users(
|
||||
start_index: int = Query(1, alias="startIndex"),
|
||||
count: Optional[int] = Query(None, alias="count"),
|
||||
email: Optional[str] = Query(None, alias="filter"),
|
||||
):
|
||||
"""Get SCIM Users"""
|
||||
if email:
|
||||
email = email.split(" ")[2].strip('"')
|
||||
result_users = users.get_users_paginated(start_index, count, email)
|
||||
|
||||
serialized_users = []
|
||||
for user in result_users:
|
||||
serialized_users.append(
|
||||
UserResponse(
|
||||
schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
id = user["data"]["userId"],
|
||||
userName = user["email"],
|
||||
name = Name.model_validate(user["data"]["name"]),
|
||||
emails = [Email.model_validate(user["data"]["emails"])],
|
||||
displayName = user["name"],
|
||||
locale = user["data"]["locale"],
|
||||
externalId = user["internalId"],
|
||||
active = True, # ignore for now, since, can't insert actual timestamp
|
||||
groups = [], # ignore
|
||||
).model_dump(mode='json')
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
|
||||
"totalResults": len(serialized_users),
|
||||
"startIndex": start_index,
|
||||
"itemsPerPage": len(serialized_users),
|
||||
"Resources": serialized_users,
|
||||
},
|
||||
)
|
||||
|
||||
@public_app.get("/Users/{user_id}", dependencies=[Depends(auth_required)])
|
||||
def get_user(user_id: str):
|
||||
"""Get SCIM User"""
|
||||
tenant_id = 1
|
||||
user = users.get_by_uuid(user_id, tenant_id)
|
||||
if not user:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
|
||||
"detail": "User not found",
|
||||
"status": 404,
|
||||
}
|
||||
)
|
||||
|
||||
res = UserResponse(
|
||||
schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
id = user["data"]["userId"],
|
||||
userName = user["email"],
|
||||
name = Name.model_validate(user["data"]["name"]),
|
||||
emails = [Email.model_validate(user["data"]["emails"])],
|
||||
displayName = user["name"],
|
||||
locale = user["data"]["locale"],
|
||||
externalId = user["internalId"],
|
||||
active = True, # ignore for now, since, can't insert actual timestamp
|
||||
groups = [], # ignore
|
||||
)
|
||||
return JSONResponse(status_code=201, content=res.model_dump(mode='json'))
|
||||
|
||||
|
||||
@public_app.post("/Users", dependencies=[Depends(auth_required)])
|
||||
async def create_user(r: UserRequest):
|
||||
"""Create SCIM User"""
|
||||
tenant_id = 1
|
||||
existing_user = users.get_by_email_only(r.userName)
|
||||
deleted_user = users.get_deleted_user_by_email(r.userName)
|
||||
|
||||
if existing_user:
|
||||
return JSONResponse(
|
||||
status_code = 409,
|
||||
content = {
|
||||
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
|
||||
"detail": "User already exists in the database.",
|
||||
"status": 409,
|
||||
}
|
||||
)
|
||||
elif deleted_user:
|
||||
user_id = users.get_deleted_by_uuid(deleted_user["data"]["userId"], tenant_id)
|
||||
user = users.restore_scim_user(user_id=user_id["userId"], tenant_id=tenant_id, user_uuid=uuid.uuid4().hex, email=r.emails[0].value, admin=False,
|
||||
display_name=r.displayName, full_name=r.name.model_dump(mode='json'), emails=r.emails[0].model_dump(mode='json'),
|
||||
origin="okta", locale=r.locale, role_id=None, internal_id=r.externalId)
|
||||
else:
|
||||
try:
|
||||
user = users.create_scim_user(tenant_id=tenant_id, user_uuid=uuid.uuid4().hex, email=r.emails[0].value, admin=False,
|
||||
display_name=r.displayName, full_name=r.name.model_dump(mode='json'), emails=r.emails[0].model_dump(mode='json'),
|
||||
origin="okta", locale=r.locale, role_id=None, internal_id=r.externalId)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
res = UserResponse(
|
||||
schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
id = user["data"]["userId"],
|
||||
userName = r.userName,
|
||||
name = r.name,
|
||||
emails = r.emails,
|
||||
displayName = r.displayName,
|
||||
locale = r.locale,
|
||||
externalId = r.externalId,
|
||||
active = r.active, # ignore for now, since, can't insert actual timestamp
|
||||
groups = [], # ignore
|
||||
)
|
||||
return JSONResponse(status_code=201, content=res.model_dump(mode='json'))
|
||||
|
||||
|
||||
|
||||
@public_app.put("/Users/{user_id}", dependencies=[Depends(auth_required)])
|
||||
def update_user(user_id: str, r: UserRequest):
|
||||
"""Update SCIM User"""
|
||||
tenant_id = 1
|
||||
user = users.get_by_uuid(user_id, tenant_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
changes = r.model_dump(mode='json', exclude={"schemas", "emails", "name", "locale", "groups", "password", "active"}) # some of these should be added later if necessary
|
||||
nested_changes = r.model_dump(mode='json', include={"name", "emails"})
|
||||
mapping = {"userName": "email", "displayName": "name", "externalId": "internal_id"} # mapping between scim schema field names and local database model, can be done as config?
|
||||
for k, v in mapping.items():
|
||||
if k in changes:
|
||||
changes[v] = changes.pop(k)
|
||||
changes["data"] = {}
|
||||
for k, v in nested_changes.items():
|
||||
value_to_insert = v[0] if k == "emails" else v
|
||||
changes["data"][k] = value_to_insert
|
||||
try:
|
||||
users.update(tenant_id, user["userId"], changes)
|
||||
res = UserResponse(
|
||||
schemas = ["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
id = user["data"]["userId"],
|
||||
userName = r.userName,
|
||||
name = r.name,
|
||||
emails = r.emails,
|
||||
displayName = r.displayName,
|
||||
locale = r.locale,
|
||||
externalId = r.externalId,
|
||||
active = r.active, # ignore for now, since, can't insert actual timestamp
|
||||
groups = [], # ignore
|
||||
)
|
||||
|
||||
return JSONResponse(status_code=201, content=res.model_dump(mode='json'))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@public_app.patch("/Users/{user_id}", dependencies=[Depends(auth_required)])
|
||||
def deactivate_user(user_id: str, r: PatchUserRequest):
|
||||
"""Deactivate user, soft-delete"""
|
||||
tenant_id = 1
|
||||
active = r.model_dump(mode='json')["Operations"][0]["value"]["active"]
|
||||
if active:
|
||||
raise HTTPException(status_code=404, detail="Activating user is not supported")
|
||||
user = users.get_by_uuid(user_id, tenant_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
users.delete_member_as_admin(tenant_id, user["userId"])
|
||||
|
||||
return Response(status_code=204, content="")
|
||||
|
||||
@public_app.delete("/Users/{user_uuid}", dependencies=[Depends(auth_required)])
|
||||
def delete_user(user_uuid: str):
|
||||
"""Delete user from database, hard-delete"""
|
||||
tenant_id = 1
|
||||
user = users.get_by_uuid(user_uuid, tenant_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
users.__hard_delete_user_uuid(user_uuid)
|
||||
return Response(status_code=204, content="")
|
||||
|
||||
|
||||
|
||||
"""
|
||||
Group endpoints
|
||||
"""
|
||||
|
||||
class Operation(BaseModel):
|
||||
op: str
|
||||
path: str = Field(default=None)
|
||||
value: list[dict] | dict = Field(default=None)
|
||||
|
||||
class GroupGetResponse(BaseModel):
|
||||
schemas: list[str] = Field(default=["urn:ietf:params:scim:api:messages:2.0:ListResponse"])
|
||||
totalResults: int
|
||||
startIndex: int
|
||||
itemsPerPage: int
|
||||
resources: list = Field(alias="Resources")
|
||||
|
||||
class GroupRequest(BaseModel):
|
||||
schemas: list[str] = Field(default=["urn:ietf:params:scim:schemas:core:2.0:Group"])
|
||||
displayName: str = Field(default=None)
|
||||
members: list = Field(default=None)
|
||||
operations: list[Operation] = Field(default=None, alias="Operations")
|
||||
|
||||
class GroupPatchRequest(BaseModel):
|
||||
schemas: list[str] = Field(default=["urn:ietf:params:scim:api:messages:2.0:PatchOp"])
|
||||
operations: list[Operation] = Field(alias="Operations")
|
||||
|
||||
class GroupResponse(BaseModel):
|
||||
schemas: list[str] = Field(default=["urn:ietf:params:scim:schemas:core:2.0:Group"])
|
||||
id: str
|
||||
displayName: str
|
||||
members: list
|
||||
meta: dict = Field(default={"resourceType": "Group"})
|
||||
|
||||
|
||||
@public_app.get("/Groups", dependencies=[Depends(auth_required)])
|
||||
def get_groups(
|
||||
start_index: int = Query(1, alias="startIndex"),
|
||||
count: Optional[int] = Query(None, alias="count"),
|
||||
group_name: Optional[str] = Query(None, alias="filter"),
|
||||
):
|
||||
"""Get groups"""
|
||||
tenant_id = 1
|
||||
res = []
|
||||
if group_name:
|
||||
group_name = group_name.split(" ")[2].strip('"')
|
||||
|
||||
groups = roles.get_roles_with_uuid_paginated(tenant_id, start_index, count, group_name)
|
||||
res = [{
|
||||
"id": group["data"]["groupId"],
|
||||
"meta": {
|
||||
"created": group["createdAt"],
|
||||
"lastModified": "", # not currently a field
|
||||
"version": "v1.0"
|
||||
},
|
||||
"displayName": group["name"]
|
||||
} for group in groups
|
||||
]
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content=GroupGetResponse(
|
||||
totalResults=len(groups),
|
||||
startIndex=start_index,
|
||||
itemsPerPage=len(groups),
|
||||
Resources=res
|
||||
).model_dump(mode='json'))
|
||||
|
||||
@public_app.get("/Groups/{group_id}", dependencies=[Depends(auth_required)])
|
||||
def get_group(group_id: str):
|
||||
"""Get a group by id"""
|
||||
tenant_id = 1
|
||||
group = roles.get_role_by_group_id(tenant_id, group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="Group not found")
|
||||
members = roles.get_users_by_group_uuid(tenant_id, group["data"]["groupId"])
|
||||
members = [{"value": member["data"]["userId"], "display": member["name"]} for member in members]
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content=GroupResponse(
|
||||
id=group["data"]["groupId"],
|
||||
displayName=group["name"],
|
||||
members=members,
|
||||
).model_dump(mode='json'))
|
||||
|
||||
@public_app.post("/Groups", dependencies=[Depends(auth_required)])
|
||||
def create_group(r: GroupRequest):
|
||||
"""Create a group"""
|
||||
tenant_id = 1
|
||||
member_role = roles.get_member_permissions(tenant_id)
|
||||
try:
|
||||
data = schemas.RolePayloadSchema(name=r.displayName, permissions=member_role["permissions"]) # permissions by default are same as for member role
|
||||
group = roles.create_as_admin(tenant_id, uuid.uuid4().hex, data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
added_members = []
|
||||
for member in r.members:
|
||||
user = users.get_by_uuid(member["value"], tenant_id)
|
||||
if user:
|
||||
users.update(tenant_id, user["userId"], {"role_id": group["roleId"]})
|
||||
added_members.append({
|
||||
"value": user["data"]["userId"],
|
||||
"display": user["name"]
|
||||
})
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content=GroupResponse(
|
||||
id=group["data"]["groupId"],
|
||||
displayName=group["name"],
|
||||
members=added_members,
|
||||
).model_dump(mode='json'))
|
||||
|
||||
|
||||
@public_app.put("/Groups/{group_id}", dependencies=[Depends(auth_required)])
|
||||
def update_put_group(group_id: str, r: GroupRequest):
|
||||
"""Update a group or members of the group (not used by anything yet)"""
|
||||
tenant_id = 1
|
||||
group = roles.get_role_by_group_id(tenant_id, group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="Group not found")
|
||||
|
||||
if r.operations and r.operations[0].op == "replace" and r.operations[0].path is None:
|
||||
roles.update_group_name(tenant_id, group["data"]["groupId"], r.operations[0].value["displayName"])
|
||||
return Response(status_code=200, content="")
|
||||
|
||||
members = r.members
|
||||
modified_members = []
|
||||
for member in members:
|
||||
user = users.get_by_uuid(member["value"], tenant_id)
|
||||
if user:
|
||||
users.update(tenant_id, user["userId"], {"role_id": group["roleId"]})
|
||||
modified_members.append({
|
||||
"value": user["data"]["userId"],
|
||||
"display": user["name"]
|
||||
})
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content=GroupResponse(
|
||||
id=group_id,
|
||||
displayName=group["name"],
|
||||
members=modified_members,
|
||||
).model_dump(mode='json'))
|
||||
|
||||
|
||||
@public_app.patch("/Groups/{group_id}", dependencies=[Depends(auth_required)])
|
||||
def update_patch_group(group_id: str, r: GroupPatchRequest):
|
||||
"""Update a group or members of the group, used by AIW"""
|
||||
tenant_id = 1
|
||||
group = roles.get_role_by_group_id(tenant_id, group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="Group not found")
|
||||
if r.operations[0].op == "replace" and r.operations[0].path is None:
|
||||
roles.update_group_name(tenant_id, group["data"]["groupId"], r.operations[0].value["displayName"])
|
||||
return Response(status_code=200, content="")
|
||||
|
||||
modified_members = []
|
||||
for op in r.operations:
|
||||
if op.op == "add" or op.op == "replace":
|
||||
# Both methods work as "replace"
|
||||
for u in op.value:
|
||||
user = users.get_by_uuid(u["value"], tenant_id)
|
||||
if user:
|
||||
users.update(tenant_id, user["userId"], {"role_id": group["roleId"]})
|
||||
modified_members.append({
|
||||
"value": user["data"]["userId"],
|
||||
"display": user["name"]
|
||||
})
|
||||
elif op.op == "remove":
|
||||
user_id = re.search(r'\[value eq \"([a-f0-9]+)\"\]', op.path).group(1)
|
||||
roles.remove_group_membership(tenant_id, group_id, user_id)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content=GroupResponse(
|
||||
id=group_id,
|
||||
displayName=group["name"],
|
||||
members=modified_members,
|
||||
).model_dump(mode='json'))
|
||||
|
||||
|
||||
@public_app.delete("/Groups/{group_id}", dependencies=[Depends(auth_required)])
|
||||
def delete_group(group_id: str):
|
||||
"""Delete a group, hard-delete"""
|
||||
# possibly need to set the user's roles to default member role, instead of null
|
||||
tenant_id = 1
|
||||
group = roles.get_role_by_group_id(tenant_id, group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="Group not found")
|
||||
roles.delete_scim_group(tenant_id, group["data"]["groupId"])
|
||||
|
||||
return Response(status_code=200, content="")
|
||||
0
ee/api/routers/scim/__init__.py
Normal file
0
ee/api/routers/scim/__init__.py
Normal file
164
ee/api/routers/scim/api.py
Normal file
164
ee/api/routers/scim/api.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
from scim2_server import utils
|
||||
|
||||
|
||||
from routers.base import get_routers
|
||||
from routers.scim.providers import MultiTenantProvider
|
||||
from routers.scim.backends import PostgresBackend
|
||||
from routers.scim.postgres_resource import PostgresResource
|
||||
from routers.scim import users, groups, helpers
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from chalicelib.utils import pg_client
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from chalicelib.utils.scim_auth import (
|
||||
create_tokens,
|
||||
verify_refresh_token,
|
||||
)
|
||||
|
||||
|
||||
b = PostgresBackend()
|
||||
b.register_postgres_resource(
|
||||
"User",
|
||||
PostgresResource(
|
||||
query_resources=users.query_resources,
|
||||
get_resource=users.get_resource,
|
||||
create_resource=users.create_resource,
|
||||
search_existing=users.search_existing,
|
||||
restore_resource=users.restore_resource,
|
||||
delete_resource=users.delete_resource,
|
||||
update_resource=users.update_resource,
|
||||
),
|
||||
)
|
||||
b.register_postgres_resource(
|
||||
"Group",
|
||||
PostgresResource(
|
||||
query_resources=groups.query_resources,
|
||||
get_resource=groups.get_resource,
|
||||
create_resource=groups.create_resource,
|
||||
search_existing=groups.search_existing,
|
||||
restore_resource=None,
|
||||
delete_resource=groups.delete_resource,
|
||||
update_resource=groups.update_resource,
|
||||
),
|
||||
)
|
||||
|
||||
scim_app = MultiTenantProvider(b)
|
||||
|
||||
for schema in utils.load_default_schemas().values():
|
||||
scim_app.register_schema(schema)
|
||||
for schema in helpers.load_custom_schemas().values():
|
||||
scim_app.register_schema(schema)
|
||||
for resource_type in helpers.load_custom_resource_types().values():
|
||||
scim_app.register_resource_type(resource_type)
|
||||
|
||||
|
||||
public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2")
|
||||
|
||||
|
||||
@public_app.post("/token")
|
||||
async def post_token(r: Request):
|
||||
form = await r.form()
|
||||
|
||||
client_id = form.get("client_id")
|
||||
client_secret = form.get("client_secret")
|
||||
with pg_client.PostgresClient() as cur:
|
||||
try:
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
SELECT tenant_id
|
||||
FROM public.tenants
|
||||
WHERE tenant_id=%(tenant_id)s AND tenant_key=%(tenant_key)s
|
||||
""",
|
||||
{"tenant_id": int(client_id), "tenant_key": client_secret},
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
tenant = cur.fetchone()
|
||||
if not tenant:
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
grant_type = form.get("grant_type")
|
||||
if grant_type == "refresh_token":
|
||||
refresh_token = form.get("refresh_token")
|
||||
verify_refresh_token(refresh_token)
|
||||
else:
|
||||
code = form.get("code")
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
SELECT *
|
||||
FROM public.scim_auth_codes
|
||||
WHERE auth_code=%(auth_code)s AND tenant_id=%(tenant_id)s AND used IS FALSE
|
||||
""",
|
||||
{"auth_code": code, "tenant_id": int(client_id)},
|
||||
)
|
||||
)
|
||||
if cur.fetchone() is None:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid code/client_id pair"
|
||||
)
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
UPDATE public.scim_auth_codes
|
||||
SET used=TRUE
|
||||
WHERE auth_code=%(auth_code)s AND tenant_id=%(tenant_id)s AND used IS FALSE
|
||||
""",
|
||||
{"auth_code": code, "tenant_id": int(client_id)},
|
||||
)
|
||||
)
|
||||
|
||||
access_token, refresh_token, expires_in = create_tokens(
|
||||
tenant_id=tenant["tenant_id"]
|
||||
)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": expires_in,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
|
||||
# note(jon): this might be specific to okta. if so, we should probably put specify that in the endpoint
|
||||
@public_app.get("/authorize")
|
||||
async def get_authorize(
|
||||
r: Request,
|
||||
response_type: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
state: str | None = None,
|
||||
):
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
UPDATE public.scim_auth_codes
|
||||
SET used=TRUE
|
||||
WHERE tenant_id=%(tenant_id)s
|
||||
""",
|
||||
{"tenant_id": int(client_id)},
|
||||
)
|
||||
)
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
INSERT INTO public.scim_auth_codes (tenant_id)
|
||||
VALUES (%(tenant_id)s)
|
||||
RETURNING auth_code
|
||||
""",
|
||||
{"tenant_id": int(client_id)},
|
||||
)
|
||||
)
|
||||
code = cur.fetchone()["auth_code"]
|
||||
params = {"code": code}
|
||||
if state:
|
||||
params["state"] = state
|
||||
url = f"{redirect_uri}?{urlencode(params)}"
|
||||
return RedirectResponse(url)
|
||||
203
ee/api/routers/scim/backends.py
Normal file
203
ee/api/routers/scim/backends.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
from scim2_server import backend
|
||||
from scim2_server.filter import evaluate_filter
|
||||
from scim2_server.utils import SCIMException
|
||||
|
||||
from scim2_models import (
|
||||
SearchRequest,
|
||||
Resource,
|
||||
Context,
|
||||
Error,
|
||||
)
|
||||
from scim2_filter_parser import lexer
|
||||
from scim2_filter_parser.parser import SCIMParser
|
||||
from routers.scim.postgres_resource import PostgresResource
|
||||
from scim2_server.operators import ResolveSortOperator
|
||||
import operator
|
||||
|
||||
|
||||
class PostgresBackend(backend.Backend):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._postgres_resources = {}
|
||||
|
||||
def register_postgres_resource(
|
||||
self, resource_type_id: str, postgres_resource: PostgresResource
|
||||
):
|
||||
self._postgres_resources[resource_type_id] = postgres_resource
|
||||
|
||||
def query_resources(
|
||||
self,
|
||||
search_request: SearchRequest,
|
||||
tenant_id: int,
|
||||
resource_type_id: str | None = None,
|
||||
) -> tuple[int, list[Resource]]:
|
||||
"""Query the backend for a set of resources.
|
||||
|
||||
:param search_request: SearchRequest instance describing the
|
||||
query.
|
||||
:param resource_type_id: ID of the resource type to query. If
|
||||
None, all resource types are queried.
|
||||
:return: A tuple of "total results" and a List of found
|
||||
Resources. The List must contain a copy of resources.
|
||||
Mutating elements in the List must not modify the data
|
||||
stored in the backend.
|
||||
:raises SCIMException: If the backend only supports querying for
|
||||
one resource type at a time, setting resource_type_id to
|
||||
None the backend may raise a
|
||||
SCIMException(Error.make_too_many_error()).
|
||||
"""
|
||||
start_index = (search_request.start_index or 1) - 1
|
||||
|
||||
tree = None
|
||||
if search_request.filter is not None:
|
||||
token_stream = lexer.SCIMLexer().tokenize(search_request.filter)
|
||||
tree = SCIMParser().parse(token_stream)
|
||||
|
||||
# todo(jon): handle the case when resource_type_id is None.
|
||||
# we're assuming it's never None for now.
|
||||
# but, this is fine to leave as it doesn't seem to used or reached in
|
||||
# any of my tests yet.
|
||||
if not resource_type_id:
|
||||
raise NotImplementedError
|
||||
|
||||
resources = self._postgres_resources[resource_type_id].query_resources(
|
||||
tenant_id
|
||||
)
|
||||
model = self.get_model(resource_type_id)
|
||||
resources = [
|
||||
model.model_validate(r, scim_ctx=Context.RESOURCE_QUERY_RESPONSE)
|
||||
for r in resources
|
||||
]
|
||||
resources = [r for r in resources if (tree is None or evaluate_filter(r, tree))]
|
||||
|
||||
if search_request.sort_by is not None:
|
||||
descending = search_request.sort_order == SearchRequest.SortOrder.descending
|
||||
sort_operator = ResolveSortOperator(search_request.sort_by)
|
||||
|
||||
# To ensure that unset attributes are sorted last (when ascending, as defined in the RFC),
|
||||
# we have to divide the result set into a set and unset subset.
|
||||
unset_values = []
|
||||
set_values = []
|
||||
for resource in resources:
|
||||
result = sort_operator(resource)
|
||||
if result is None:
|
||||
unset_values.append(resource)
|
||||
else:
|
||||
set_values.append((resource, result))
|
||||
|
||||
set_values.sort(key=operator.itemgetter(1), reverse=descending)
|
||||
set_values = [value[0] for value in set_values]
|
||||
if descending:
|
||||
resources = unset_values + set_values
|
||||
else:
|
||||
resources = set_values + unset_values
|
||||
|
||||
found_resources = resources[start_index:]
|
||||
if search_request.count is not None:
|
||||
found_resources = resources[: search_request.count]
|
||||
|
||||
return len(resources), found_resources
|
||||
|
||||
def get_resource(
|
||||
self, tenant_id: int, resource_type_id: str, object_id: str
|
||||
) -> Resource | None:
|
||||
"""Query the backend for a resources by its ID.
|
||||
|
||||
:param resource_type_id: ID of the resource type to get the
|
||||
object from.
|
||||
:param object_id: ID of the object to get.
|
||||
:return: The resource object if it exists, None otherwise. The
|
||||
resource must be a copy, modifying it must not change the
|
||||
data stored in the backend.
|
||||
"""
|
||||
resource = self._postgres_resources[resource_type_id].get_resource(
|
||||
object_id, tenant_id
|
||||
)
|
||||
if resource:
|
||||
model = self.get_model(resource_type_id)
|
||||
resource = model.model_validate(resource)
|
||||
return resource
|
||||
|
||||
def delete_resource(
|
||||
self, tenant_id: int, resource_type_id: str, object_id: str
|
||||
) -> bool:
|
||||
"""Delete a resource.
|
||||
|
||||
:param resource_type_id: ID of the resource type to delete the
|
||||
object from.
|
||||
:param object_id: ID of the object to delete.
|
||||
:return: True if the resource was deleted, False otherwise.
|
||||
"""
|
||||
resource = self.get_resource(tenant_id, resource_type_id, object_id)
|
||||
if resource:
|
||||
self._postgres_resources[resource_type_id].delete_resource(
|
||||
object_id, tenant_id
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def create_resource(
|
||||
self, tenant_id: int, resource_type_id: str, resource: Resource
|
||||
) -> Resource | None:
|
||||
"""Create a resource.
|
||||
|
||||
:param resource_type_id: ID of the resource type to create.
|
||||
:param resource: Resource to create.
|
||||
:return: The created resource. Creation should set system-
|
||||
defined attributes (ID, Metadata). May be the same object
|
||||
that is passed in.
|
||||
"""
|
||||
model = self.get_model(resource_type_id)
|
||||
existing = self._postgres_resources[resource_type_id].search_existing(
|
||||
tenant_id, resource
|
||||
)
|
||||
if existing:
|
||||
existing = model.model_validate(existing)
|
||||
if existing.active:
|
||||
raise SCIMException(Error.make_uniqueness_error())
|
||||
resource = self._postgres_resources[resource_type_id].restore_resource(
|
||||
tenant_id, resource
|
||||
)
|
||||
else:
|
||||
resource = self._postgres_resources[resource_type_id].create_resource(
|
||||
tenant_id, resource
|
||||
)
|
||||
resource = model.model_validate(resource)
|
||||
return resource
|
||||
|
||||
def update_resource(
|
||||
self, tenant_id: int, resource_type_id: str, resource: Resource
|
||||
) -> Resource | None:
|
||||
"""Update a resource. The resource is identified by its ID.
|
||||
|
||||
:param resource_type_id: ID of the resource type to update.
|
||||
:param resource: Resource to update.
|
||||
:return: The updated resource. Updating should update the
|
||||
"meta.lastModified" data. May be the same object that is
|
||||
passed in.
|
||||
"""
|
||||
model = self.get_model(resource_type_id)
|
||||
existing = self._postgres_resources[resource_type_id].search_existing(
|
||||
tenant_id, resource
|
||||
)
|
||||
if existing:
|
||||
existing = model.model_validate(existing)
|
||||
if existing.active:
|
||||
if existing.id != resource.id:
|
||||
raise SCIMException(Error.make_uniqueness_error())
|
||||
resource = self._postgres_resources[resource_type_id].update_resource(
|
||||
tenant_id, resource
|
||||
)
|
||||
else:
|
||||
self._postgres_resources[resource_type_id].delete_resource(
|
||||
existing.id, tenant_id
|
||||
)
|
||||
resource = self._postgres_resources[resource_type_id].update_resource(
|
||||
resource.id, tenant_id, resource
|
||||
)
|
||||
else:
|
||||
resource = self._postgres_resources[resource_type_id].update_resource(
|
||||
tenant_id, resource
|
||||
)
|
||||
resource = model.model_validate(resource)
|
||||
return resource
|
||||
36
ee/api/routers/scim/fixtures/custom_resource_types.json
Normal file
36
ee/api/routers/scim/fixtures/custom_resource_types.json
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
[{
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ResourceType"],
|
||||
"id": "User",
|
||||
"name": "User",
|
||||
"endpoint": "/Users",
|
||||
"description": "User Account",
|
||||
"schema": "urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
"schemaExtensions": [
|
||||
{
|
||||
"schema":
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"schema":
|
||||
"urn:ietf:params:scim:schemas:extension:openreplay:2.0:User",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"location": "/v2/ResourceTypes/User",
|
||||
"resourceType": "ResourceType"
|
||||
}
|
||||
},
|
||||
{
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ResourceType"],
|
||||
"id": "Group",
|
||||
"name": "Group",
|
||||
"endpoint": "/Groups",
|
||||
"description": "Group",
|
||||
"schema": "urn:ietf:params:scim:schemas:core:2.0:Group",
|
||||
"meta": {
|
||||
"location": "/v2/ResourceTypes/Group",
|
||||
"resourceType": "ResourceType"
|
||||
}
|
||||
}]
|
||||
32
ee/api/routers/scim/fixtures/custom_schemas.json
Normal file
32
ee/api/routers/scim/fixtures/custom_schemas.json
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
[
|
||||
{
|
||||
"id": "urn:ietf:params:scim:schemas:extension:openreplay:2.0:User",
|
||||
"name": "OpenreplayUser",
|
||||
"description": "Openreplay User Account Extension",
|
||||
"attributes": [
|
||||
{
|
||||
"name": "permissions",
|
||||
"type": "string",
|
||||
"multiValued": true,
|
||||
"description": "A list of permissions for the User that represent a thing the User is capable of doing.",
|
||||
"required": false,
|
||||
"canonicalValues": ["SESSION_REPLAY", "DEV_TOOLS", "METRICS", "ASSIST_LIVE", "ASSIST_CALL", "SPOT", "SPOT_PUBLIC"],
|
||||
"mutability": "readWrite",
|
||||
"returned": "default"
|
||||
},
|
||||
{
|
||||
"name": "projectKeys",
|
||||
"type": "string",
|
||||
"multiValued": true,
|
||||
"description": "A list of project keys for the User that represent a project the User is allowed to work on.",
|
||||
"required": false,
|
||||
"mutability": "readWrite",
|
||||
"returned": "default"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"resourceType": "Schema",
|
||||
"location": "/v2/Schemas/urn:ietf:params:scim:schemas:extension:openreplay:2.0:User"
|
||||
}
|
||||
}
|
||||
]
|
||||
203
ee/api/routers/scim/groups.py
Normal file
203
ee/api/routers/scim/groups.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
from typing import Any
|
||||
from datetime import datetime
|
||||
from psycopg2.extensions import AsIs
|
||||
|
||||
from chalicelib.utils import pg_client
|
||||
from routers.scim import helpers
|
||||
|
||||
from scim2_models import Error, Resource
|
||||
from scim2_server.utils import SCIMException
|
||||
|
||||
|
||||
def convert_provider_resource_to_client_resource(
|
||||
provider_resource: dict,
|
||||
) -> dict:
|
||||
members = provider_resource["users"] or []
|
||||
return {
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
|
||||
"id": str(provider_resource["role_id"]),
|
||||
"meta": {
|
||||
"resourceType": "Group",
|
||||
"created": provider_resource["created_at"].strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
"lastModified": provider_resource["updated_at"].strftime(
|
||||
"%Y-%m-%dT%H:%M:%SZ"
|
||||
),
|
||||
},
|
||||
"displayName": provider_resource["name"],
|
||||
"members": [
|
||||
{
|
||||
"value": str(member["user_id"]),
|
||||
"$ref": f"Users/{member['user_id']}",
|
||||
"type": "User",
|
||||
}
|
||||
for member in members
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def query_resources(tenant_id: int) -> list[dict]:
|
||||
query = _main_select_query(tenant_id)
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(query)
|
||||
items = cur.fetchall()
|
||||
return [convert_provider_resource_to_client_resource(item) for item in items]
|
||||
|
||||
|
||||
def get_resource(resource_id: str, tenant_id: int) -> dict | None:
|
||||
query = _main_select_query(tenant_id, resource_id)
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(query)
|
||||
item = cur.fetchone()
|
||||
if item:
|
||||
return convert_provider_resource_to_client_resource(item)
|
||||
return None
|
||||
|
||||
|
||||
def delete_resource(resource_id: str, tenant_id: int) -> None:
|
||||
_update_resource_sql(
|
||||
resource_id=resource_id,
|
||||
tenant_id=tenant_id,
|
||||
deleted_at=datetime.now(),
|
||||
)
|
||||
|
||||
|
||||
def search_existing(tenant_id: int, resource: Resource) -> dict | None:
|
||||
return None
|
||||
|
||||
|
||||
def create_resource(tenant_id: int, resource: Resource) -> dict:
|
||||
with pg_client.PostgresClient() as cur:
|
||||
user_ids = (
|
||||
[int(x.value) for x in resource.members] if resource.members else None
|
||||
)
|
||||
user_id_clause = helpers.safe_mogrify_array(user_ids, "int", cur)
|
||||
try:
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
INSERT INTO public.roles (
|
||||
name,
|
||||
tenant_id
|
||||
)
|
||||
VALUES (
|
||||
%(name)s,
|
||||
%(tenant_id)s
|
||||
)
|
||||
RETURNING role_id
|
||||
""",
|
||||
{
|
||||
"name": resource.display_name,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
raise SCIMException(Error.make_invalid_value_error())
|
||||
role_id = cur.fetchone()["role_id"]
|
||||
cur.execute(
|
||||
f"""
|
||||
UPDATE public.users
|
||||
SET
|
||||
updated_at = now(),
|
||||
role_id = {role_id}
|
||||
WHERE users.user_id = ANY({user_id_clause})
|
||||
"""
|
||||
)
|
||||
cur.execute(f"{_main_select_query(tenant_id, role_id)} LIMIT 1")
|
||||
item = cur.fetchone()
|
||||
return convert_provider_resource_to_client_resource(item)
|
||||
|
||||
|
||||
def update_resource(tenant_id: int, resource: Resource) -> dict | None:
|
||||
item = _update_resource_sql(
|
||||
resource_id=resource.id,
|
||||
tenant_id=tenant_id,
|
||||
name=resource.display_name,
|
||||
user_ids=[int(x.value) for x in resource.members],
|
||||
deleted_at=None,
|
||||
)
|
||||
return convert_provider_resource_to_client_resource(item)
|
||||
|
||||
|
||||
def _main_select_query(tenant_id: int, resource_id: str | None = None) -> str:
|
||||
where_and_clauses = [
|
||||
f"roles.tenant_id = {tenant_id}",
|
||||
"roles.deleted_at IS NULL",
|
||||
]
|
||||
if resource_id is not None:
|
||||
where_and_clauses.append(f"roles.role_id = {resource_id}")
|
||||
where_clause = " AND ".join(where_and_clauses)
|
||||
return f"""
|
||||
SELECT
|
||||
roles.*,
|
||||
COALESCE(
|
||||
(
|
||||
SELECT json_agg(users)
|
||||
FROM public.users
|
||||
WHERE users.role_id = roles.role_id
|
||||
),
|
||||
'[]'
|
||||
) AS users,
|
||||
COALESCE(
|
||||
(
|
||||
SELECT json_agg(projects.project_key)
|
||||
FROM public.projects
|
||||
LEFT JOIN public.roles_projects USING (project_id)
|
||||
WHERE roles_projects.role_id = roles.role_id
|
||||
),
|
||||
'[]'
|
||||
) AS project_keys
|
||||
FROM public.roles
|
||||
WHERE {where_clause}
|
||||
"""
|
||||
|
||||
|
||||
def _update_resource_sql(
|
||||
resource_id: int,
|
||||
tenant_id: int,
|
||||
user_ids: list[int] | None = None,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
with pg_client.PostgresClient() as cur:
|
||||
kwargs["updated_at"] = datetime.now()
|
||||
set_fragments = [
|
||||
cur.mogrify("%s = %s", (AsIs(k), v)).decode("utf-8")
|
||||
for k, v in kwargs.items()
|
||||
]
|
||||
set_clause = ", ".join(set_fragments)
|
||||
user_id_clause = helpers.safe_mogrify_array(user_ids, "int", cur)
|
||||
cur.execute(
|
||||
f"""
|
||||
UPDATE public.users
|
||||
SET
|
||||
updated_at = now(),
|
||||
role_id = NULL
|
||||
WHERE
|
||||
users.role_id = {resource_id}
|
||||
AND users.user_id != ALL({user_id_clause})
|
||||
RETURNING *
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
f"""
|
||||
UPDATE public.users
|
||||
SET
|
||||
updated_at = now(),
|
||||
role_id = {resource_id}
|
||||
WHERE
|
||||
(users.role_id != {resource_id} OR users.role_id IS NULL)
|
||||
AND users.user_id = ANY({user_id_clause})
|
||||
RETURNING *
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
f"""
|
||||
UPDATE public.roles
|
||||
SET {set_clause}
|
||||
WHERE
|
||||
roles.role_id = {resource_id}
|
||||
AND roles.tenant_id = {tenant_id}
|
||||
"""
|
||||
)
|
||||
cur.execute(f"{_main_select_query(tenant_id, resource_id)} LIMIT 1")
|
||||
return cur.fetchone()
|
||||
44
ee/api/routers/scim/helpers.py
Normal file
44
ee/api/routers/scim/helpers.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
from typing import Any, Literal
|
||||
from chalicelib.utils import pg_client
|
||||
from scim2_models import Schema, Resource, ResourceType
|
||||
import os
|
||||
import json
|
||||
|
||||
|
||||
def safe_mogrify_array(
|
||||
items: list[Any] | None,
|
||||
array_type: Literal["varchar", "int"],
|
||||
cursor: pg_client.PostgresClient,
|
||||
) -> str:
|
||||
items = items or []
|
||||
fragments = [cursor.mogrify("%s", (item,)).decode("utf-8") for item in items]
|
||||
result = f"ARRAY[{', '.join(fragments)}]::{array_type}[]"
|
||||
return result
|
||||
|
||||
|
||||
def load_json_resource(json_name: str) -> dict:
|
||||
with open(json_name) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def load_scim_resource(
|
||||
json_name: str, type_: type[Resource]
|
||||
) -> dict[str, type[Resource]]:
|
||||
ret = {}
|
||||
definitions = load_json_resource(json_name)
|
||||
for d in definitions:
|
||||
model = type_.model_validate(d)
|
||||
ret[model.id] = model
|
||||
return ret
|
||||
|
||||
|
||||
def load_custom_schemas() -> dict[str, Schema]:
|
||||
json_name = os.path.join("routers", "scim", "fixtures", "custom_schemas.json")
|
||||
return load_scim_resource(json_name, Schema)
|
||||
|
||||
|
||||
def load_custom_resource_types() -> dict[str, ResourceType]:
|
||||
json_name = os.path.join(
|
||||
"routers", "scim", "fixtures", "custom_resource_types.json"
|
||||
)
|
||||
return load_scim_resource(json_name, ResourceType)
|
||||
14
ee/api/routers/scim/postgres_resource.py
Normal file
14
ee/api/routers/scim/postgres_resource.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
from scim2_models import Resource
|
||||
|
||||
|
||||
@dataclass
|
||||
class PostgresResource:
|
||||
query_resources: Callable[[int], list[dict]]
|
||||
get_resource: Callable[[str, int], dict | None]
|
||||
create_resource: Callable[[int, Resource], dict]
|
||||
search_existing: Callable[[int, Resource], dict | None]
|
||||
restore_resource: Callable[[int, Resource], dict] | None
|
||||
delete_resource: Callable[[str, int], None]
|
||||
update_resource: Callable[[int, Resource], dict]
|
||||
280
ee/api/routers/scim/providers.py
Normal file
280
ee/api/routers/scim/providers.py
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
import traceback
|
||||
from typing import Union
|
||||
|
||||
from scim2_server import provider
|
||||
|
||||
from scim2_models import (
|
||||
AuthenticationScheme,
|
||||
ServiceProviderConfig,
|
||||
Patch,
|
||||
Bulk,
|
||||
Filter,
|
||||
Sort,
|
||||
ETag,
|
||||
Meta,
|
||||
ChangePassword,
|
||||
Error,
|
||||
ResourceType,
|
||||
Context,
|
||||
ListResponse,
|
||||
PatchOp,
|
||||
)
|
||||
|
||||
from werkzeug import Request, Response
|
||||
from werkzeug.exceptions import HTTPException, NotFound, PreconditionFailed
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.routing.exceptions import RequestRedirect
|
||||
from scim2_server.utils import SCIMException, merge_resources
|
||||
|
||||
from chalicelib.utils.scim_auth import verify_access_token
|
||||
|
||||
|
||||
class MultiTenantProvider(provider.SCIMProvider):
|
||||
def check_auth(self, request: Request):
|
||||
auth = request.headers.get("Authorization")
|
||||
if not auth or not auth.startswith("Bearer "):
|
||||
return None
|
||||
token = auth[len("Bearer ") :]
|
||||
if not token:
|
||||
return Response(
|
||||
"Missing or invalid Authorization header",
|
||||
status=401,
|
||||
headers={"WWW-Authenticate": 'Bearer realm="login required"'},
|
||||
)
|
||||
payload = verify_access_token(token)
|
||||
tenant_id = payload["tenant_id"]
|
||||
return tenant_id
|
||||
|
||||
def get_service_provider_config(self):
|
||||
auth_schemes = [
|
||||
AuthenticationScheme(
|
||||
type="oauthbearertoken",
|
||||
name="OAuth Bearer Token",
|
||||
description="Authentication scheme using the OAuth Bearer Token Standard. The access token should be sent in the 'Authorization' header using the Bearer schema.",
|
||||
spec_uri="https://datatracker.ietf.org/doc/html/rfc6750",
|
||||
)
|
||||
]
|
||||
return ServiceProviderConfig(
|
||||
# todo(jon): write correct documentation uri
|
||||
documentation_uri="https://www.example.com/",
|
||||
patch=Patch(supported=True),
|
||||
bulk=Bulk(supported=False),
|
||||
filter=Filter(supported=True, max_results=1000),
|
||||
change_password=ChangePassword(supported=False),
|
||||
sort=Sort(supported=True),
|
||||
etag=ETag(supported=False),
|
||||
authentication_schemes=auth_schemes,
|
||||
meta=Meta(resource_type="ServiceProviderConfig"),
|
||||
)
|
||||
|
||||
def query_resource(
|
||||
self, request: Request, tenant_id: int, resource: ResourceType | None
|
||||
):
|
||||
search_request = self.build_search_request(request)
|
||||
|
||||
kwargs = {}
|
||||
if resource is not None:
|
||||
kwargs["resource_type_id"] = resource.id
|
||||
total_results, results = self.backend.query_resources(
|
||||
search_request=search_request, tenant_id=tenant_id, **kwargs
|
||||
)
|
||||
for r in results:
|
||||
self.adjust_location(request, r)
|
||||
|
||||
resources = [
|
||||
s.model_dump(
|
||||
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
|
||||
attributes=search_request.attributes,
|
||||
excluded_attributes=search_request.excluded_attributes,
|
||||
)
|
||||
for s in results
|
||||
]
|
||||
|
||||
return ListResponse[Union[tuple(self.backend.get_models())]]( # noqa: UP007
|
||||
total_results=total_results,
|
||||
items_per_page=search_request.count,
|
||||
start_index=search_request.start_index,
|
||||
resources=resources,
|
||||
)
|
||||
|
||||
def call_resource(
|
||||
self, request: Request, resource_endpoint: str, **kwargs
|
||||
) -> Response:
|
||||
resource_type = self.backend.get_resource_type_by_endpoint(
|
||||
"/" + resource_endpoint
|
||||
)
|
||||
if not resource_type:
|
||||
raise NotFound
|
||||
|
||||
if "tenant_id" not in kwargs:
|
||||
raise Exception
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
|
||||
match request.method:
|
||||
case "GET":
|
||||
return self.make_response(
|
||||
self.query_resource(request, tenant_id, resource_type).model_dump(
|
||||
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
|
||||
)
|
||||
)
|
||||
case _: # "POST"
|
||||
payload = request.json
|
||||
resource = self.backend.get_model(resource_type.id).model_validate(
|
||||
payload, scim_ctx=Context.RESOURCE_CREATION_REQUEST
|
||||
)
|
||||
created_resource = self.backend.create_resource(
|
||||
tenant_id,
|
||||
resource_type.id,
|
||||
resource,
|
||||
)
|
||||
self.adjust_location(request, created_resource)
|
||||
return self.make_response(
|
||||
created_resource.model_dump(
|
||||
scim_ctx=Context.RESOURCE_CREATION_RESPONSE
|
||||
),
|
||||
status=201,
|
||||
headers={"Location": created_resource.meta.location},
|
||||
)
|
||||
|
||||
def call_single_resource(
|
||||
self, request: Request, resource_endpoint: str, resource_id: str, **kwargs
|
||||
) -> Response:
|
||||
find_endpoint = "/" + resource_endpoint
|
||||
resource_type = self.backend.get_resource_type_by_endpoint(find_endpoint)
|
||||
if not resource_type:
|
||||
raise NotFound
|
||||
|
||||
if "tenant_id" not in kwargs:
|
||||
raise Exception
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
|
||||
match request.method:
|
||||
case "GET":
|
||||
if resource := self.backend.get_resource(
|
||||
tenant_id, resource_type.id, resource_id
|
||||
):
|
||||
if self.continue_etag(request, resource):
|
||||
response_args = self.get_attrs_from_request(request)
|
||||
self.adjust_location(request, resource)
|
||||
return self.make_response(
|
||||
resource.model_dump(
|
||||
scim_ctx=Context.RESOURCE_QUERY_RESPONSE,
|
||||
**response_args,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return self.make_response(None, status=304)
|
||||
raise NotFound
|
||||
case "DELETE":
|
||||
if self.backend.delete_resource(
|
||||
tenant_id, resource_type.id, resource_id
|
||||
):
|
||||
return self.make_response(None, 204)
|
||||
else:
|
||||
raise NotFound
|
||||
case "PUT":
|
||||
response_args = self.get_attrs_from_request(request)
|
||||
resource = self.backend.get_resource(
|
||||
tenant_id, resource_type.id, resource_id
|
||||
)
|
||||
if resource is None:
|
||||
raise NotFound
|
||||
if not self.continue_etag(request, resource):
|
||||
raise PreconditionFailed
|
||||
|
||||
updated_attributes = self.backend.get_model(
|
||||
resource_type.id
|
||||
).model_validate(request.json)
|
||||
merge_resources(resource, updated_attributes)
|
||||
updated = self.backend.update_resource(
|
||||
tenant_id, resource_type.id, resource
|
||||
)
|
||||
self.adjust_location(request, updated)
|
||||
return self.make_response(
|
||||
updated.model_dump(
|
||||
scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE,
|
||||
**response_args,
|
||||
)
|
||||
)
|
||||
case _: # "PATCH"
|
||||
payload = request.json
|
||||
# MS Entra sometimes passes a "id" attribute
|
||||
if "id" in payload:
|
||||
del payload["id"]
|
||||
operations = payload.get("Operations", [])
|
||||
for operation in operations:
|
||||
if "name" in operation:
|
||||
# MS Entra sometimes passes a "name" attribute
|
||||
del operation["name"]
|
||||
|
||||
patch_operation = PatchOp.model_validate(payload)
|
||||
response_args = self.get_attrs_from_request(request)
|
||||
resource = self.backend.get_resource(
|
||||
tenant_id, resource_type.id, resource_id
|
||||
)
|
||||
if resource is None:
|
||||
raise NotFound
|
||||
if not self.continue_etag(request, resource):
|
||||
raise PreconditionFailed
|
||||
|
||||
self.apply_patch_operation(resource, patch_operation)
|
||||
updated = self.backend.update_resource(
|
||||
tenant_id, resource_type.id, resource
|
||||
)
|
||||
|
||||
if response_args:
|
||||
self.adjust_location(request, updated)
|
||||
return self.make_response(
|
||||
updated.model_dump(
|
||||
scim_ctx=Context.RESOURCE_REPLACEMENT_RESPONSE,
|
||||
**response_args,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# RFC 7644, section 3.5.2:
|
||||
# A PATCH operation MAY return a 204 (no content)
|
||||
# if no attributes were requested
|
||||
return self.make_response(
|
||||
None, 204, headers={"ETag": updated.meta.version}
|
||||
)
|
||||
|
||||
def wsgi_app(self, request: Request, environ):
|
||||
try:
|
||||
if environ.get("PATH_INFO", "").endswith(".scim"):
|
||||
# RFC 7644, Section 3.8
|
||||
# Just strip .scim suffix, the provider always returns application/scim+json
|
||||
environ["PATH_INFO"], _, _ = environ["PATH_INFO"].rpartition(".scim")
|
||||
urls = self.url_map.bind_to_environ(environ)
|
||||
endpoint, args = urls.match()
|
||||
|
||||
tenant_id = None
|
||||
if endpoint != "service_provider_config":
|
||||
# RFC7643, Section 5: skip authentication for ServiceProviderConfig
|
||||
tenant_id = self.check_auth(request)
|
||||
|
||||
# Wrap the entire call in a transaction. Should probably be optimized (use transaction only when necessary).
|
||||
with self.backend:
|
||||
if endpoint == "service_provider_config" or endpoint == "schema":
|
||||
response = getattr(self, f"call_{endpoint}")(request, **args)
|
||||
else:
|
||||
response = getattr(self, f"call_{endpoint}")(
|
||||
request, **args, tenant_id=tenant_id
|
||||
)
|
||||
return response
|
||||
except RequestRedirect as e:
|
||||
# urls.match may cause a redirect, handle it as a special case of HTTPException
|
||||
self.log.exception(e)
|
||||
return e.get_response(environ)
|
||||
except HTTPException as e:
|
||||
self.log.exception(e)
|
||||
return self.make_error(Error(status=e.code, detail=e.description))
|
||||
except SCIMException as e:
|
||||
self.log.exception(e)
|
||||
return self.make_error(e.scim_error)
|
||||
except ValidationError as e:
|
||||
self.log.exception(e)
|
||||
return self.make_error(Error(status=400, detail=str(e)))
|
||||
except Exception as e:
|
||||
self.log.exception(e)
|
||||
tb = traceback.format_exc()
|
||||
return self.make_error(Error(status=500, detail=str(e) + "\n" + tb))
|
||||
371
ee/api/routers/scim/users.py
Normal file
371
ee/api/routers/scim/users.py
Normal file
|
|
@ -0,0 +1,371 @@
|
|||
from routers.scim import helpers
|
||||
|
||||
from chalicelib.utils import pg_client
|
||||
from scim2_models import Resource
|
||||
|
||||
|
||||
def convert_provider_resource_to_client_resource(
|
||||
provider_resource: dict,
|
||||
) -> dict:
|
||||
groups = []
|
||||
if provider_resource["role_id"]:
|
||||
groups.append(
|
||||
{
|
||||
"value": str(provider_resource["role_id"]),
|
||||
"$ref": f"Groups/{provider_resource['role_id']}",
|
||||
}
|
||||
)
|
||||
return {
|
||||
"id": str(provider_resource["user_id"]),
|
||||
"schemas": [
|
||||
"urn:ietf:params:scim:schemas:core:2.0:User",
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
|
||||
"urn:ietf:params:scim:schemas:extension:openreplay:2.0:User",
|
||||
],
|
||||
"meta": {
|
||||
"resourceType": "User",
|
||||
"created": provider_resource["created_at"].strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
"lastModified": provider_resource["updated_at"].strftime(
|
||||
"%Y-%m-%dT%H:%M:%SZ"
|
||||
),
|
||||
},
|
||||
"userName": provider_resource["email"],
|
||||
"externalId": provider_resource["internal_id"],
|
||||
"name": {
|
||||
"formatted": provider_resource["name"],
|
||||
},
|
||||
"displayName": provider_resource["name"] or provider_resource["email"],
|
||||
"active": provider_resource["deleted_at"] is None,
|
||||
"groups": groups,
|
||||
"urn:ietf:params:scim:schemas:extension:openreplay:2.0:User": {
|
||||
"permissions": provider_resource.get("permissions") or [],
|
||||
"projectKeys": provider_resource.get("project_keys") or [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def query_resources(tenant_id: int) -> list[dict]:
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT
|
||||
users.*,
|
||||
roles.permissions AS permissions,
|
||||
COALESCE(
|
||||
(
|
||||
SELECT json_agg(projects.project_key)
|
||||
FROM public.projects
|
||||
LEFT JOIN public.roles_projects USING (project_id)
|
||||
WHERE roles_projects.role_id = roles.role_id
|
||||
),
|
||||
'[]'
|
||||
) AS project_keys
|
||||
FROM public.users
|
||||
LEFT JOIN public.roles ON roles.role_id = users.role_id
|
||||
WHERE users.tenant_id = {tenant_id} AND users.deleted_at IS NULL
|
||||
"""
|
||||
)
|
||||
items = cur.fetchall()
|
||||
return [convert_provider_resource_to_client_resource(item) for item in items]
|
||||
|
||||
|
||||
def get_resource(resource_id: str, tenant_id: int) -> dict | None:
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT
|
||||
users.*,
|
||||
roles.permissions AS permissions,
|
||||
COALESCE(
|
||||
(
|
||||
SELECT json_agg(projects.project_key)
|
||||
FROM public.projects
|
||||
LEFT JOIN public.roles_projects USING (project_id)
|
||||
WHERE roles_projects.role_id = roles.role_id
|
||||
),
|
||||
'[]'
|
||||
) AS project_keys
|
||||
FROM public.users
|
||||
LEFT JOIN public.roles ON roles.role_id = users.role_id
|
||||
WHERE users.tenant_id = {tenant_id} AND users.deleted_at IS NULL AND users.user_id = {resource_id}
|
||||
"""
|
||||
)
|
||||
item = cur.fetchone()
|
||||
if item:
|
||||
return convert_provider_resource_to_client_resource(item)
|
||||
return None
|
||||
|
||||
|
||||
def delete_resource(resource_id: str, tenatn_id: int) -> None:
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
UPDATE public.users
|
||||
SET
|
||||
deleted_at = NULL,
|
||||
updated_at = now()
|
||||
WHERE users.user_id = %(user_id)s
|
||||
""",
|
||||
{"user_id": resource_id},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def search_existing(tenant_id: int, resource: Resource) -> dict | None:
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
SELECT *
|
||||
FROM public.users
|
||||
WHERE email = %(email)s
|
||||
""",
|
||||
{"email": resource.user_name},
|
||||
)
|
||||
)
|
||||
item = cur.fetchone()
|
||||
if item:
|
||||
return convert_provider_resource_to_client_resource(item)
|
||||
return None
|
||||
|
||||
|
||||
def restore_resource(tenant_id: int, resource: Resource) -> dict | None:
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
SELECT role_id
|
||||
FROM public.users
|
||||
WHERE user_id = %(user_id)s
|
||||
""",
|
||||
{"user_id": resource.id},
|
||||
)
|
||||
)
|
||||
item = cur.fetchone()
|
||||
if item and item["role_id"] is not None:
|
||||
_update_role_projects_and_permissions(
|
||||
item["role_id"],
|
||||
resource.OpenreplayUser.project_keys,
|
||||
resource.OpenreplayUser.permissions,
|
||||
cur,
|
||||
)
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
WITH u AS (
|
||||
UPDATE public.users
|
||||
SET
|
||||
tenant_id = %(tenant_id)s,
|
||||
email = %(email)s,
|
||||
name = %(name)s,
|
||||
internal_id = %(internal_id)s,
|
||||
deleted_at = NULL,
|
||||
created_at = now(),
|
||||
updated_at = now(),
|
||||
api_key = default,
|
||||
jwt_iat = NULL,
|
||||
weekly_report = default
|
||||
WHERE users.email = %(email)s
|
||||
RETURNING *
|
||||
)
|
||||
SELECT
|
||||
u.*,
|
||||
roles.permissions AS permissions,
|
||||
COALESCE(
|
||||
(
|
||||
SELECT json_agg(projects.project_key)
|
||||
FROM public.projects
|
||||
LEFT JOIN public.roles_projects USING (project_id)
|
||||
WHERE roles_projects.role_id = roles.role_id
|
||||
),
|
||||
'[]'
|
||||
) AS project_keys
|
||||
FROM u
|
||||
LEFT JOIN public.roles ON roles.role_id = u.role_id
|
||||
""",
|
||||
{
|
||||
"tenant_id": tenant_id,
|
||||
"email": resource.user_name,
|
||||
"name": " ".join(
|
||||
[
|
||||
x
|
||||
for x in [
|
||||
resource.name.honorific_prefix,
|
||||
resource.name.given_name,
|
||||
resource.name.middle_name,
|
||||
resource.name.family_name,
|
||||
resource.name.honorific_suffix,
|
||||
]
|
||||
if x
|
||||
]
|
||||
)
|
||||
if resource.name
|
||||
else "",
|
||||
"internal_id": resource.external_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
item = cur.fetchone()
|
||||
return convert_provider_resource_to_client_resource(item)
|
||||
|
||||
|
||||
def create_resource(tenant_id: int, resource: Resource) -> dict:
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
WITH u AS (
|
||||
INSERT INTO public.users (
|
||||
tenant_id,
|
||||
email,
|
||||
name,
|
||||
internal_id
|
||||
)
|
||||
VALUES (
|
||||
%(tenant_id)s,
|
||||
%(email)s,
|
||||
%(name)s,
|
||||
%(internal_id)s
|
||||
)
|
||||
RETURNING *
|
||||
)
|
||||
SELECT *
|
||||
FROM u
|
||||
""",
|
||||
{
|
||||
"tenant_id": tenant_id,
|
||||
"email": resource.user_name,
|
||||
"name": " ".join(
|
||||
[
|
||||
x
|
||||
for x in [
|
||||
resource.name.honorific_prefix,
|
||||
resource.name.given_name,
|
||||
resource.name.middle_name,
|
||||
resource.name.family_name,
|
||||
resource.name.honorific_suffix,
|
||||
]
|
||||
if x
|
||||
]
|
||||
)
|
||||
if resource.name
|
||||
else "",
|
||||
"internal_id": resource.external_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
item = cur.fetchone()
|
||||
return convert_provider_resource_to_client_resource(item)
|
||||
|
||||
|
||||
def update_resource(tenant_id: int, resource: Resource) -> dict | None:
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
SELECT role_id
|
||||
FROM public.users
|
||||
WHERE user_id = %(user_id)s
|
||||
""",
|
||||
{"user_id": resource.id},
|
||||
)
|
||||
)
|
||||
item = cur.fetchone()
|
||||
if item and item["role_id"] is not None:
|
||||
_update_role_projects_and_permissions(
|
||||
item["role_id"],
|
||||
resource.OpenreplayUser.project_keys,
|
||||
resource.OpenreplayUser.permissions,
|
||||
cur,
|
||||
)
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
WITH u AS (
|
||||
UPDATE public.users
|
||||
SET
|
||||
tenant_id = %(tenant_id)s,
|
||||
email = %(email)s,
|
||||
name = %(name)s,
|
||||
internal_id = %(internal_id)s,
|
||||
updated_at = now()
|
||||
WHERE user_id = %(user_id)s
|
||||
RETURNING *
|
||||
)
|
||||
SELECT
|
||||
u.*,
|
||||
roles.permissions AS permissions,
|
||||
COALESCE(
|
||||
(
|
||||
SELECT json_agg(projects.project_key)
|
||||
FROM public.projects
|
||||
LEFT JOIN public.roles_projects USING (project_id)
|
||||
WHERE roles_projects.role_id = roles.role_id
|
||||
),
|
||||
'[]'
|
||||
) AS project_keys
|
||||
FROM u
|
||||
LEFT JOIN public.roles ON roles.role_id = u.role_id
|
||||
""",
|
||||
{
|
||||
"user_id": resource.id,
|
||||
"tenant_id": tenant_id,
|
||||
"email": resource.user_name,
|
||||
"name": " ".join(
|
||||
[
|
||||
x
|
||||
for x in [
|
||||
resource.name.honorific_prefix,
|
||||
resource.name.given_name,
|
||||
resource.name.middle_name,
|
||||
resource.name.family_name,
|
||||
resource.name.honorific_suffix,
|
||||
]
|
||||
if x
|
||||
]
|
||||
)
|
||||
if resource.name
|
||||
else "",
|
||||
"internal_id": resource.external_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
item = cur.fetchone()
|
||||
return convert_provider_resource_to_client_resource(item)
|
||||
|
||||
|
||||
def _update_role_projects_and_permissions(
|
||||
role_id: int,
|
||||
project_keys: list[str] | None,
|
||||
permissions: list[str] | None,
|
||||
cur: pg_client.PostgresClient,
|
||||
) -> None:
|
||||
all_projects = "true" if not project_keys else "false"
|
||||
project_key_clause = helpers.safe_mogrify_array(project_keys, "varchar", cur)
|
||||
permission_clause = helpers.safe_mogrify_array(permissions, "varchar", cur)
|
||||
cur.execute(
|
||||
f"""
|
||||
UPDATE public.roles
|
||||
SET
|
||||
updated_at = now(),
|
||||
all_projects = {all_projects},
|
||||
permissions = {permission_clause}
|
||||
WHERE role_id = {role_id}
|
||||
RETURNING *
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
f"""
|
||||
DELETE FROM public.roles_projects
|
||||
WHERE roles_projects.role_id = {role_id}
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
f"""
|
||||
INSERT INTO public.roles_projects (role_id, project_id)
|
||||
SELECT {role_id}, projects.project_id
|
||||
FROM public.projects
|
||||
WHERE projects.project_key = ANY({project_key_clause})
|
||||
"""
|
||||
)
|
||||
|
|
@ -108,6 +108,16 @@ CREATE TABLE public.tenants
|
|||
);
|
||||
|
||||
|
||||
CREATE TABLE public.scim_auth_codes
|
||||
(
|
||||
auth_code_id integer generated BY DEFAULT AS IDENTITY PRIMARY KEY,
|
||||
tenant_id integer NOT NULL REFERENCES public.tenants (tenant_id) ON DELETE CASCADE,
|
||||
auth_code text NOT NULL UNIQUE DEFAULT generate_api_key(20),
|
||||
created_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'),
|
||||
used bool NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
|
||||
CREATE TABLE public.roles
|
||||
(
|
||||
role_id integer generated BY DEFAULT AS IDENTITY PRIMARY KEY,
|
||||
|
|
@ -118,6 +128,7 @@ CREATE TABLE public.roles
|
|||
protected bool NOT NULL DEFAULT FALSE,
|
||||
all_projects bool NOT NULL DEFAULT TRUE,
|
||||
created_at timestamp NOT NULL DEFAULT timezone('utc'::text, now()),
|
||||
updated_at timestamp NOT NULL DEFAULT timezone('utc'::text, now()),
|
||||
deleted_at timestamp NULL DEFAULT NULL,
|
||||
service_role bool NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
|
@ -132,6 +143,7 @@ CREATE TABLE public.users
|
|||
role user_role NOT NULL DEFAULT 'member',
|
||||
name text NOT NULL,
|
||||
created_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'),
|
||||
updated_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'),
|
||||
deleted_at timestamp without time zone NULL DEFAULT NULL,
|
||||
api_key text UNIQUE DEFAULT generate_api_key(20) NOT NULL,
|
||||
jwt_iat timestamp without time zone NULL DEFAULT NULL,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue