fix(chalice): fix multi-refresh token

fix(chalice): fix spot multi-refresh token
This commit is contained in:
Taha Yassine Kraiem 2025-03-10 13:14:10 +01:00 committed by Kraiem Taha Yassine
parent c40e32d624
commit af4a344c85
6 changed files with 60 additions and 30 deletions

View file

@ -28,9 +28,6 @@ def jwt_authorizer(scheme: str, token: str, leeway=0) -> dict | None:
if scheme.lower() != "bearer":
return None
try:
logger.warning("Checking JWT token: %s", token)
logger.warning("Against: %s", config("JWT_SECRET") if not is_spot_token(token) else config("JWT_SPOT_SECRET"))
logger.warning(get_supported_audience())
payload = jwt.decode(jwt=token,
key=config("JWT_SECRET") if not is_spot_token(token) else config("JWT_SPOT_SECRET"),
algorithms=config("JWT_ALGORITHM"),

View file

@ -18,7 +18,7 @@ def refresh_spot_jwt_iat_jti(user_id):
{"user_id": user_id})
cur.execute(query)
row = cur.fetchone()
return row.get("spot_jwt_iat"), row.get("spot_jwt_refresh_jti"), row.get("spot_jwt_refresh_iat")
return users.RefreshSpotJWTs(**row)
def logout(user_id: int):
@ -26,13 +26,13 @@ def logout(user_id: int):
def refresh(user_id: int, tenant_id: int = -1) -> dict:
spot_jwt_iat, spot_jwt_r_jti, spot_jwt_r_iat = refresh_spot_jwt_iat_jti(user_id=user_id)
j = refresh_spot_jwt_iat_jti(user_id=user_id)
return {
"jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=spot_jwt_iat,
"jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=j.spot_jwt_iat,
aud=AUDIENCE, for_spot=True),
"refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=spot_jwt_r_iat,
aud=AUDIENCE, jwt_jti=spot_jwt_r_jti, for_spot=True),
"refreshTokenMaxAge": config("JWT_SPOT_REFRESH_EXPIRATION", cast=int) - (spot_jwt_iat - spot_jwt_r_iat)
"refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=j.spot_jwt_refresh_iat,
aud=AUDIENCE, jwt_jti=j.spot_jwt_refresh_jti, for_spot=True),
"refreshTokenMaxAge": config("JWT_SPOT_REFRESH_EXPIRATION", cast=int) - (j.spot_jwt_iat - j.spot_jwt_refresh_iat)
}

View file

