update token generation for automatic refresh with okta
This commit is contained in:
parent
7dd7389d3b
commit
34bbde13c1
4 changed files with 115 additions and 45 deletions
|
|
@ -13,24 +13,9 @@ REFRESH_SECRET_KEY = config("SCIM_REFRESH_SECRET_KEY")
|
|||
ALGORITHM = config("SCIM_JWT_ALGORITHM")
|
||||
ACCESS_TOKEN_EXPIRE_SECONDS = int(config("SCIM_ACCESS_TOKEN_EXPIRE_SECONDS"))
|
||||
REFRESH_TOKEN_EXPIRE_SECONDS = int(config("SCIM_REFRESH_TOKEN_EXPIRE_SECONDS"))
|
||||
AUDIENCE = "okta_client"
|
||||
AUDIENCE = config("SCIM_AUDIENCE")
|
||||
ISSUER = (config("JWT_ISSUER"),)
|
||||
|
||||
# Simulated Okta Client Credentials
|
||||
# OKTA_CLIENT_ID = "okta-client"
|
||||
# OKTA_CLIENT_SECRET = "okta-secret"
|
||||
|
||||
# class TokenRequest(BaseModel):
|
||||
# client_id: str
|
||||
# client_secret: str
|
||||
|
||||
# async def authenticate_client(token_request: TokenRequest):
|
||||
# """Validate Okta Client Credentials and issue JWT"""
|
||||
# if token_request.client_id != OKTA_CLIENT_ID or token_request.client_secret != OKTA_CLIENT_SECRET:
|
||||
# raise HTTPException(status_code=401, detail="Invalid client credentials")
|
||||
|
||||
# return {"access_token": create_jwt(), "token_type": "bearer"}
|
||||
|
||||
|
||||
def create_tokens(tenant_id):
|
||||
curr_time = time.time()
|
||||
|
|
@ -48,7 +33,7 @@ def create_tokens(tenant_id):
|
|||
refresh_payload.update({"exp": curr_time + REFRESH_TOKEN_EXPIRE_SECONDS})
|
||||
refresh_token = jwt.encode(refresh_payload, REFRESH_SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
return access_token, refresh_token
|
||||
return access_token, refresh_token, ACCESS_TOKEN_EXPIRE_SECONDS
|
||||
|
||||
|
||||
def verify_access_token(token: str):
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
import logging
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from decouple import config
|
||||
from fastapi import Depends, HTTPException, Header, Query, Response, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from pydantic import BaseModel
|
||||
from chalicelib.utils import pg_client
|
||||
|
||||
from fastapi import Depends, HTTPException, Query, Response, Request
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from psycopg2 import errors
|
||||
|
||||
from chalicelib.core import roles, tenants
|
||||
from chalicelib.core import roles
|
||||
from chalicelib.utils.scim_auth import (
|
||||
auth_optional,
|
||||
auth_required,
|
||||
|
|
@ -33,37 +33,110 @@ public_app, app, app_apikey = get_routers(prefix="/sso/scim/v2")
|
|||
|
||||
|
||||
@public_app.post("/token")
|
||||
async def post_token(
|
||||
host: str = Header(..., alias="Host"),
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
):
|
||||
subdomain = host.split(".")[0]
|
||||
async def post_token(r: Request):
|
||||
form = await r.form()
|
||||
|
||||
# Missing authentication part, to add
|
||||
if form_data.username != config("SCIM_USER") or form_data.password != config(
|
||||
"SCIM_PASSWORD"
|
||||
):
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
client_id = form.get("client_id")
|
||||
client_secret = form.get("client_secret")
|
||||
with pg_client.PostgresClient() as cur:
|
||||
try:
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
SELECT tenant_id
|
||||
FROM public.tenants
|
||||
WHERE tenant_id=%(tenant_id)s AND tenant_key=%(tenant_key)s
|
||||
""",
|
||||
{"tenant_id": int(client_id), "tenant_key": client_secret},
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
tenant = tenants.get_by_name(subdomain)
|
||||
access_token, refresh_token = create_tokens(tenant_id=tenant["tenantId"])
|
||||
tenant = cur.fetchone()
|
||||
if not tenant:
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
grant_type = form.get("grant_type")
|
||||
if grant_type == "refresh_token":
|
||||
refresh_token = form.get("refresh_token")
|
||||
verify_refresh_token(refresh_token)
|
||||
else:
|
||||
code = form.get("code")
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
SELECT *
|
||||
FROM public.scim_auth_codes
|
||||
WHERE auth_code=%(auth_code)s AND tenant_id=%(tenant_id)s AND used IS FALSE
|
||||
""",
|
||||
{"auth_code": code, "tenant_id": int(client_id)},
|
||||
)
|
||||
)
|
||||
if cur.fetchone() is None:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid code/client_id pair"
|
||||
)
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
UPDATE public.scim_auth_codes
|
||||
SET used=TRUE
|
||||
WHERE auth_code=%(auth_code)s AND tenant_id=%(tenant_id)s AND used IS FALSE
|
||||
""",
|
||||
{"auth_code": code, "tenant_id": int(client_id)},
|
||||
)
|
||||
)
|
||||
|
||||
access_token, refresh_token, expires_in = create_tokens(
|
||||
tenant_id=tenant["tenant_id"]
|
||||
)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": expires_in,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
@public_app.post("/refresh")
|
||||
async def post_refresh(r: RefreshRequest):
|
||||
payload = verify_refresh_token(r.refresh_token)
|
||||
new_access_token, _ = create_tokens(tenant_id=payload["tenant_id"])
|
||||
return {"access_token": new_access_token, "token_type": "Bearer"}
|
||||
# note(jon): this might be specific to okta. if so, we should probably put specify that in the endpoint
|
||||
@public_app.get("/authorize")
|
||||
async def get_authorize(
|
||||
r: Request,
|
||||
response_type: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
state: str | None = None,
|
||||
):
|
||||
with pg_client.PostgresClient() as cur:
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
UPDATE public.scim_auth_codes
|
||||
SET used=TRUE
|
||||
WHERE tenant_id=%(tenant_id)s
|
||||
""",
|
||||
{"tenant_id": int(client_id)},
|
||||
)
|
||||
)
|
||||
cur.execute(
|
||||
cur.mogrify(
|
||||
"""
|
||||
INSERT INTO public.scim_auth_codes (tenant_id)
|
||||
VALUES (%(tenant_id)s)
|
||||
RETURNING auth_code
|
||||
""",
|
||||
{"tenant_id": int(client_id)},
|
||||
)
|
||||
)
|
||||
code = cur.fetchone()["auth_code"]
|
||||
params = {"code": code}
|
||||
if state:
|
||||
params["state"] = state
|
||||
url = f"{redirect_uri}?{urlencode(params)}"
|
||||
return RedirectResponse(url)
|
||||
|
||||
|
||||
def _not_found_error_response(resource_id: int):
|
||||
|
|
|
|||
|
|
@ -255,6 +255,8 @@ def _update_role_projects_and_permissions(
|
|||
permissions: list[str] | None,
|
||||
cur: pg_client.PostgresClient,
|
||||
) -> None:
|
||||
if role_id is None:
|
||||
return
|
||||
all_projects = "true" if not project_keys else "false"
|
||||
project_key_clause = helpers.safe_mogrify_array(project_keys, "varchar", cur)
|
||||
permission_clause = helpers.safe_mogrify_array(permissions, "varchar", cur)
|
||||
|
|
|
|||
|
|
@ -108,6 +108,16 @@ CREATE TABLE public.tenants
|
|||
);
|
||||
|
||||
|
||||
CREATE TABLE public.scim_auth_codes
|
||||
(
|
||||
auth_code_id integer generated BY DEFAULT AS IDENTITY PRIMARY KEY,
|
||||
tenant_id integer NOT NULL REFERENCES public.tenants (tenant_id) ON DELETE CASCADE,
|
||||
auth_code text NOT NULL UNIQUE DEFAULT generate_api_key(20),
|
||||
created_at timestamp without time zone NOT NULL DEFAULT (now() at time zone 'utc'),
|
||||
used bool NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
|
||||
CREATE TABLE public.roles
|
||||
(
|
||||
role_id integer generated BY DEFAULT AS IDENTITY PRIMARY KEY,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue