[Draft] add auth flow with JWT
This commit is contained in:
parent
937e4d244c
commit
cd70633d1f
3 changed files with 126 additions and 9 deletions
|
|
@ -56,6 +56,20 @@ def get_by_api_key(api_key):
|
||||||
return helper.dict_to_camel_case(cur.fetchone())
|
return helper.dict_to_camel_case(cur.fetchone())
|
||||||
|
|
||||||
|
|
||||||
|
def get_by_name(name):
|
||||||
|
with pg_client.PostgresClient() as cur:
|
||||||
|
query = cur.mogrify(f"""SELECT tenants.tenant_id,
|
||||||
|
tenants.name,
|
||||||
|
tenants.created_at
|
||||||
|
FROM public.tenants
|
||||||
|
WHERE tenants.name = %(name)s
|
||||||
|
AND tenants.deleted_at ISNULL
|
||||||
|
LIMIT 1;""",
|
||||||
|
{"name": name})
|
||||||
|
cur.execute(query=query)
|
||||||
|
return helper.dict_to_camel_case(cur.fetchone())
|
||||||
|
|
||||||
|
|
||||||
def generate_new_api_key(tenant_id):
|
def generate_new_api_key(tenant_id):
|
||||||
with pg_client.PostgresClient() as cur:
|
with pg_client.PostgresClient() as cur:
|
||||||
query = cur.mogrify(f"""UPDATE public.tenants
|
query = cur.mogrify(f"""UPDATE public.tenants
|
||||||
|
|
|
||||||
77
ee/api/chalicelib/utils/scim_auth.py
Normal file
77
ee/api/chalicelib/utils/scim_auth.py
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
from decouple import config
|
||||||
|
from fastapi import HTTPException, Depends
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ACCESS_SECRET_KEY = config("SCIM_ACCESS_SECRET_KEY")
|
||||||
|
REFRESH_SECRET_KEY = config("SCIM_REFRESH_SECRET_KEY")
|
||||||
|
ALGORITHM = config("SCIM_JWT_ALGORITHM")
|
||||||
|
ACCESS_TOKEN_EXPIRE_SECONDS = int(config("SCIM_ACCESS_TOKEN_EXPIRE_SECONDS"))
|
||||||
|
REFRESH_TOKEN_EXPIRE_SECONDS = int(config("SCIM_REFRESH_TOKEN_EXPIRE_SECONDS"))
|
||||||
|
AUDIENCE="okta_client"
|
||||||
|
ISSUER=config("JWT_ISSUER"),
|
||||||
|
|
||||||
|
# Simulated Okta Client Credentials
|
||||||
|
# OKTA_CLIENT_ID = "okta-client"
|
||||||
|
# OKTA_CLIENT_SECRET = "okta-secret"
|
||||||
|
|
||||||
|
# class TokenRequest(BaseModel):
|
||||||
|
# client_id: str
|
||||||
|
# client_secret: str
|
||||||
|
|
||||||
|
# async def authenticate_client(token_request: TokenRequest):
|
||||||
|
# """Validate Okta Client Credentials and issue JWT"""
|
||||||
|
# if token_request.client_id != OKTA_CLIENT_ID or token_request.client_secret != OKTA_CLIENT_SECRET:
|
||||||
|
# raise HTTPException(status_code=401, detail="Invalid client credentials")
|
||||||
|
|
||||||
|
# return {"access_token": create_jwt(), "token_type": "bearer"}
|
||||||
|
|
||||||
|
def create_tokens(tenant_id):
|
||||||
|
curr_time = time.time()
|
||||||
|
access_payload = {
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"sub": "scim_server",
|
||||||
|
"aud": AUDIENCE,
|
||||||
|
"iss": ISSUER,
|
||||||
|
"exp": ""
|
||||||
|
}
|
||||||
|
access_payload.update({"exp": curr_time + ACCESS_TOKEN_EXPIRE_SECONDS})
|
||||||
|
access_token = jwt.encode(access_payload, ACCESS_SECRET_KEY, algorithm=ALGORITHM)
|
||||||
|
|
||||||
|
refresh_payload = access_payload.copy()
|
||||||
|
refresh_payload.update({"exp": curr_time + REFRESH_TOKEN_EXPIRE_SECONDS})
|
||||||
|
refresh_token = jwt.encode(refresh_payload, REFRESH_SECRET_KEY, algorithm=ALGORITHM)
|
||||||
|
|
||||||
|
return access_token, refresh_token
|
||||||
|
|
||||||
|
def verify_access_token(token: str):
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, ACCESS_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE)
|
||||||
|
return payload
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
raise HTTPException(status_code=401, detail="Token expired")
|
||||||
|
except jwt.InvalidTokenError:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid token")
|
||||||
|
|
||||||
|
def verify_refresh_token(token: str):
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, REFRESH_SECRET_KEY, algorithms=[ALGORITHM], audience=AUDIENCE)
|
||||||
|
return payload
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
raise HTTPException(status_code=401, detail="Token expired")
|
||||||
|
except jwt.InvalidTokenError:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid token")
|
||||||
|
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
# Authentication Dependency
|
||||||
|
def auth_required(token: str = Depends(oauth2_scheme)):
|
||||||
|
"""Dependency to check Authorization header."""
|
||||||
|
if config("SCIM_AUTH_TYPE") == "OAuth2":
|
||||||
|
payload = verify_access_token(token)
|
||||||
|
return payload["tenant_id"]
|
||||||
|
|
@ -6,25 +6,50 @@ from typing import Optional
|
||||||
from decouple import config
|
from decouple import config
|
||||||
from fastapi import Depends, HTTPException, Header, Query, Response
|
from fastapi import Depends, HTTPException, Header, Query, Response
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
import schemas
|
import schemas
|
||||||
from chalicelib.core import users, roles
|
from chalicelib.core import users, roles, tenants
|
||||||
|
from chalicelib.utils.scim_auth import auth_required, create_tokens, verify_refresh_token
|
||||||
from routers.base import get_routers
|
from routers.base import get_routers
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2")
|
public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2")
|
||||||
|
|
||||||
# Authentication Dependency
|
|
||||||
def auth_required(authorization: str = Header(..., alias="Authorization")):
|
|
||||||
"""Dependency to check Authorization header."""
|
|
||||||
token = authorization.replace("Bearer ", "")
|
|
||||||
if token != config("OCTA_TOKEN"):
|
|
||||||
raise HTTPException(status_code=403, detail="Unauthorized")
|
|
||||||
return token
|
|
||||||
|
|
||||||
|
"""Authentication endpoints"""
|
||||||
|
|
||||||
|
class RefreshRequest(BaseModel):
|
||||||
|
refresh_token: str
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
|
# Login endpoint to generate tokens
|
||||||
|
@public_app.post("/token")
|
||||||
|
async def login(host: str = Header(..., alias="Host"), form_data: OAuth2PasswordRequestForm = Depends()):
|
||||||
|
subdomain = host.split(".")[0]
|
||||||
|
|
||||||
|
# Missing authentication part, to add
|
||||||
|
if form_data.username != config("SCIM_USER") or form_data.password != config("SCIM_PASSWORD"):
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||||
|
|
||||||
|
subdomain = "Openreplay EE"
|
||||||
|
tenant = tenants.get_by_name(subdomain)
|
||||||
|
access_token, refresh_token = create_tokens(tenant_id=tenant["tenantId"])
|
||||||
|
|
||||||
|
return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"}
|
||||||
|
|
||||||
|
# Refresh token endpoint
|
||||||
|
@public_app.post("/refresh")
|
||||||
|
async def refresh_token(r: RefreshRequest):
|
||||||
|
|
||||||
|
payload = verify_refresh_token(r.refresh_token)
|
||||||
|
new_access_token, _ = create_tokens(tenant_id=payload["tenant_id"])
|
||||||
|
|
||||||
|
return {"access_token": new_access_token, "token_type": "Bearer"}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
User endpoints
|
User endpoints
|
||||||
|
|
@ -431,6 +456,7 @@ def update_patch_group(group_id: str, r: GroupPatchRequest):
|
||||||
@public_app.delete("/Groups/{group_id}", dependencies=[Depends(auth_required)])
|
@public_app.delete("/Groups/{group_id}", dependencies=[Depends(auth_required)])
|
||||||
def delete_group(group_id: str):
|
def delete_group(group_id: str):
|
||||||
"""Delete a group, hard-delete"""
|
"""Delete a group, hard-delete"""
|
||||||
|
# possibly need to set the user's roles to default member role, instead of null
|
||||||
tenant_id = 1
|
tenant_id = 1
|
||||||
group = roles.get_role_by_group_id(tenant_id, group_id)
|
group = roles.get_role_by_group_id(tenant_id, group_id)
|
||||||
if not group:
|
if not group:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue