fix(chalice): fix multi-refresh token

fix(chalice): fix spot multi-refresh token
This commit is contained in:
Taha Yassine Kraiem 2025-03-10 13:14:10 +01:00 committed by Kraiem Taha Yassine
parent c40e32d624
commit af4a344c85
6 changed files with 60 additions and 30 deletions

View file

@ -28,9 +28,6 @@ def jwt_authorizer(scheme: str, token: str, leeway=0) -> dict | None:
if scheme.lower() != "bearer": if scheme.lower() != "bearer":
return None return None
try: try:
logger.warning("Checking JWT token: %s", token)
logger.warning("Against: %s", config("JWT_SECRET") if not is_spot_token(token) else config("JWT_SPOT_SECRET"))
logger.warning(get_supported_audience())
payload = jwt.decode(jwt=token, payload = jwt.decode(jwt=token,
key=config("JWT_SECRET") if not is_spot_token(token) else config("JWT_SPOT_SECRET"), key=config("JWT_SECRET") if not is_spot_token(token) else config("JWT_SPOT_SECRET"),
algorithms=config("JWT_ALGORITHM"), algorithms=config("JWT_ALGORITHM"),

View file

@ -18,7 +18,7 @@ def refresh_spot_jwt_iat_jti(user_id):
{"user_id": user_id}) {"user_id": user_id})
cur.execute(query) cur.execute(query)
row = cur.fetchone() row = cur.fetchone()
return row.get("spot_jwt_iat"), row.get("spot_jwt_refresh_jti"), row.get("spot_jwt_refresh_iat") return users.RefreshSpotJWTs(**row)
def logout(user_id: int): def logout(user_id: int):
@ -26,13 +26,13 @@ def logout(user_id: int):
def refresh(user_id: int, tenant_id: int = -1) -> dict: def refresh(user_id: int, tenant_id: int = -1) -> dict:
spot_jwt_iat, spot_jwt_r_jti, spot_jwt_r_iat = refresh_spot_jwt_iat_jti(user_id=user_id) j = refresh_spot_jwt_iat_jti(user_id=user_id)
return { return {
"jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=spot_jwt_iat, "jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=j.spot_jwt_iat,
aud=AUDIENCE, for_spot=True), aud=AUDIENCE, for_spot=True),
"refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=spot_jwt_r_iat, "refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=j.spot_jwt_refresh_iat,
aud=AUDIENCE, jwt_jti=spot_jwt_r_jti, for_spot=True), aud=AUDIENCE, jwt_jti=j.spot_jwt_refresh_jti, for_spot=True),
"refreshTokenMaxAge": config("JWT_SPOT_REFRESH_EXPIRATION", cast=int) - (spot_jwt_iat - spot_jwt_r_iat) "refreshTokenMaxAge": config("JWT_SPOT_REFRESH_EXPIRATION", cast=int) - (j.spot_jwt_iat - j.spot_jwt_refresh_iat)
} }

View file

@ -1,5 +1,6 @@
import json import json
import secrets import secrets
from typing import Optional
from decouple import config from decouple import config
from fastapi import BackgroundTasks from fastapi import BackgroundTasks
@ -83,7 +84,6 @@ def restore_member(user_id, email, invitation_token, admin, name, owner=False):
"name": name, "invitation_token": invitation_token}) "name": name, "invitation_token": invitation_token})
cur.execute(query) cur.execute(query)
result = cur.fetchone() result = cur.fetchone()
cur.execute(query)
result["created_at"] = TimeUTC.datetime_to_timestamp(result["created_at"]) result["created_at"] = TimeUTC.datetime_to_timestamp(result["created_at"])
return helper.dict_to_camel_case(result) return helper.dict_to_camel_case(result)
@ -552,7 +552,7 @@ def refresh_auth_exists(user_id, jwt_jti=None):
return r is not None return r is not None
class ChangeJwt(BaseModel): class FullLoginJWTs(BaseModel):
jwt_iat: int jwt_iat: int
jwt_refresh_jti: str jwt_refresh_jti: str
jwt_refresh_iat: int jwt_refresh_iat: int
@ -565,11 +565,23 @@ class ChangeJwt(BaseModel):
def _transform_data(cls, values): def _transform_data(cls, values):
if values.get("jwt_refresh_jti") is not None: if values.get("jwt_refresh_jti") is not None:
values["jwt_refresh_jti"] = str(values["jwt_refresh_jti"]) values["jwt_refresh_jti"] = str(values["jwt_refresh_jti"])
if values.get("jwt_refresh_jti") is not None: if values.get("spot_jwt_refresh_jti") is not None:
values["spot_jwt_refresh_jti"] = str(values["spot_jwt_refresh_jti"]) values["spot_jwt_refresh_jti"] = str(values["spot_jwt_refresh_jti"])
return values return values
class RefreshLoginJWTs(FullLoginJWTs):
spot_jwt_iat: Optional[int] = None
spot_jwt_refresh_jti: Optional[str] = None
spot_jwt_refresh_iat: Optional[int] = None
class RefreshSpotJWTs(FullLoginJWTs):
jwt_iat: Optional[int] = None
jwt_refresh_jti: Optional[str] = None
jwt_refresh_iat: Optional[int] = None
def change_jwt_iat_jti(user_id): def change_jwt_iat_jti(user_id):
with pg_client.PostgresClient() as cur: with pg_client.PostgresClient() as cur:
query = cur.mogrify(f"""UPDATE public.users query = cur.mogrify(f"""UPDATE public.users
@ -589,7 +601,7 @@ def change_jwt_iat_jti(user_id):
{"user_id": user_id}) {"user_id": user_id})
cur.execute(query) cur.execute(query)
row = cur.fetchone() row = cur.fetchone()
return ChangeJwt(**row) return FullLoginJWTs(**row)
def refresh_jwt_iat_jti(user_id): def refresh_jwt_iat_jti(user_id):
@ -604,7 +616,7 @@ def refresh_jwt_iat_jti(user_id):
{"user_id": user_id}) {"user_id": user_id})
cur.execute(query) cur.execute(query)
row = cur.fetchone() row = cur.fetchone()
return row.get("jwt_iat"), row.get("jwt_refresh_jti"), row.get("jwt_refresh_iat") return RefreshLoginJWTs(**row)
def authenticate(email, password, for_change_password=False) -> dict | bool | None: def authenticate(email, password, for_change_password=False) -> dict | bool | None:
@ -672,13 +684,13 @@ def logout(user_id: int):
def refresh(user_id: int, tenant_id: int = -1) -> dict: def refresh(user_id: int, tenant_id: int = -1) -> dict:
jwt_iat, jwt_r_jti, jwt_r_iat = refresh_jwt_iat_jti(user_id=user_id) j = refresh_jwt_iat_jti(user_id=user_id)
return { return {
"jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=jwt_iat, "jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=j.jwt_iat,
aud=AUDIENCE), aud=AUDIENCE),
"refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=jwt_r_iat, "refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=j.jwt_refresh_iat,
aud=AUDIENCE, jwt_jti=jwt_r_jti), aud=AUDIENCE, jwt_jti=j.jwt_refresh_jti),
"refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int) - (jwt_iat - jwt_r_iat) "refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int) - (j.jwt_iat - j.jwt_refresh_iat),
} }

View file

@ -1,3 +1,4 @@
from decouple import config
from fastapi import Depends from fastapi import Depends
from starlette.responses import JSONResponse, Response from starlette.responses import JSONResponse, Response
@ -8,7 +9,10 @@ from routers.base import get_routers
public_app, app, app_apikey = get_routers(prefix="/spot", tags=["spot"]) public_app, app, app_apikey = get_routers(prefix="/spot", tags=["spot"])
COOKIE_PATH = "/api/spot/refresh" if config("LOCAL_DEV", cast=bool, default=False):
COOKIE_PATH = "/spot/refresh"
else:
COOKIE_PATH = "/api/spot/refresh"
@app.get('/logout') @app.get('/logout')

View file