@ -1,5 +1,6 @@
import json
import secrets
from typing import Optional
from decouple import config
from fastapi import BackgroundTasks
@ -83,7 +84,6 @@ def restore_member(user_id, email, invitation_token, admin, name, owner=False):
"name": name, "invitation_token": invitation_token})
cur.execute(query)
result = cur.fetchone()
cur.execute(query)
result["created_at"] = TimeUTC.datetime_to_timestamp(result["created_at"])
return helper.dict_to_camel_case(result)
@ -552,7 +552,7 @@ def refresh_auth_exists(user_id, jwt_jti=None):
return r is not None
class ChangeJwt(BaseModel):
class FullLoginJWTs(BaseModel):
jwt_iat: int
jwt_refresh_jti: str
jwt_refresh_iat: int
@ -565,11 +565,23 @@ class ChangeJwt(BaseModel):
def _transform_data(cls, values):
if values.get("jwt_refresh_jti") is not None:
values["jwt_refresh_jti"] = str(values["jwt_refresh_jti"])
if values.get("jwt_refresh_jti") is not None:
if values.get("spot_jwt_refresh_jti") is not None:
values["spot_jwt_refresh_jti"] = str(values["spot_jwt_refresh_jti"])
return values
class RefreshLoginJWTs(FullLoginJWTs):
spot_jwt_iat: Optional[int] = None
spot_jwt_refresh_jti: Optional[str] = None
spot_jwt_refresh_iat: Optional[int] = None
class RefreshSpotJWTs(FullLoginJWTs):
jwt_iat: Optional[int] = None
jwt_refresh_jti: Optional[str] = None
jwt_refresh_iat: Optional[int] = None
def change_jwt_iat_jti(user_id):
with pg_client.PostgresClient() as cur:
query = cur.mogrify(f"""UPDATE public.users
@ -589,7 +601,7 @@ def change_jwt_iat_jti(user_id):
{"user_id": user_id})
cur.execute(query)
row = cur.fetchone()
return ChangeJwt(**row)
return FullLoginJWTs(**row)
def refresh_jwt_iat_jti(user_id):
@ -604,7 +616,7 @@ def refresh_jwt_iat_jti(user_id):
{"user_id": user_id})
cur.execute(query)
row = cur.fetchone()
return row.get("jwt_iat"), row.get("jwt_refresh_jti"), row.get("jwt_refresh_iat")
return RefreshLoginJWTs(**row)
def authenticate(email, password, for_change_password=False) -> dict | bool | None:
@ -672,13 +684,13 @@ def logout(user_id: int):
def refresh(user_id: int, tenant_id: int = -1) -> dict:
jwt_iat, jwt_r_jti, jwt_r_iat = refresh_jwt_iat_jti(user_id=user_id)
j = refresh_jwt_iat_jti(user_id=user_id)
return {
"jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=jwt_iat,
"jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=j.jwt_iat,
aud=AUDIENCE),
"refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=jwt_r_iat,
aud=AUDIENCE, jwt_jti=jwt_r_jti),
"refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int) - (jwt_iat - jwt_r_iat)
"refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=j.jwt_refresh_iat,
aud=AUDIENCE, jwt_jti=j.jwt_refresh_jti),
"refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int) - (j.jwt_iat - j.jwt_refresh_iat),
}

View file

@ -1,3 +1,4 @@
from decouple import config
from fastapi import Depends
from starlette.responses import JSONResponse, Response
@ -8,7 +9,10 @@ from routers.base import get_routers
public_app, app, app_apikey = get_routers(prefix="/spot", tags=["spot"])
COOKIE_PATH = "/api/spot/refresh"
if config("LOCAL_DEV", cast=bool, default=False):
COOKIE_PATH = "/spot/refresh"
else:
COOKIE_PATH = "/api/spot/refresh"
@app.get('/logout')

View file

@ -1,6 +1,7 @@
import json
import logging
import secrets
from typing import Optional
from decouple import config
from fastapi import BackgroundTasks, HTTPException
@ -657,7 +658,7 @@ def refresh_auth_exists(user_id, tenant_id, jwt_jti=None):
return r is not None
class ChangeJwt(BaseModel):
class FullLoginJWTs(BaseModel):
jwt_iat: int
jwt_refresh_jti: str
jwt_refresh_iat: int
@ -670,11 +671,23 @@ class ChangeJwt(BaseModel):
def _transform_data(cls, values):
if values.get("jwt_refresh_jti") is not None:
values["jwt_refresh_jti"] = str(values["jwt_refresh_jti"])
if values.get("jwt_refresh_jti") is not None:
if values.get("spot_jwt_refresh_jti") is not None:
values["spot_jwt_refresh_jti"] = str(values["spot_jwt_refresh_jti"])
return values
class RefreshLoginJWTs(FullLoginJWTs):
spot_jwt_iat: Optional[int] = None
spot_jwt_refresh_jti: Optional[str] = None
spot_jwt_refresh_iat: Optional[int] = None
class RefreshSpotJWTs(FullLoginJWTs):
jwt_iat: Optional[int] = None
jwt_refresh_jti: Optional[str] = None
jwt_refresh_iat: Optional[int] = None
def change_jwt_iat_jti(user_id):
with pg_client.PostgresClient() as cur:
query = cur.mogrify(f"""UPDATE public.users
@ -694,7 +707,7 @@ def change_jwt_iat_jti(user_id):
{"user_id": user_id})
cur.execute(query)
row = cur.fetchone()
return ChangeJwt(**row)
return FullLoginJWTs(**row)
def refresh_jwt_iat_jti(user_id):
@ -709,7 +722,7 @@ def refresh_jwt_iat_jti(user_id):
{"user_id": user_id})
cur.execute(query)
row = cur.fetchone()
return row.get("jwt_iat"), row.get("jwt_refresh_jti"), row.get("jwt_refresh_iat")
return RefreshLoginJWTs(**row)
def authenticate(email, password, for_change_password=False) -> dict | bool | None:
@ -869,13 +882,13 @@ def logout(user_id: int):
def refresh(user_id: int, tenant_id: int = -1) -> dict:
jwt_iat, jwt_r_jti, jwt_r_iat = refresh_jwt_iat_jti(user_id=user_id)
j = refresh_jwt_iat_jti(user_id=user_id)
return {
"jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=jwt_iat,
"jwt": authorizers.generate_jwt(user_id=user_id, tenant_id=tenant_id, iat=j.jwt_iat,
aud=AUDIENCE),
"refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=jwt_r_iat,
aud=AUDIENCE, jwt_jti=jwt_r_jti),
"refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int) - (jwt_iat - jwt_r_iat)
"refreshToken": authorizers.generate_jwt_refresh(user_id=user_id, tenant_id=tenant_id, iat=j.jwt_refresh_iat,
aud=AUDIENCE, jwt_jti=j.jwt_refresh_jti),
"refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int) - (j.jwt_iat - j.jwt_refresh_iat),
}

View file

@ -1,3 +1,4 @@
from decouple import config
from fastapi import Depends
from starlette.responses import JSONResponse, Response
@ -8,7 +9,10 @@ from routers.base import get_routers
public_app, app, app_apikey = get_routers(prefix="/spot", tags=["spot"])
COOKIE_PATH = "/api/spot/refresh"
if config("LOCAL_DEV", cast=bool, default=False):
COOKIE_PATH = "/spot/refresh"
else:
COOKIE_PATH = "/api/spot/refresh"
@app.get('/logout')