diff --git a/api/chalicelib/core/users.py b/api/chalicelib/core/users.py index 393c201f0..2b37375b6 100644 --- a/api/chalicelib/core/users.py +++ b/api/chalicelib/core/users.py @@ -644,16 +644,17 @@ def authenticate(email, password, for_change_password=False, include_spot=False) **r } if include_spot: - response = {**response, - "spotJwt": authorizers.generate_jwt(user_id=r['userId'], tenant_id=r['tenantId'], - iat=j_r.spot_jwt_iat, aud=spot.AUDIENCE), - "spotRefreshToken": authorizers.generate_jwt_refresh(user_id=r['userId'], - tenant_id=r['tenantId'], - iat=j_r.spot_jwt_refresh_iat, - aud=spot.AUDIENCE, - jwt_jti=j_r.spot_jwt_refresh_jti), - "spotRefreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int), - } + response = { + **response, + "spotJwt": authorizers.generate_jwt(user_id=r['userId'], tenant_id=r['tenantId'], + iat=j_r.spot_jwt_iat, aud=spot.AUDIENCE), + "spotRefreshToken": authorizers.generate_jwt_refresh(user_id=r['userId'], + tenant_id=r['tenantId'], + iat=j_r.spot_jwt_refresh_iat, + aud=spot.AUDIENCE, + jwt_jti=j_r.spot_jwt_refresh_jti), + "spotRefreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int) + } return response return None diff --git a/ee/api/chalicelib/core/users.py b/ee/api/chalicelib/core/users.py index 28d27af71..31b8959d9 100644 --- a/ee/api/chalicelib/core/users.py +++ b/ee/api/chalicelib/core/users.py @@ -860,7 +860,7 @@ def refresh(user_id: int, tenant_id: int) -> dict: } -def authenticate_sso(email, internal_id, exp=None): +def authenticate_sso(email: str, internal_id: str, exp=None, include_spot: bool = False): with pg_client.PostgresClient() as cur: query = cur.mogrify( f"""SELECT @@ -886,15 +886,28 @@ def authenticate_sso(email, internal_id, exp=None): if r["serviceAccount"]: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="service account is not authorized to login") - jwt_iat, jwt_r_jti, jwt_r_iat = change_jwt_iat_jti(user_id=r['userId']) - return { - "jwt": authorizers.generate_jwt(user_id=r['userId'], tenant_id=r['tenantId'], iat=jwt_iat, + j_r = change_jwt_iat_jti(user_id=r['userId'], include_spot=include_spot) + response = { + "jwt": authorizers.generate_jwt(user_id=r['userId'], tenant_id=r['tenantId'], iat=j_r.jwt_iat, aud=AUDIENCE), "refreshToken": authorizers.generate_jwt_refresh(user_id=r['userId'], tenant_id=r['tenantId'], - iat=jwt_r_iat, - aud=AUDIENCE, jwt_jti=jwt_r_jti), - "refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int), + iat=j_r.jwt_refresh_iat, + aud=AUDIENCE, jwt_jti=j_r.jwt_refresh_jti), + "refreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int) } + if include_spot: + response = { + **response, + "spotJwt": authorizers.generate_jwt(user_id=r['userId'], tenant_id=r['tenantId'], + iat=j_r.spot_jwt_iat, aud=spot.AUDIENCE), + "spotRefreshToken": authorizers.generate_jwt_refresh(user_id=r['userId'], + tenant_id=r['tenantId'], + iat=j_r.spot_jwt_refresh_iat, + aud=spot.AUDIENCE, + jwt_jti=j_r.spot_jwt_refresh_jti), + "spotRefreshTokenMaxAge": config("JWT_REFRESH_EXPIRATION", cast=int) + } + return response logger.warning(f"SSO user not found with email: {email} and internal_id: {internal_id}") return None diff --git a/ee/api/chalicelib/utils/SAML2_helper.py b/ee/api/chalicelib/utils/SAML2_helper.py index 2878f6df1..4857dda6f 100644 --- a/ee/api/chalicelib/utils/SAML2_helper.py +++ b/ee/api/chalicelib/utils/SAML2_helper.py @@ -1,4 +1,5 @@ import logging +import urllib from http import cookies from os import environ from urllib.parse import urlparse @@ -137,14 +138,19 @@ def get_saml2_provider(): config("idp_name", default="saml2")) > 0 else None -def get_landing_URL(jwt, redirect_to_link2=False): +def get_landing_URL(query_params: dict = None, redirect_to_link2=False): + if query_params is not None and len(query_params.keys()) > 0: + query_params = "?" + urllib.parse.urlencode(query_params) + else: + query_params = "" + if redirect_to_link2: if len(config("sso_landing_override", default="")) == 0: logging.warning("SSO trying to redirect to custom URL, but sso_landing_override env var is empty") else: - return config("sso_landing_override") + "?jwt=%s" % jwt + return config("sso_landing_override") + query_params - return config("SITE_URL") + config("sso_landing", default="/login?jwt=%s") % jwt + return config("SITE_URL") + config("sso_landing", default="/login") + query_params environ["hastSAML2"] = str(is_saml2_available()) diff --git a/ee/api/routers/saml.py b/ee/api/routers/saml.py index 7a0f9caf4..083386ab2 100644 --- a/ee/api/routers/saml.py +++ b/ee/api/routers/saml.py @@ -20,11 +20,11 @@ from starlette.responses import RedirectResponse @public_app.get("/sso/saml2", tags=["saml2"]) @public_app.get("/sso/saml2/", tags=["saml2"]) -async def start_sso(request: Request, iFrame: bool = False): +async def start_sso(request: Request, iFrame: bool = False, spot: bool = False): request.path = '' req = await prepare_request(request=request) auth = init_saml_auth(req) - sso_built_url = auth.login(return_to=json.dumps({'iFrame': iFrame})) + sso_built_url = auth.login(return_to=json.dumps({'iFrame': iFrame, 'spot': spot})) return RedirectResponse(url=sso_built_url) @@ -47,6 +47,7 @@ async def process_sso_assertion(request: Request): post_data = {} redirect_to_link2 = None + spot = False relay_state = post_data.get('RelayState') if relay_state: if isinstance(relay_state, str): @@ -57,6 +58,7 @@ async def process_sso_assertion(request: Request): logger.error(relay_state) relay_state = {} redirect_to_link2 = relay_state.get("iFrame") + spot = relay_state.get("spot") request_id = None if 'AuthNRequestID' in session: @@ -127,18 +129,25 @@ async def process_sso_assertion(request: Request): users.update(tenant_id=t['tenantId'], user_id=existing["userId"], changes={"origin": SAML2_helper.get_saml2_provider(), "internal_id": internal_id}) expiration = auth.get_session_expiration() + print(">>>>>>>>") + print(expiration) expiration = expiration if expiration is not None and expiration > 10 * 60 \ else int(config("sso_exp_delta_seconds", cast=int, default=24 * 60 * 60)) - jwt = users.authenticate_sso(email=email, internal_id=internal_id, exp=expiration) + jwt = users.authenticate_sso(email=email, internal_id=internal_id, exp=expiration, include_spot=spot) if jwt is None: return {"errors": ["null JWT"]} - refresh_token = jwt["refreshToken"] - refresh_token_max_age = jwt["refreshTokenMaxAge"] - response = Response( - status_code=status.HTTP_302_FOUND, - headers={'Location': SAML2_helper.get_landing_URL(jwt["jwt"], redirect_to_link2=redirect_to_link2)}) - response.set_cookie(key="refreshToken", value=refresh_token, path="/api/refresh", - max_age=refresh_token_max_age, secure=True, httponly=True) + response = Response(status_code=status.HTTP_302_FOUND) + response.set_cookie(key="refreshToken", value=jwt["refreshToken"], path="/api/refresh", + max_age=jwt["refreshTokenMaxAge"], secure=True, httponly=True) + if spot: + response.set_cookie(key="spotRefreshToken", value=jwt["spotRefreshToken"], path="/api/spot/refresh", + max_age=jwt["spotRefreshTokenMaxAge"], secure=True, httponly=True) + headers = {'Location': SAML2_helper.get_landing_URL({"jwt": jwt["jwt"], "spotJwt": jwt["spotJwt"]}, + redirect_to_link2=redirect_to_link2)} + else: + headers = {'Location': SAML2_helper.get_landing_URL({"jwt": jwt["jwt"]}, redirect_to_link2=redirect_to_link2)} + + response.init_headers(headers) return response @@ -161,6 +170,7 @@ async def process_sso_assertion_tk(tenantKey: str, request: Request): post_data = {} redirect_to_link2 = None + spot = False relay_state = post_data.get('RelayState') if relay_state: if isinstance(relay_state, str): @@ -171,6 +181,7 @@ async def process_sso_assertion_tk(tenantKey: str, request: Request): logger.error(relay_state) relay_state = {} redirect_to_link2 = relay_state.get("iFrame") + spot = relay_state.get("spot") request_id = None if 'AuthNRequestID' in session: @@ -239,16 +250,21 @@ async def process_sso_assertion_tk(tenantKey: str, request: Request): expiration = auth.get_session_expiration() expiration = expiration if expiration is not None and expiration > 10 * 60 \ else int(config("sso_exp_delta_seconds", cast=int, default=24 * 60 * 60)) - jwt = users.authenticate_sso(email=email, internal_id=internal_id, exp=expiration) + jwt = users.authenticate_sso(email=email, internal_id=internal_id, exp=expiration, include_spot=spot) if jwt is None: return {"errors": ["null JWT"]} - refresh_token = jwt["refreshToken"] - refresh_token_max_age = jwt["refreshTokenMaxAge"] - response = Response( - status_code=status.HTTP_302_FOUND, - headers={'Location': SAML2_helper.get_landing_URL(jwt["jwt"], redirect_to_link2=redirect_to_link2)}) - response.set_cookie(key="refreshToken", value=refresh_token, path="/api/refresh", - max_age=refresh_token_max_age, secure=True, httponly=True) + response = Response(status_code=status.HTTP_302_FOUND) + response.set_cookie(key="refreshToken", value=jwt["refreshToken"], path="/api/refresh", + max_age=jwt["refreshTokenMaxAge"], secure=True, httponly=True) + if spot: + response.set_cookie(key="spotRefreshToken", value=jwt["spotRefreshToken"], path="/api/spot/refresh", + max_age=jwt["spotRefreshTokenMaxAge"], secure=True, httponly=True) + headers = {'Location': SAML2_helper.get_landing_URL({"jwt": jwt["jwt"], "spotJwt": jwt["spotJwt"]}, + redirect_to_link2=redirect_to_link2)} + else: + headers = {'Location': SAML2_helper.get_landing_URL({"jwt": jwt["jwt"]}, redirect_to_link2=redirect_to_link2)} + + response.init_headers(headers) return response