@ -1,6 +1,7 @@
import json import json
import logging import logging
import secrets import secrets
from typing import Optional
from decouple import config from decouple import config
from fastapi import BackgroundTasks, HTTPException from fastapi import BackgroundTasks, HTTPException
@ -657,7 +658,7 @@ def refresh_auth_exists(user_id, tenant_id, jwt_jti=None):
return r is not None return r is not None
class ChangeJwt(BaseModel): class FullLoginJWTs(BaseModel):
jwt_iat: int jwt_iat: int
jwt_refresh_jti: str jwt_refresh_jti: str
jwt_refresh_iat: int jwt_refresh_iat: int
@ -670,11 +671,23 @@ class ChangeJwt(BaseModel):
def _transform_data(cls, values): def _transform_data(cls, values):
if values.get("jwt_refresh_jti") is not None: if values.get("jwt_refresh_jti") is not None:
values["jwt_refresh_jti"] = str(values["jwt_refresh_jti"]) values["jwt_refresh_jti"] = str(values["jwt_refresh_jti"])
if values.get("jwt_refresh_jti") is not None: if values.get("spot_jwt_refresh_jti") is not None:
values["spot_jwt_refresh_jti"] = str(values["spot_jwt_refresh_jti"]) values["spot_jwt_refresh_jti"] = str(values["spot_jwt_refresh_jti"])
return values return values
class RefreshLoginJWTs(FullLoginJWTs):
spot_jwt_iat: Optional[int] = None
spot_jwt_refresh_jti: Optional[str] = None
spot_jwt_refresh_iat: Optional[int] = None
class RefreshSpotJWTs(FullLoginJWTs):
jwt_iat: Optional[int] = None
jwt_refresh_jti: Optional[str] = None
jwt_refresh_iat: Optional[int] = None
def change_jwt_iat_jti(user_id): def change_jwt_iat_jti(user_id):
with pg_client.PostgresClient() as cur: with pg_client.PostgresClient() as cur:
query = cur.mogrify(f"""UPDATE public.users query = cur.mogrify(f"""UPDATE public.users
@ -694,7 +707,7 @@ def change_jwt_iat_jti(user_id):
{"user_id": user_id}) {"user_id": user_id})
cur.execute(query) cur.execute(query)
row = cur.fetchone() row = cur.fetchone()
return ChangeJwt(**row) return FullLoginJWTs(**row)
def refresh_jwt_iat_jti(user_id): def refresh_jwt_iat_jti(user_id):
@ -709,7 +722,7 @@ def refresh_jwt_iat_jti(user_id):
{"user_id": user_id}) {"user_id": user_id})
cur.execute(query) cur.execute(query)
row = cur.fetchone() row = cur.fetchone()
return row.get("jwt_iat"), row.get("jwt_refresh_jti"), row.get("jwt_refresh_iat") return RefreshLoginJWTs(**row)
def authenticate(email, password, for_change_password=False) -> dict | bool | None: def authenticate(email, password, for_change_password=False) -> dict | bool | None:
@ -869,13 +882,13 @@ def logout(user_id: int):
def refresh(user_id: int, tenant_id: int = -1) -> dict: def refresh(user_id: int, tenant_id: int = -1) -> dict:
jwt_iat, jwt_r_jti, jwt_r_iat = refresh_jwt_iat_jti(user_id=user_id) j = refresh_jwt_iat_jti(user_id=user_id)
return { return {
"jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=jwt_iat, "jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=j.jwt_iat,
aud=AUDIENCE), aud=AUDIENCE),
"refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=jwt_r_iat, "refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=j.jwt_refresh_iat,
aud=AUDIENCE, jwt_jti=jwt_r_jti), aud=AUDIENCE, jwt_jti=j.jwt_refresh_jti),
"refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int) - (jwt_iat - jwt_r_iat) "refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int) - (j.jwt_iat - j.jwt_refresh_iat),
} }

View file

@ -1,3 +1,4 @@
from decouple import config
from fastapi import Depends from fastapi import Depends
from starlette.responses import JSONResponse, Response from starlette.responses import JSONResponse, Response
@ -8,7 +9,10 @@ from routers.base import get_routers
public_app, app, app_apikey = get_routers(prefix="/spot", tags=["spot"]) public_app, app, app_apikey = get_routers(prefix="/spot", tags=["spot"])
COOKIE_PATH = "/api/spot/refresh" if config("LOCAL_DEV", cast=bool, default=False):
COOKIE_PATH = "/spot/refresh"
else:
COOKIE_PATH = "/api/spot/refresh"
@app.get('/logout') @app.get('/logout')