[Draft] add auth flow with JWT

This commit is contained in:
Pavel Kim 2025-02-20 14:31:32 +01:00 committed by Jonathan Griffin
parent 937e4d244c
commit cd70633d1f
3 changed files with 126 additions and 9 deletions

View file

@ -56,6 +56,20 @@ def get_by_api_key(api_key):
return helper.dict_to_camel_case(cur.fetchone()) return helper.dict_to_camel_case(cur.fetchone())
def get_by_name(name):
with pg_client.PostgresClient() as cur:
query = cur.mogrify(f"""SELECT tenants.tenant_id,
tenants.name,
tenants.created_at
FROM public.tenants
WHERE tenants.name = %(name)s
AND tenants.deleted_at ISNULL
LIMIT 1;""",
{"name": name})
cur.execute(query=query)
return helper.dict_to_camel_case(cur.fetchone())
def generate_new_api_key(tenant_id): def generate_new_api_key(tenant_id):
with pg_client.PostgresClient() as cur: with pg_client.PostgresClient() as cur:
query = cur.mogrify(f"""UPDATE public.tenants query = cur.mogrify(f"""UPDATE public.tenants

View file

@ -0,0 +1,77 @@
import logging
import time
import jwt
from decouple import config
from fastapi import HTTPException, Depends
from fastapi.security import OAuth2PasswordBearer
logger = logging.getLogger(__name__)
ACCESS_SECRET_KEY = config("SCIM_ACCESS_SECRET_KEY")
REFRESH_SECRET_KEY = config("SCIM_REFRESH_SECRET_KEY")
ALGORITHM = config("SCIM_JWT_ALGORITHM")
ACCESS_TOKEN_EXPIRE_SECONDS = int(config("SCIM_ACCESS_TOKEN_EXPIRE_SECONDS"))
REFRESH_TOKEN_EXPIRE_SECONDS = int(config("SCIM_REFRESH_TOKEN_EXPIRE_SECONDS"))
AUDIENCE="okta_client"
ISSUER=config("JWT_ISSUER"),
# Simulated Okta Client Credentials
# OKTA_CLIENT_ID = "okta-client"
# OKTA_CLIENT_SECRET = "okta-secret"
# class TokenRequest(BaseModel):
# client_id: str
# client_secret: str
# async def authenticate_client(token_request: TokenRequest):
# """Validate Okta Client Credentials and issue JWT"""
# if token_request.client_id != OKTA_CLIENT_ID or token_request.client_secret != OKTA_CLIENT_SECRET:
# raise HTTPException(status_code=401, detail="Invalid client credentials")
# return {"access_token": create_jwt(), "token_type": "bearer"}
def create_tokens(tenant_id):
curr_time = time.time()
access_payload = {
"tenant_id": tenant_id,
"sub": "scim_server",
"aud": AUDIENCE,
"iss": ISSUER,
"exp": ""
}
access_payload.update({"exp": curr_time + ACCESS_TOKEN_EXPIRE_SECONDS})
access_token = jwt.encode(access_payload, ACCESS_SECRET_KEY, algorithm=ALGORITHM)
refresh_payload = access_payload.copy()
refresh_payload.update({"exp": curr_time + REFRESH_TOKEN_EXPIRE_SECONDS})
refresh_token = jwt.encode(refresh_payload, REFRESH_SECRET_KEY, algorithm=ALGORITHM)
return access_token, refresh_token
def verify_access_token(token: str):
try:
payload = jwt.decode(token, ACCESS_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE)
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")
def verify_refresh_token(token: str):
try:
payload = jwt.decode(token, REFRESH_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE)
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Authentication Dependency
def auth_required(token: str = Depends(oauth2_scheme)):
"""Dependency to check Authorization header."""
if config("SCIM_AUTH_TYPE") == "OAuth2":
payload = verify_access_token(token)
return payload["tenant_id"]

View file

@ -6,25 +6,50 @@ from typing import Optional
from decouple import config from decouple import config
from fastapi import Depends, HTTPException, Header, Query, Response from fastapi import Depends, HTTPException, Header, Query, Response
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import schemas import schemas
from chalicelib.core import users, roles from chalicelib.core import users, roles, tenants
from chalicelib.utils.scim_auth import auth_required, create_tokens, verify_refresh_token
from routers.base import get_routers from routers.base import get_routers
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2") public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2")
# Authentication Dependency
def auth_required(authorization: str = Header(..., alias="Authorization")):
"""Dependency to check Authorization header."""
token = authorization.replace("Bearer ", "")
if token != config("OCTA_TOKEN"):
raise HTTPException(status_code=403, detail="Unauthorized")
return token
"""Authentication endpoints"""
class RefreshRequest(BaseModel):
refresh_token: str
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Login endpoint to generate tokens
@public_app.post("/token")
async def login(host: str = Header(..., alias="Host"), form_data: OAuth2PasswordRequestForm = Depends()):
subdomain = host.split(".")[0]
# Missing authentication part, to add
if form_data.username != config("SCIM_USER") or form_data.password != config("SCIM_PASSWORD"):
raise HTTPException(status_code=401, detail="Invalid credentials")
subdomain = "Openreplay EE"
tenant = tenants.get_by_name(subdomain)
access_token, refresh_token = create_tokens(tenant_id=tenant["tenantId"])
return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"}
# Refresh token endpoint
@public_app.post("/refresh")
async def refresh_token(r: RefreshRequest):
payload = verify_refresh_token(r.refresh_token)
new_access_token, _ = create_tokens(tenant_id=payload["tenant_id"])
return {"access_token": new_access_token, "token_type": "Bearer"}
""" """
User endpoints User endpoints
@ -431,6 +456,7 @@ def update_patch_group(group_id: str, r: GroupPatchRequest):
@public_app.delete("/Groups/{group_id}", dependencies=[Depends(auth_required)]) @public_app.delete("/Groups/{group_id}", dependencies=[Depends(auth_required)])
def delete_group(group_id: str): def delete_group(group_id: str):
"""Delete a group, hard-delete""" """Delete a group, hard-delete"""
# possibly need to set the user's roles to default member role, instead of null
tenant_id = 1 tenant_id = 1
group = roles.get_role_by_group_id(tenant_id, group_id) group = roles.get_role_by_group_id(tenant_id, group_id)
if not group: if not group: