From cd70633d1fe7bd5daab23e6fdcd2c16c9f44e403 Mon Sep 17 00:00:00 2001 From: Pavel Kim Date: Thu, 20 Feb 2025 14:31:32 +0100 Subject: [PATCH] [Draft] add auth flow with JWT --- ee/api/chalicelib/core/tenants.py | 14 +++++ ee/api/chalicelib/utils/scim_auth.py | 77 ++++++++++++++++++++++++++++ ee/api/routers/scim.py | 44 ++++++++++++---- 3 files changed, 126 insertions(+), 9 deletions(-) create mode 100644 ee/api/chalicelib/utils/scim_auth.py diff --git a/ee/api/chalicelib/core/tenants.py b/ee/api/chalicelib/core/tenants.py index ca2d59dde..84f2c6d3a 100644 --- a/ee/api/chalicelib/core/tenants.py +++ b/ee/api/chalicelib/core/tenants.py @@ -56,6 +56,20 @@ def get_by_api_key(api_key): return helper.dict_to_camel_case(cur.fetchone()) +def get_by_name(name): + with pg_client.PostgresClient() as cur: + query = cur.mogrify(f"""SELECT tenants.tenant_id, + tenants.name, + tenants.created_at + FROM public.tenants + WHERE tenants.name = %(name)s + AND tenants.deleted_at ISNULL + LIMIT 1;""", + {"name": name}) + cur.execute(query=query) + return helper.dict_to_camel_case(cur.fetchone()) + + def generate_new_api_key(tenant_id): with pg_client.PostgresClient() as cur: query = cur.mogrify(f"""UPDATE public.tenants diff --git a/ee/api/chalicelib/utils/scim_auth.py b/ee/api/chalicelib/utils/scim_auth.py new file mode 100644 index 000000000..83e779c40 --- /dev/null +++ b/ee/api/chalicelib/utils/scim_auth.py @@ -0,0 +1,77 @@ +import logging +import time +import jwt + +from decouple import config +from fastapi import HTTPException, Depends +from fastapi.security import OAuth2PasswordBearer + +logger = logging.getLogger(__name__) + +ACCESS_SECRET_KEY = config("SCIM_ACCESS_SECRET_KEY") +REFRESH_SECRET_KEY = config("SCIM_REFRESH_SECRET_KEY") +ALGORITHM = config("SCIM_JWT_ALGORITHM") +ACCESS_TOKEN_EXPIRE_SECONDS = int(config("SCIM_ACCESS_TOKEN_EXPIRE_SECONDS")) +REFRESH_TOKEN_EXPIRE_SECONDS = int(config("SCIM_REFRESH_TOKEN_EXPIRE_SECONDS")) +AUDIENCE="okta_client" +ISSUER=config("JWT_ISSUER"), + +# Simulated Okta Client Credentials +# OKTA_CLIENT_ID = "okta-client" +# OKTA_CLIENT_SECRET = "okta-secret" + +# class TokenRequest(BaseModel): +# client_id: str +# client_secret: str + +# async def authenticate_client(token_request: TokenRequest): +# """Validate Okta Client Credentials and issue JWT""" +# if token_request.client_id != OKTA_CLIENT_ID or token_request.client_secret != OKTA_CLIENT_SECRET: +# raise HTTPException(status_code=401, detail="Invalid client credentials") + +# return {"access_token": create_jwt(), "token_type": "bearer"} + +def create_tokens(tenant_id): + curr_time = time.time() + access_payload = { + "tenant_id": tenant_id, + "sub": "scim_server", + "aud": AUDIENCE, + "iss": ISSUER, + "exp": "" + } + access_payload.update({"exp": curr_time + ACCESS_TOKEN_EXPIRE_SECONDS}) + access_token = jwt.encode(access_payload, ACCESS_SECRET_KEY, algorithm=ALGORITHM) + + refresh_payload = access_payload.copy() + refresh_payload.update({"exp": curr_time + REFRESH_TOKEN_EXPIRE_SECONDS}) + refresh_token = jwt.encode(refresh_payload, REFRESH_SECRET_KEY, algorithm=ALGORITHM) + + return access_token, refresh_token + +def verify_access_token(token: str): + try: + payload = jwt.decode(token, ACCESS_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE) + return payload + except jwt.ExpiredSignatureError: + raise HTTPException(status_code=401, detail="Token expired") + except jwt.InvalidTokenError: + raise HTTPException(status_code=401, detail="Invalid token") + +def verify_refresh_token(token: str): + try: + payload = jwt.decode(token, REFRESH_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE) + return payload + except jwt.ExpiredSignatureError: + raise HTTPException(status_code=401, detail="Token expired") + except jwt.InvalidTokenError: + raise HTTPException(status_code=401, detail="Invalid token") + + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +# Authentication Dependency +def auth_required(token: str = Depends(oauth2_scheme)): + """Dependency to check Authorization header.""" + if config("SCIM_AUTH_TYPE") == "OAuth2": + payload = verify_access_token(token) + return payload["tenant_id"] diff --git a/ee/api/routers/scim.py b/ee/api/routers/scim.py index 85a777c90..8a0975492 100644 --- a/ee/api/routers/scim.py +++ b/ee/api/routers/scim.py @@ -6,25 +6,50 @@ from typing import Optional from decouple import config from fastapi import Depends, HTTPException, Header, Query, Response from fastapi.responses import JSONResponse +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from pydantic import BaseModel, Field import schemas -from chalicelib.core import users, roles +from chalicelib.core import users, roles, tenants +from chalicelib.utils.scim_auth import auth_required, create_tokens, verify_refresh_token from routers.base import get_routers + logger = logging.getLogger(__name__) - public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2") -# Authentication Dependency -def auth_required(authorization: str = Header(..., alias="Authorization")): - """Dependency to check Authorization header.""" - token = authorization.replace("Bearer ", "") - if token != config("OCTA_TOKEN"): - raise HTTPException(status_code=403, detail="Unauthorized") - return token +"""Authentication endpoints""" + +class RefreshRequest(BaseModel): + refresh_token: str + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + +# Login endpoint to generate tokens +@public_app.post("/token") +async def login(host: str = Header(..., alias="Host"), form_data: OAuth2PasswordRequestForm = Depends()): + subdomain = host.split(".")[0] + + # Missing authentication part, to add + if form_data.username != config("SCIM_USER") or form_data.password != config("SCIM_PASSWORD"): + raise HTTPException(status_code=401, detail="Invalid credentials") + + subdomain = "Openreplay EE" + tenant = tenants.get_by_name(subdomain) + access_token, refresh_token = create_tokens(tenant_id=tenant["tenantId"]) + + return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"} + +# Refresh token endpoint +@public_app.post("/refresh") +async def refresh_token(r: RefreshRequest): + + payload = verify_refresh_token(r.refresh_token) + new_access_token, _ = create_tokens(tenant_id=payload["tenant_id"]) + + return {"access_token": new_access_token, "token_type": "Bearer"} """ User endpoints @@ -431,6 +456,7 @@ def update_patch_group(group_id: str, r: GroupPatchRequest): @public_app.delete("/Groups/{group_id}", dependencies=[Depends(auth_required)]) def delete_group(group_id: str): """Delete a group, hard-delete""" + # possibly need to set the user's roles to default member role, instead of null tenant_id = 1 group = roles.get_role_by_group_id(tenant_id, group_id) if not group: