From 34bbde13c1f5a0bb19ad0081712f0a1f7e87cbda Mon Sep 17 00:00:00 2001 From: Jonathan Griffin Date: Thu, 15 May 2025 18:01:33 +0200 Subject: [PATCH] update token generation for automatic refresh with okta --- ee/api/chalicelib/utils/scim_auth.py | 19 +-- ee/api/routers/scim/api.py | 129 ++++++++++++++---- ee/api/routers/scim/users.py | 2 + .../db/init_dbs/postgresql/init_schema.sql | 10 ++ 4 files changed, 115 insertions(+), 45 deletions(-) diff --git a/ee/api/chalicelib/utils/scim_auth.py b/ee/api/chalicelib/utils/scim_auth.py index a8deaa136..c31dcd058 100644 --- a/ee/api/chalicelib/utils/scim_auth.py +++ b/ee/api/chalicelib/utils/scim_auth.py @@ -13,24 +13,9 @@ REFRESH_SECRET_KEY = config("SCIM_REFRESH_SECRET_KEY") ALGORITHM = config("SCIM_JWT_ALGORITHM") ACCESS_TOKEN_EXPIRE_SECONDS = int(config("SCIM_ACCESS_TOKEN_EXPIRE_SECONDS")) REFRESH_TOKEN_EXPIRE_SECONDS = int(config("SCIM_REFRESH_TOKEN_EXPIRE_SECONDS")) -AUDIENCE = "okta_client" +AUDIENCE = config("SCIM_AUDIENCE") ISSUER = (config("JWT_ISSUER"),) -# Simulated Okta Client Credentials -# OKTA_CLIENT_ID = "okta-client" -# OKTA_CLIENT_SECRET = "okta-secret" - -# class TokenRequest(BaseModel): -# client_id: str -# client_secret: str - -# async def authenticate_client(token_request: TokenRequest): -# """Validate Okta Client Credentials and issue JWT""" -# if token_request.client_id != OKTA_CLIENT_ID or token_request.client_secret != OKTA_CLIENT_SECRET: -# raise HTTPException(status_code=401, detail="Invalid client credentials") - -# return {"access_token": create_jwt(), "token_type": "bearer"} - def create_tokens(tenant_id): curr_time = time.time() @@ -48,7 +33,7 @@ def create_tokens(tenant_id): refresh_payload.update({"exp": curr_time + REFRESH_TOKEN_EXPIRE_SECONDS}) refresh_token = jwt.encode(refresh_payload, REFRESH_SECRET_KEY, algorithm=ALGORITHM) - return access_token, refresh_token + return access_token, refresh_token, ACCESS_TOKEN_EXPIRE_SECONDS def verify_access_token(token: str): diff --git a/ee/api/routers/scim/api.py b/ee/api/routers/scim/api.py index 2508fc7bf..07482f10c 100644 --- a/ee/api/routers/scim/api.py +++ b/ee/api/routers/scim/api.py @@ -1,15 +1,15 @@ import logging from copy import deepcopy from enum import Enum +from urllib.parse import urlencode -from decouple import config -from fastapi import Depends, HTTPException, Header, Query, Response, Request -from fastapi.responses import JSONResponse -from fastapi.security import OAuth2PasswordRequestForm -from pydantic import BaseModel +from chalicelib.utils import pg_client + +from fastapi import Depends, HTTPException, Query, Response, Request +from fastapi.responses import JSONResponse, RedirectResponse from psycopg2 import errors -from chalicelib.core import roles, tenants +from chalicelib.core import roles from chalicelib.utils.scim_auth import ( auth_optional, auth_required, @@ -33,37 +33,110 @@ public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2") @public_app.post("/token") -async def post_token( - host: str = Header(..., alias="Host"), - form_data: OAuth2PasswordRequestForm = Depends(), -): - subdomain = host.split(".")[0] +async def post_token(r: Request): + form = await r.form() - # Missing authentication part, to add - if form_data.username != config("SCIM_USER") or form_data.password != config( - "SCIM_PASSWORD" - ): - raise HTTPException(status_code=401, detail="Invalid credentials") + client_id = form.get("client_id") + client_secret = form.get("client_secret") + with pg_client.PostgresClient() as cur: + try: + cur.execute( + cur.mogrify( + """ + SELECT tenant_id + FROM public.tenants + WHERE tenant_id=%(tenant_id)s AND tenant_key=%(tenant_key)s + """, + {"tenant_id": int(client_id), "tenant_key": client_secret}, + ) + ) + except ValueError: + raise HTTPException(status_code=401, detail="Invalid credentials") - tenant = tenants.get_by_name(subdomain) - access_token, refresh_token = create_tokens(tenant_id=tenant["tenantId"]) + tenant = cur.fetchone() + if not tenant: + raise HTTPException(status_code=401, detail="Invalid credentials") + + grant_type = form.get("grant_type") + if grant_type == "refresh_token": + refresh_token = form.get("refresh_token") + verify_refresh_token(refresh_token) + else: + code = form.get("code") + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + SELECT * + FROM public.scim_auth_codes + WHERE auth_code=%(auth_code)s AND tenant_id=%(tenant_id)s AND used IS FALSE + """, + {"auth_code": code, "tenant_id": int(client_id)}, + ) + ) + if cur.fetchone() is None: + raise HTTPException( + status_code=401, detail="Invalid code/client_id pair" + ) + cur.execute( + cur.mogrify( + """ + UPDATE public.scim_auth_codes + SET used=TRUE + WHERE auth_code=%(auth_code)s AND tenant_id=%(tenant_id)s AND used IS FALSE + """, + {"auth_code": code, "tenant_id": int(client_id)}, + ) + ) + + access_token, refresh_token, expires_in = create_tokens( + tenant_id=tenant["tenant_id"] + ) return { "access_token": access_token, - "refresh_token": refresh_token, "token_type": "Bearer", + "expires_in": expires_in, + "refresh_token": refresh_token, } -class RefreshRequest(BaseModel): - refresh_token: str - - -@public_app.post("/refresh") -async def post_refresh(r: RefreshRequest): - payload = verify_refresh_token(r.refresh_token) - new_access_token, _ = create_tokens(tenant_id=payload["tenant_id"]) - return {"access_token": new_access_token, "token_type": "Bearer"} +# note(jon): this might be specific to okta. if so, we should probably put specify that in the endpoint +@public_app.get("/authorize") +async def get_authorize( + r: Request, + response_type: str, + client_id: str, + redirect_uri: str, + state: str | None = None, +): + with pg_client.PostgresClient() as cur: + cur.execute( + cur.mogrify( + """ + UPDATE public.scim_auth_codes + SET used=TRUE + WHERE tenant_id=%(tenant_id)s + """, + {"tenant_id": int(client_id)}, + ) + ) + cur.execute( + cur.mogrify( + """ + INSERT INTO public.scim_auth_codes (tenant_id) + VALUES (%(tenant_id)s) + RETURNING auth_code + """, + {"tenant_id": int(client_id)}, + ) + ) + code = cur.fetchone()["auth_code"] + params = {"code": code} + if state: + params["state"] = state + url = f"{redirect_uri}?{urlencode(params)}" + return RedirectResponse(url) def _not_found_error_response(resource_id: int): diff --git a/ee/api/routers/scim/users.py b/ee/api/routers/scim/users.py index 14639a083..9e8623188 100644 --- a/ee/api/routers/scim/users.py +++ b/ee/api/routers/scim/users.py @@ -255,6 +255,8 @@ def _update_role_projects_and_permissions( permissions: list[str] | None, cur: pg_client.PostgresClient, ) -> None: + if role_id is None: + return all_projects = "true" if not project_keys else "false" project_key_clause = helpers.safe_mogrify_array(project_keys, "varchar", cur) permission_clause = helpers.safe_mogrify_array(permissions, "varchar", cur) diff --git a/ee/scripts/schema/db/init_dbs/postgresql/init_schema.sql b/ee/scripts/schema/db/init_dbs/postgresql/init_schema.sql index dc231c2a0..f0ea95b84 100644 --- a/ee/scripts/schema/db/init_dbs/postgresql/init_schema.sql +++ b/ee/scripts/schema/db/init_dbs/postgresql/init_schema.sql @@ -108,6 +108,16 @@ CREATE TABLE public.tenants ); +CREATE TABLE public.scim_auth_codes +( + auth_code_id integer generated BY DEFAULT AS IDENTITY PRIMARY KEY, + tenant_id integer NOT NULL REFERENCES public.tenants (tenant_id) ON DELETE CASCADE, + auth_code text NOT NULL UNIQUE DEFAULT generate_api_key(20), + created_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'), + used bool NOT NULL DEFAULT FALSE +); + + CREATE TABLE public.roles ( role_id integer generated BY DEFAULT AS IDENTITY PRIMARY KEY,