openreplay/api/chalicelib/core/authorizers.py
2025-03-10 12:02:03 +01:00

105 lines
3.6 KiB
Python

import logging
import jwt
from decouple import config
from chalicelib.core import tenants
from chalicelib.core import users, spot
from chalicelib.utils.TimeUTC import TimeUTC
logger = logging.getLogger(__name__)
def get_supported_audience():
return [users.AUDIENCE, spot.AUDIENCE]
def is_spot_token(token: str) -> bool:
try:
decoded_token = jwt.decode(token, options={"verify_signature": False, "verify_exp": False})
audience = decoded_token.get("aud")
return audience == spot.AUDIENCE
except jwt.InvalidTokenError:
logger.error(f"Invalid token for is_spot_token: {token}")
raise
def jwt_authorizer(scheme: str, token: str, leeway=0) -> dict | None:
if scheme.lower() != "bearer":
return None
try:
logger.warning("Checking JWT token: %s", token)
logger.warning("Against: %s", config("JWT_SECRET") if not is_spot_token(token) else config("JWT_SPOT_SECRET"))
logger.warning(get_supported_audience())
payload = jwt.decode(jwt=token,
key=config("JWT_SECRET") if not is_spot_token(token) else config("JWT_SPOT_SECRET"),
algorithms=config("JWT_ALGORITHM"),
audience=get_supported_audience(),
leeway=leeway)
except jwt.ExpiredSignatureError:
logger.debug("! JWT Expired signature")
return None
except BaseException as e:
logger.warning("! JWT Base Exception", exc_info=e)
return None
return payload
def jwt_refresh_authorizer(scheme: str, token: str):
if scheme.lower() != "bearer":
return None
try:
payload = jwt.decode(jwt=token,
key=config("JWT_REFRESH_SECRET") if not is_spot_token(token) \
else config("JWT_SPOT_REFRESH_SECRET"),
algorithms=config("JWT_ALGORITHM"),
audience=get_supported_audience())
except jwt.ExpiredSignatureError:
logger.debug("! JWT-refresh Expired signature")
return None
except BaseException as e:
logger.error("! JWT-refresh Base Exception", exc_info=e)
return None
return payload
def generate_jwt(user_id, tenant_id, iat, aud, for_spot=False):
token = jwt.encode(
payload={
"userId": user_id,
"tenantId": tenant_id,
"exp": iat + (config("JWT_EXPIRATION", cast=int) if not for_spot
else config("JWT_SPOT_EXPIRATION", cast=int)),
"iss": config("JWT_ISSUER"),
"iat": iat,
"aud": aud
},
key=config("JWT_SECRET") if not for_spot else config("JWT_SPOT_SECRET"),
algorithm=config("JWT_ALGORITHM")
)
return token
def generate_jwt_refresh(user_id, tenant_id, iat, aud, jwt_jti, for_spot=False):
token = jwt.encode(
payload={
"userId": user_id,
"tenantId": tenant_id,
"exp": iat + (config("JWT_REFRESH_EXPIRATION", cast=int) if not for_spot
else config("JWT_SPOT_REFRESH_EXPIRATION", cast=int)),
"iss": config("JWT_ISSUER"),
"iat": iat,
"aud": aud,
"jti": jwt_jti
},
key=config("JWT_REFRESH_SECRET") if not for_spot else config("JWT_SPOT_REFRESH_SECRET"),
algorithm=config("JWT_ALGORITHM")
)
return token
def api_key_authorizer(token):
t = tenants.get_by_api_key(token)
if t is not None:
t["createdAt"] = TimeUTC.datetime_to_timestamp(t["createdAt"])
return t