From 293619e61e38eb50f6ad2c0890f7445c7a65fb1e Mon Sep 17 00:00:00 2001 From: Taha Yassine Kraiem Date: Fri, 21 Jan 2022 14:25:18 +0100 Subject: [PATCH] feat(api): EE-SSO extra endpoint for response processing using tenantKey --- ee/api/routers/saml.py | 85 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/ee/api/routers/saml.py b/ee/api/routers/saml.py index 50723a1db..73b0e7393 100644 --- a/ee/api/routers/saml.py +++ b/ee/api/routers/saml.py @@ -43,6 +43,8 @@ async def process_sso_assertion(request: Request): user_data = auth.get_attributes() elif auth.get_settings().is_debug_active(): error_reason = auth.get_last_error_reason() + print("SAML2 error:") + print(error_reason) return {"errors": [error_reason]} email = auth.get_nameid() @@ -108,6 +110,88 @@ async def process_sso_assertion(request: Request): headers={'Location': SAML2_helper.get_landing_URL(jwt)}) +@public_app.post('/sso/saml2/acs/{tenantKey}', tags=["saml2"]) +async def process_sso_assertion_tk(tenantKey: str, request: Request): + req = await prepare_request(request=request) + session = req["cookie"]["session"] + auth = init_saml_auth(req) + + request_id = None + if 'AuthNRequestID' in session: + request_id = session['AuthNRequestID'] + + auth.process_response(request_id=request_id) + errors = auth.get_errors() + user_data = {} + if len(errors) == 0: + if 'AuthNRequestID' in session: + del session['AuthNRequestID'] + user_data = auth.get_attributes() + elif auth.get_settings().is_debug_active(): + error_reason = auth.get_last_error_reason() + print("SAML2 error:") + print(error_reason) + return {"errors": [error_reason]} + + email = auth.get_nameid() + print("received nameId:") + print(email) + existing = users.get_by_email_only(auth.get_nameid()) + + internal_id = next(iter(user_data.get("internalId", [])), None) + + t = tenants.get_by_tenant_key(tenantKey) + if t is None: + print("invalid tenantKey, please copy the correct value from Preferences > Account") + return {"errors": ["invalid tenantKey, please copy the correct value from Preferences > Account"]} + print(user_data) + role_name = user_data.get("role", []) + if len(role_name) == 0: + print("No role specified, setting role to member") + role_name = ["member"] + role_name = role_name[0] + role = roles.get_role_by_name(tenant_id=t['tenantId'], name=role_name) + if role is None: + return {"errors": [f"role {role_name} not found, please create it in openreplay first"]} + + admin_privileges = user_data.get("adminPrivileges", []) + admin_privileges = not (len(admin_privileges) == 0 + or admin_privileges[0] is None + or admin_privileges[0].lower() == "false") + + if existing is None: + deleted = users.get_deleted_user_by_email(auth.get_nameid()) + if deleted is not None: + print("== restore deleted user ==") + users.restore_sso_user(user_id=deleted["userId"], tenant_id=t['tenantId'], email=email, + admin=admin_privileges, origin=SAML2_helper.get_saml2_provider(), + name=" ".join(user_data.get("firstName", []) + user_data.get("lastName", [])), + internal_id=internal_id, role_id=role["roleId"]) + else: + print("== new user ==") + users.create_sso_user(tenant_id=t['tenantId'], email=email, admin=admin_privileges, + origin=SAML2_helper.get_saml2_provider(), + name=" ".join(user_data.get("firstName", []) + user_data.get("lastName", [])), + internal_id=internal_id, role_id=role["roleId"]) + else: + if t['tenantId'] != existing["tenantId"]: + print("user exists for a different tenant") + return {"errors": ["user exists for a different tenant"]} + if existing.get("origin") is None: + print(f"== migrating user to {SAML2_helper.get_saml2_provider()} ==") + users.update(tenant_id=t['tenantId'], user_id=existing["id"], + changes={"origin": SAML2_helper.get_saml2_provider(), "internal_id": internal_id}) + expiration = auth.get_session_expiration() + expiration = expiration if expiration is not None and expiration > 10 * 60 \ + else int(config("sso_exp_delta_seconds", cast=int, default=24 * 60 * 60)) + jwt = users.authenticate_sso(email=email, internal_id=internal_id, exp=expiration) + if jwt is None: + return {"errors": ["null JWT"]} + return Response( + status_code=status.HTTP_302_FOUND, + headers={'Location': SAML2_helper.get_landing_URL(jwt)}) + + @public_app.get('/sso/saml2/sls', tags=["saml2"]) async def process_sls_assertion(request: Request): req = await prepare_request(request=request) @@ -143,6 +227,7 @@ async def process_sls_assertion(request: Request): return RedirectResponse(url=config("SITE_URL")) +@public_app.get('/sso/saml2/metadata/', tags=["saml2"]) @public_app.get('/sso/saml2/metadata', tags=["saml2"]) async def saml2_metadata(request: Request): req = await prepare_request(request=request)