diff --git a/ee/api/auth/router_security.py b/ee/api/auth/router_security.py deleted file mode 100644 index 1b0c98980..000000000 --- a/ee/api/auth/router_security.py +++ /dev/null @@ -1,15 +0,0 @@ -from fastapi import HTTPException, Depends -from fastapi.security import SecurityScopes - -import schemas_ee -from or_dependencies import OR_context - - -def check(security_scopes: SecurityScopes, context: schemas_ee.CurrentContext = Depends(OR_context)): - for scope in security_scopes.scopes: - if scope not in context.permissions: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Not enough permissions", - ) - \ No newline at end of file diff --git a/ee/api/or_dependencies.py b/ee/api/or_dependencies.py index 4ca35476d..fed974c0d 100644 --- a/ee/api/or_dependencies.py +++ b/ee/api/or_dependencies.py @@ -1,13 +1,15 @@ import json from typing import Callable +from fastapi import HTTPException, Depends +from fastapi import Security from fastapi.routing import APIRoute +from fastapi.security import SecurityScopes from starlette import status from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import Response, JSONResponse -import schemas import schemas_ee from chalicelib.core import traces @@ -44,3 +46,14 @@ class ORRoute(APIRoute): return response return custom_route_handler + + +def check_permissions(security_scopes: SecurityScopes, context: schemas_ee.CurrentContext = Depends(OR_context)): + for scope in security_scopes.scopes: + if scope not in context.permissions: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not enough permissions") + + +def OR_scope(*scopes): + return Security(check_permissions, scopes=list(scopes)) diff --git a/ee/api/routers/core_dynamic.py b/ee/api/routers/core_dynamic.py index 75806aeca..a2a09c92d 100644 --- a/ee/api/routers/core_dynamic.py +++ b/ee/api/routers/core_dynamic.py @@ -1,25 +1,22 @@ from typing import Optional, Union from decouple import config -from fastapi import Body, Depends, BackgroundTasks, Security, HTTPException -from fastapi.security import SecurityScopes -from starlette import status +from fastapi import Body, Depends, BackgroundTasks from starlette.responses import RedirectResponse import schemas import schemas_ee -from schemas_ee import Permissions -from auth import router_security from chalicelib.core import sessions +from chalicelib.core import sessions_viewed from chalicelib.core import tenants, users, projects, license from chalicelib.core import webhook -from chalicelib.core import sessions_viewed from chalicelib.core.collaboration_slack import Slack from chalicelib.utils import SAML2_helper from chalicelib.utils import helper from chalicelib.utils.TimeUTC import TimeUTC -from or_dependencies import OR_context +from or_dependencies import OR_context, OR_scope from routers.base import get_routers +from schemas_ee import Permissions public_app, app, app_apikey = get_routers() @@ -179,10 +176,8 @@ def get_projects(context: schemas.CurrentContext = Depends(OR_context)): stack_integrations=True, user_id=context.user_id)} -@app.get('/{projectId}/sessions/{sessionId}', tags=["sessions"], - dependencies=[Security(router_security.check, scopes=[Permissions.session_replay])]) -@app.get('/{projectId}/sessions2/{sessionId}', tags=["sessions"], - dependencies=[Security(router_security.check, scopes=[Permissions.session_replay])]) +@app.get('/{projectId}/sessions/{sessionId}', tags=["sessions"], dependencies=[OR_scope(Permissions.session_replay)]) +@app.get('/{projectId}/sessions2/{sessionId}', tags=["sessions"], dependencies=[OR_scope(Permissions.session_replay)]) def get_session(projectId: int, sessionId: Union[int, str], background_tasks: BackgroundTasks, context: schemas.CurrentContext = Depends(OR_context)): if isinstance(sessionId, str): @@ -200,11 +195,9 @@ def get_session(projectId: int, sessionId: Union[int, str], background_tasks: Ba @app.get('/{projectId}/sessions/{sessionId}/errors/{errorId}/sourcemaps', tags=["sessions", "sourcemaps"], - dependencies=[Security(router_security.check, - scopes=[Permissions.session_replay, Permissions.errors])]) + dependencies=[OR_scope(Permissions.session_replay, Permissions.errors)]) @app.get('/{projectId}/sessions2/{sessionId}/errors/{errorId}/sourcemaps', tags=["sessions", "sourcemaps"], - dependencies=[Security(router_security.check, - scopes=[Permissions.session_replay, Permissions.errors])]) + dependencies=[OR_scope(Permissions.session_replay, Permissions.errors)]) def get_error_trace(projectId: int, sessionId: int, errorId: str, context: schemas.CurrentContext = Depends(OR_context)): data = errors.get_trace(project_id=projectId, error_id=errorId) @@ -215,25 +208,21 @@ def get_error_trace(projectId: int, sessionId: int, errorId: str, } -@app.post('/{projectId}/errors/search', tags=['errors'], - dependencies=[Security(router_security.check, scopes=[Permissions.errors])]) +@app.post('/{projectId}/errors/search', tags=['errors'], dependencies=[OR_scope(Permissions.errors)]) def errors_search(projectId: int, data: schemas.SearchErrorsSchema = Body(...), context: schemas.CurrentContext = Depends(OR_context)): return {"data": errors.search(data, projectId, user_id=context.user_id)} -@app.get('/{projectId}/errors/stats', tags=['errors'], - dependencies=[Security(router_security.check, scopes=[Permissions.errors])]) +@app.get('/{projectId}/errors/stats', tags=['errors'], dependencies=[OR_scope(Permissions.errors)]) def errors_stats(projectId: int, startTimestamp: int, endTimestamp: int, context: schemas.CurrentContext = Depends(OR_context)): return errors.stats(projectId, user_id=context.user_id, startTimestamp=startTimestamp, endTimestamp=endTimestamp) -@app.get('/{projectId}/errors/{errorId}', tags=['errors'], - dependencies=[Security(router_security.check, scopes=[Permissions.errors])]) +@app.get('/{projectId}/errors/{errorId}', tags=['errors'], dependencies=[OR_scope(Permissions.errors)]) def errors_get_details(projectId: int, errorId: str, background_tasks: BackgroundTasks, density24: int = 24, - density30: int = 30, - context: schemas.CurrentContext = Depends(OR_context)): + density30: int = 30, context: schemas.CurrentContext = Depends(OR_context)): data = errors.get_details(project_id=projectId, user_id=context.user_id, error_id=errorId, **{"density24": density24, "density30": density30}) if data.get("data") is not None: @@ -242,8 +231,7 @@ def errors_get_details(projectId: int, errorId: str, background_tasks: Backgroun return data -@app.get('/{projectId}/errors/{errorId}/stats', tags=['errors'], - dependencies=[Security(router_security.check, scopes=[Permissions.errors])]) +@app.get('/{projectId}/errors/{errorId}/stats', tags=['errors'], dependencies=[OR_scope(Permissions.errors)]) def errors_get_details_right_column(projectId: int, errorId: str, startDate: int = TimeUTC.now(-7), endDate: int = TimeUTC.now(), density: int = 7, context: schemas.CurrentContext = Depends(OR_context)): @@ -252,8 +240,7 @@ def errors_get_details_right_column(projectId: int, errorId: str, startDate: int return data -@app.get('/{projectId}/errors/{errorId}/sourcemaps', tags=['errors'], - dependencies=[Security(router_security.check, scopes=[Permissions.errors])]) +@app.get('/{projectId}/errors/{errorId}/sourcemaps', tags=['errors'], dependencies=[OR_scope(Permissions.errors)]) def errors_get_details_sourcemaps(projectId: int, errorId: str, context: schemas.CurrentContext = Depends(OR_context)): data = errors.get_trace(project_id=projectId, error_id=errorId) @@ -264,8 +251,7 @@ def errors_get_details_sourcemaps(projectId: int, errorId: str, } -@app.get('/{projectId}/errors/{errorId}/{action}', tags=["errors"], - dependencies=[Security(router_security.check, scopes=[Permissions.errors])]) +@app.get('/{projectId}/errors/{errorId}/{action}', tags=["errors"], dependencies=[OR_scope(Permissions.errors)]) def add_remove_favorite_error(projectId: int, errorId: str, action: str, startDate: int = TimeUTC.now(-7), endDate: int = TimeUTC.now(), context: schemas.CurrentContext = Depends(OR_context)): if action == "favorite": @@ -282,8 +268,7 @@ def add_remove_favorite_error(projectId: int, errorId: str, action: str, startDa return {"errors": ["undefined action"]} -@app.get('/{projectId}/assist/sessions/{sessionId}', tags=["assist"], - dependencies=[Security(router_security.check, scopes=[Permissions.assist_live])]) +@app.get('/{projectId}/assist/sessions/{sessionId}', tags=["assist"], dependencies=[OR_scope(Permissions.assist_live)]) def get_live_session(projectId: int, sessionId: str, background_tasks: BackgroundTasks, context: schemas.CurrentContext = Depends(OR_context)): data = assist.get_live_session_by_id(project_id=projectId, session_id=sessionId) @@ -298,11 +283,10 @@ def get_live_session(projectId: int, sessionId: str, background_tasks: Backgroun return {'data': data} -@app.get('/{projectId}/unprocessed/{sessionId}', tags=["assist"], dependencies=[Security(router_security.check, scopes=[ - Permissions.assist_live, Permissions.session_replay])]) -@app.get('/{projectId}/assist/sessions/{sessionId}/replay', tags=["assist"], dependencies=[ - Security(router_security.check, - scopes=[Permissions.assist_live, Permissions.session_replay])]) +@app.get('/{projectId}/unprocessed/{sessionId}', tags=["assist"], + dependencies=[OR_scope(Permissions.assist_live, Permissions.session_replay)]) +@app.get('/{projectId}/assist/sessions/{sessionId}/replay', tags=["assist"], + dependencies=[OR_scope(Permissions.assist_live, Permissions.session_replay)]) def get_live_session_replay_file(projectId: int, sessionId: Union[int, str], context: schemas.CurrentContext = Depends(OR_context)): if isinstance(sessionId, str) or not sessions.session_exists(project_id=projectId, session_id=sessionId): @@ -319,28 +303,26 @@ def get_live_session_replay_file(projectId: int, sessionId: Union[int, str], return FileResponse(path=path, media_type="application/octet-stream") -@app.post('/{projectId}/heatmaps/url', tags=["heatmaps"], - dependencies=[Security(router_security.check, scopes=[Permissions.session_replay])]) +@app.post('/{projectId}/heatmaps/url', tags=["heatmaps"], dependencies=[OR_scope(Permissions.session_replay)]) def get_heatmaps_by_url(projectId: int, data: schemas.GetHeatmapPayloadSchema = Body(...), context: schemas.CurrentContext = Depends(OR_context)): return {"data": heatmaps.get_by_url(project_id=projectId, data=data.dict())} @app.get('/{projectId}/sessions/{sessionId}/favorite', tags=["sessions"], - dependencies=[Security(router_security.check, scopes=[Permissions.session_replay])]) + dependencies=[OR_scope(Permissions.session_replay)]) @app.get('/{projectId}/sessions2/{sessionId}/favorite', tags=["sessions"], - dependencies=[Security(router_security.check, scopes=[Permissions.session_replay])]) -def add_remove_favorite_session2(projectId: int, sessionId: int, - context: schemas.CurrentContext = Depends(OR_context)): + dependencies=[OR_scope(Permissions.session_replay)]) +def add_remove_favorite_session2(projectId: int, sessionId: int, context: schemas.CurrentContext = Depends(OR_context)): return { "data": sessions_favorite.favorite_session(project_id=projectId, user_id=context.user_id, session_id=sessionId)} @app.get('/{projectId}/sessions/{sessionId}/assign', tags=["sessions"], - dependencies=[Security(router_security.check, scopes=[Permissions.session_replay])]) + dependencies=[OR_scope(Permissions.session_replay)]) @app.get('/{projectId}/sessions2/{sessionId}/assign', tags=["sessions"], - dependencies=[Security(router_security.check, scopes=[Permissions.session_replay])]) + dependencies=[OR_scope(Permissions.session_replay)]) def assign_session(projectId: int, sessionId, context: schemas.CurrentContext = Depends(OR_context)): data = sessions_assignments.get_by_session(project_id=projectId, session_id=sessionId, tenant_id=context.tenant_id, @@ -353,9 +335,9 @@ def assign_session(projectId: int, sessionId, context: schemas.CurrentContext = @app.get('/{projectId}/sessions/{sessionId}/assign/{issueId}', tags=["sessions", "issueTracking"], - dependencies=[Security(router_security.check, scopes=[Permissions.session_replay])]) + dependencies=[OR_scope(Permissions.session_replay)]) @app.get('/{projectId}/sessions2/{sessionId}/assign/{issueId}', tags=["sessions", "issueTracking"], - dependencies=[Security(router_security.check, scopes=[Permissions.session_replay])]) + dependencies=[OR_scope(Permissions.session_replay)]) def assign_session(projectId: int, sessionId: int, issueId: str, context: schemas.CurrentContext = Depends(OR_context)): data = sessions_assignments.get(project_id=projectId, session_id=sessionId, assignment_id=issueId, @@ -368,13 +350,13 @@ def assign_session(projectId: int, sessionId: int, issueId: str, @app.post('/{projectId}/sessions/{sessionId}/assign/{issueId}/comment', tags=["sessions", "issueTracking"], - dependencies=[Security(router_security.check, scopes=[Permissions.session_replay])]) + dependencies=[OR_scope(Permissions.session_replay)]) @app.put('/{projectId}/sessions/{sessionId}/assign/{issueId}/comment', tags=["sessions", "issueTracking"], - dependencies=[Security(router_security.check, scopes=[Permissions.session_replay])]) + dependencies=[OR_scope(Permissions.session_replay)]) @app.post('/{projectId}/sessions2/{sessionId}/assign/{issueId}/comment', tags=["sessions", "issueTracking"], - dependencies=[Security(router_security.check, scopes=[Permissions.session_replay])]) + dependencies=[OR_scope(Permissions.session_replay)]) @app.put('/{projectId}/sessions2/{sessionId}/assign/{issueId}/comment', tags=["sessions", "issueTracking"], - dependencies=[Security(router_security.check, scopes=[Permissions.session_replay])]) + dependencies=[OR_scope(Permissions.session_replay)]) def comment_assignment(projectId: int, sessionId: int, issueId: str, data: schemas.CommentAssignmentSchema = Body(...), context: schemas.CurrentContext = Depends(OR_context)): data = sessions_assignments.comment(tenant_id=context.tenant_id, project_id=projectId,