diff --git a/ee/api/chalicelib/core/users.py b/ee/api/chalicelib/core/users.py index 65239537a..fae0a1b90 100644 --- a/ee/api/chalicelib/core/users.py +++ b/ee/api/chalicelib/core/users.py @@ -380,7 +380,9 @@ def get_by_email_only(email): (CASE WHEN users.role = 'admin' THEN TRUE ELSE FALSE END) AS admin, (CASE WHEN users.role = 'member' THEN TRUE ELSE FALSE END) AS member, origin, - basic_authentication.password IS NOT NULL AS has_password + basic_authentication.password IS NOT NULL AS has_password, + role_id, + internal_id FROM public.users LEFT JOIN public.basic_authentication ON users.user_id=basic_authentication.user_id WHERE users.email = %(email)s AND users.deleted_at IS NULL diff --git a/ee/api/chalicelib/utils/SAML2_helper.py b/ee/api/chalicelib/utils/SAML2_helper.py index 2878f6df1..02d84abf4 100644 --- a/ee/api/chalicelib/utils/SAML2_helper.py +++ b/ee/api/chalicelib/utils/SAML2_helper.py @@ -10,7 +10,11 @@ from starlette.datastructures import FormData if config("ENABLE_SSO", cast=bool, default=True): from onelogin.saml2.auth import OneLogin_Saml2_Auth -API_PREFIX = "/api" +if config("LOCAL_DEV", default=False, cast=bool): + API_PREFIX = "" +else: + API_PREFIX = "/api" + SAML2 = { "strict": config("saml_strict", cast=bool, default=True), "debug": config("saml_debug", cast=bool, default=True), diff --git a/ee/api/routers/saml.py b/ee/api/routers/saml.py index 7a0f9caf4..27034cc2e 100644 --- a/ee/api/routers/saml.py +++ b/ee/api/routers/saml.py @@ -1,8 +1,12 @@ import json import logging +from decouple import config from fastapi import HTTPException, Request, Response, status +from onelogin.saml2.auth import OneLogin_Saml2_Logout_Request +from starlette.responses import RedirectResponse +from chalicelib.core import users, tenants, roles from chalicelib.utils import SAML2_helper from chalicelib.utils.SAML2_helper import prepare_request, init_saml_auth from routers.base import get_routers @@ -10,12 +14,6 @@ from routers.base import get_routers logger = logging.getLogger(__name__) public_app, app, app_apikey = get_routers() -from decouple import config - -from onelogin.saml2.auth import OneLogin_Saml2_Logout_Request - -from chalicelib.core import users, tenants, roles -from starlette.responses import RedirectResponse @public_app.get("/sso/saml2", tags=["saml2"]) @@ -90,15 +88,19 @@ async def process_sso_assertion(request: Request): logger.error("invalid tenantKey, please copy the correct value from Preferences > Account") return {"errors": ["invalid tenantKey, please copy the correct value from Preferences > Account"]} logger.debug(user_data) - role_name = user_data.get("role", []) - if len(role_name) == 0: + role_names = user_data.get("role", []) + if len(role_names) == 0: logger.info("No role specified, setting role to member") - role_name = ["member"] - role_name = role_name[0] - role = roles.get_role_by_name(tenant_id=t['tenantId'], name=role_name) - if role is None: - return {"errors": [f"role {role_name} not found, please create it in openreplay first"]} + role_names = ["member"] + role = None + for r in role_names: + role = roles.get_role_by_name(tenant_id=t['tenantId'], name=r) + if role is not None: + break + if role is None: + return {"errors": [f"role '{role_names}' not found, please create it in OpenReplay first"]} + logger.info(f"received roles:{role_names}; using:{role['name']}") admin_privileges = user_data.get("adminPrivileges", []) admin_privileges = not (len(admin_privileges) == 0 or admin_privileges[0] is None @@ -122,10 +124,30 @@ async def process_sso_assertion(request: Request): if t['tenantId'] != existing["tenantId"]: logger.warning("user exists for a different tenant") return {"errors": ["user exists for a different tenant"]} - if existing.get("origin") is None: - logger.info(f"== migrating user to {SAML2_helper.get_saml2_provider()} ==") - users.update(tenant_id=t['tenantId'], user_id=existing["userId"], - changes={"origin": SAML2_helper.get_saml2_provider(), "internal_id": internal_id}) + # Check difference between existing user and received data + received_data = { + "role": "admin" if admin_privileges else "member", + "origin": SAML2_helper.get_saml2_provider(), + "name": " ".join(user_data.get("firstName", []) + user_data.get("lastName", [])), + "internal_id": internal_id, + "role_id": role["roleId"] + } + existing_data = { + "role": "admin" if existing["admin"] else "member", + "origin": existing["origin"], + "name": existing["name"], + "internal_id": existing["internalId"], + "role_id": existing["roleId"] + } + to_update = {} + for k in existing_data.keys(): + if (k != "role" or not existing["superAdmin"]) and existing_data[k] != received_data[k]: + to_update[k] = received_data[k] + + if len(to_update.keys()) > 0: + logger.info(f"== Updating user:{existing['userId']}: {to_update} ==") + users.update(tenant_id=t['tenantId'], user_id=existing["userId"], changes=to_update) + expiration = auth.get_session_expiration() expiration = expiration if expiration is not None and expiration > 10 * 60 \ else int(config("sso_exp_delta_seconds", cast=int, default=24 * 60 * 60)) @@ -200,15 +222,19 @@ async def process_sso_assertion_tk(tenantKey: str, request: Request): logger.error("invalid tenantKey, please copy the correct value from Preferences > Account") return {"errors": ["invalid tenantKey, please copy the correct value from Preferences > Account"]} logger.debug(user_data) - role_name = user_data.get("role", []) - if len(role_name) == 0: + role_names = user_data.get("role", []) + if len(role_names) == 0: logger.info("No role specified, setting role to member") - role_name = ["member"] - role_name = role_name[0] - role = roles.get_role_by_name(tenant_id=t['tenantId'], name=role_name) - if role is None: - return {"errors": [f"role {role_name} not found, please create it in openreplay first"]} + role_names = ["member"] + role = None + for r in role_names: + role = roles.get_role_by_name(tenant_id=t['tenantId'], name=r) + if role is not None: + break + if role is None: + return {"errors": [f"role '{role_names}' not found, please create it in OpenReplay first"]} + logger.info(f"received roles:{role_names}; using:{role['name']}") admin_privileges = user_data.get("adminPrivileges", []) admin_privileges = not (len(admin_privileges) == 0 or admin_privileges[0] is None @@ -232,10 +258,30 @@ async def process_sso_assertion_tk(tenantKey: str, request: Request): if t['tenantId'] != existing["tenantId"]: logger.warning("user exists for a different tenant") return {"errors": ["user exists for a different tenant"]} - if existing.get("origin") is None: - logger.info(f"== migrating user to {SAML2_helper.get_saml2_provider()} ==") - users.update(tenant_id=t['tenantId'], user_id=existing["userId"], - changes={"origin": SAML2_helper.get_saml2_provider(), "internal_id": internal_id}) + # Check difference between existing user and received data + received_data = { + "role": "admin" if admin_privileges else "member", + "origin": SAML2_helper.get_saml2_provider(), + "name": " ".join(user_data.get("firstName", []) + user_data.get("lastName", [])), + "internal_id": internal_id, + "role_id": role["roleId"] + } + existing_data = { + "role": "admin" if existing["admin"] else "member", + "origin": existing["origin"], + "name": existing["name"], + "internal_id": existing["internalId"], + "role_id": existing["roleId"] + } + to_update = {} + for k in existing_data.keys(): + if (k != "role" or not existing["superAdmin"]) and existing_data[k] != received_data[k]: + to_update[k] = received_data[k] + + if len(to_update.keys()) > 0: + logger.info(f"== Updating user:{existing['userId']}: {to_update} ==") + users.update(tenant_id=t['tenantId'], user_id=existing["userId"], changes=to_update) + expiration = auth.get_session_expiration() expiration = expiration if expiration is not None and expiration > 10 * 60 \ else int(config("sso_exp_delta_seconds", cast=int, default=24 * 60 * 60))