update token generation for automatic refresh with okta

This commit is contained in:
Jonathan Griffin 2025-05-15 18:01:33 +02:00
parent 7dd7389d3b
commit 34bbde13c1
4 changed files with 115 additions and 45 deletions

View file

@ -13,24 +13,9 @@ REFRESH_SECRET_KEY = config("SCIM_REFRESH_SECRET_KEY")
ALGORITHM = config("SCIM_JWT_ALGORITHM")
ACCESS_TOKEN_EXPIRE_SECONDS = int(config("SCIM_ACCESS_TOKEN_EXPIRE_SECONDS"))
REFRESH_TOKEN_EXPIRE_SECONDS = int(config("SCIM_REFRESH_TOKEN_EXPIRE_SECONDS"))
AUDIENCE = "okta_client"
AUDIENCE = config("SCIM_AUDIENCE")
ISSUER = (config("JWT_ISSUER"),)
# Simulated Okta Client Credentials
# OKTA_CLIENT_ID = "okta-client"
# OKTA_CLIENT_SECRET = "okta-secret"
# class TokenRequest(BaseModel):
# client_id: str
# client_secret: str
# async def authenticate_client(token_request: TokenRequest):
# """Validate Okta Client Credentials and issue JWT"""
# if token_request.client_id != OKTA_CLIENT_ID or token_request.client_secret != OKTA_CLIENT_SECRET:
# raise HTTPException(status_code=401, detail="Invalid client credentials")
# return {"access_token": create_jwt(), "token_type": "bearer"}
def create_tokens(tenant_id):
curr_time = time.time()
@ -48,7 +33,7 @@ def create_tokens(tenant_id):
refresh_payload.update({"exp": curr_time + REFRESH_TOKEN_EXPIRE_SECONDS})
refresh_token = jwt.encode(refresh_payload, REFRESH_SECRET_KEY, algorithm=ALGORITHM)
return access_token, refresh_token
return access_token, refresh_token, ACCESS_TOKEN_EXPIRE_SECONDS
def verify_access_token(token: str):

View file

@ -1,15 +1,15 @@
import logging
from copy import deepcopy
from enum import Enum
from urllib.parse import urlencode
from decouple import config
from fastapi import Depends, HTTPException, Header, Query, Response, Request
from fastapi.responses import JSONResponse
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel
from chalicelib.utils import pg_client
from fastapi import Depends, HTTPException, Query, Response, Request
from fastapi.responses import JSONResponse, RedirectResponse
from psycopg2 import errors
from chalicelib.core import roles, tenants
from chalicelib.core import roles
from chalicelib.utils.scim_auth import (
auth_optional,
auth_required,
@ -33,37 +33,110 @@ public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2")
@public_app.post("/token")
async def post_token(
host: str = Header(..., alias="Host"),
form_data: OAuth2PasswordRequestForm = Depends(),
):
subdomain = host.split(".")[0]
async def post_token(r: Request):
form = await r.form()
# Missing authentication part, to add
if form_data.username != config("SCIM_USER") or form_data.password != config(
"SCIM_PASSWORD"
):
raise HTTPException(status_code=401, detail="Invalid credentials")
client_id = form.get("client_id")
client_secret = form.get("client_secret")
with pg_client.PostgresClient() as cur:
try:
cur.execute(
cur.mogrify(
"""
SELECT tenant_id
FROM public.tenants
WHERE tenant_id=%(tenant_id)s AND tenant_key=%(tenant_key)s
""",
{"tenant_id": int(client_id), "tenant_key": client_secret},
)
)
except ValueError:
raise HTTPException(status_code=401, detail="Invalid credentials")
tenant = tenants.get_by_name(subdomain)
access_token, refresh_token = create_tokens(tenant_id=tenant["tenantId"])
tenant = cur.fetchone()
if not tenant:
raise HTTPException(status_code=401, detail="Invalid credentials")
grant_type = form.get("grant_type")
if grant_type == "refresh_token":
refresh_token = form.get("refresh_token")
verify_refresh_token(refresh_token)
else:
code = form.get("code")
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
SELECT *
FROM public.scim_auth_codes
WHERE auth_code=%(auth_code)s AND tenant_id=%(tenant_id)s AND used IS FALSE
""",
{"auth_code": code, "tenant_id": int(client_id)},
)
)
if cur.fetchone() is None:
raise HTTPException(
status_code=401, detail="Invalid code/client_id pair"
)
cur.execute(
cur.mogrify(
"""
UPDATE public.scim_auth_codes
SET used=TRUE
WHERE auth_code=%(auth_code)s AND tenant_id=%(tenant_id)s AND used IS FALSE
""",
{"auth_code": code, "tenant_id": int(client_id)},
)
)
access_token, refresh_token, expires_in = create_tokens(
tenant_id=tenant["tenant_id"]
)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "Bearer",
"expires_in": expires_in,
"refresh_token": refresh_token,
}
class RefreshRequest(BaseModel):
refresh_token: str
@public_app.post("/refresh")
async def post_refresh(r: RefreshRequest):
payload = verify_refresh_token(r.refresh_token)
new_access_token, _ = create_tokens(tenant_id=payload["tenant_id"])
return {"access_token": new_access_token, "token_type": "Bearer"}
# note(jon): this might be specific to okta. if so, we should probably put specify that in the endpoint
@public_app.get("/authorize")
async def get_authorize(
r: Request,
response_type: str,
client_id: str,
redirect_uri: str,
state: str | None = None,
):
with pg_client.PostgresClient() as cur:
cur.execute(
cur.mogrify(
"""
UPDATE public.scim_auth_codes
SET used=TRUE
WHERE tenant_id=%(tenant_id)s
""",
{"tenant_id": int(client_id)},
)
)
cur.execute(
cur.mogrify(
"""
INSERT INTO public.scim_auth_codes (tenant_id)
VALUES (%(tenant_id)s)
RETURNING auth_code
""",
{"tenant_id": int(client_id)},
)
)
code = cur.fetchone()["auth_code"]
params = {"code": code}
if state:
params["state"] = state
url = f"{redirect_uri}?{urlencode(params)}"
return RedirectResponse(url)
def _not_found_error_response(resource_id: int):

View file

@ -255,6 +255,8 @@ def _update_role_projects_and_permissions(
permissions: list[str] | None,
cur: pg_client.PostgresClient,
) -> None:
if role_id is None:
return
all_projects = "true" if not project_keys else "false"
project_key_clause = helpers.safe_mogrify_array(project_keys, "varchar", cur)
permission_clause = helpers.safe_mogrify_array(permissions, "varchar", cur)

View file

@ -108,6 +108,16 @@ CREATE TABLE public.tenants
);
CREATE TABLE public.scim_auth_codes
(
auth_code_id integer generated BY DEFAULT AS IDENTITY PRIMARY KEY,
tenant_id integer NOT NULL REFERENCES public.tenants (tenant_id) ON DELETE CASCADE,
auth_code text NOT NULL UNIQUE DEFAULT generate_api_key(20),
created_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'),
used bool NOT NULL DEFAULT FALSE
);
CREATE TABLE public.roles
(
role_id integer generated BY DEFAULT AS IDENTITY PRIMARY KEY,