From 985ce2812cc37af45fe4c6433699c088738674be Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 18 Apr 2025 16:24:00 +0200 Subject: [PATCH] feat(assist-api): added correct handlers for 403 and 404 on getByID --- ee/backend/cmd/assist-api/main.go | 4 ++-- ee/backend/pkg/assist/api/handlers.go | 18 ++++++++++++++++-- ee/backend/pkg/assist/builder.go | 19 ++++++------------- ee/backend/pkg/sessionmanager/manager.go | 7 +++++-- 4 files changed, 29 insertions(+), 19 deletions(-) diff --git a/ee/backend/cmd/assist-api/main.go b/ee/backend/cmd/assist-api/main.go index 77ce142f2..a73b249d7 100644 --- a/ee/backend/cmd/assist-api/main.go +++ b/ee/backend/cmd/assist-api/main.go @@ -41,7 +41,7 @@ func main() { defer redisClient.Close() prefix := api.NoPrefix - builder, err := assist.NewServiceBuilder(log, cfg, webMetrics, dbMetric, pgConn, redisClient, prefix) + builder, err := assist.NewServiceBuilder(log, cfg, webMetrics, dbMetric, pgConn, redisClient) if err != nil { log.Fatal(ctx, "can't init services: %s", err) } @@ -51,7 +51,7 @@ func main() { log.Fatal(ctx, "failed while creating router: %s", err) } router.AddHandlers(prefix, builder.AssistAPI) - router.AddMiddlewares(builder.RateLimiter.Middleware, builder.AuditTrail.Middleware) + router.AddMiddlewares(builder.RateLimiter.Middleware) server.Run(ctx, log, &cfg.HTTP, router) } diff --git a/ee/backend/pkg/assist/api/handlers.go b/ee/backend/pkg/assist/api/handlers.go index 60c6975dc..c26a2f541 100644 --- a/ee/backend/pkg/assist/api/handlers.go +++ b/ee/backend/pkg/assist/api/handlers.go @@ -3,6 +3,7 @@ package api import ( "context" "encoding/json" + "errors" "fmt" "net/http" "time" @@ -13,6 +14,7 @@ import ( "openreplay/backend/pkg/assist/service" "openreplay/backend/pkg/logger" "openreplay/backend/pkg/server/api" + "openreplay/backend/pkg/sessionmanager" ) type handlersImpl struct { @@ -124,7 +126,13 @@ func (e *handlersImpl) socketsListByProject(w http.ResponseWriter, r *http.Reque resp, err := e.assist.GetByID(projectKey, sessionID) if err != nil { - e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + if errors.Is(err, sessionmanager.ErrSessionNotFound) { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusNotFound, err, startTime, r.URL.Path, bodySize) + } else if errors.Is(err, sessionmanager.ErrSessionNotBelongToProject) { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusForbidden, err, startTime, r.URL.Path, bodySize) + } else { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + } return } response := map[string]interface{}{ @@ -183,7 +191,13 @@ func (e *handlersImpl) socketsLiveBySession(w http.ResponseWriter, r *http.Reque resp, err := e.assist.GetByID(projectKey, sessionID) if err != nil { - e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + if errors.Is(err, sessionmanager.ErrSessionNotFound) { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusNotFound, err, startTime, r.URL.Path, bodySize) + } else if errors.Is(err, sessionmanager.ErrSessionNotBelongToProject) { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusForbidden, err, startTime, r.URL.Path, bodySize) + } else { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + } return } response := map[string]interface{}{ diff --git a/ee/backend/pkg/assist/builder.go b/ee/backend/pkg/assist/builder.go index 83918a09e..154d5adb0 100644 --- a/ee/backend/pkg/assist/builder.go +++ b/ee/backend/pkg/assist/builder.go @@ -1,49 +1,42 @@ package assist import ( - "openreplay/backend/pkg/db/redis" - "openreplay/backend/pkg/projects" - "openreplay/backend/pkg/sessionmanager" "time" "openreplay/backend/internal/config/assist" assistAPI "openreplay/backend/pkg/assist/api" "openreplay/backend/pkg/assist/service" "openreplay/backend/pkg/db/postgres/pool" + "openreplay/backend/pkg/db/redis" "openreplay/backend/pkg/logger" "openreplay/backend/pkg/metrics/database" "openreplay/backend/pkg/metrics/web" + "openreplay/backend/pkg/projects" "openreplay/backend/pkg/server/api" "openreplay/backend/pkg/server/limiter" - "openreplay/backend/pkg/server/tracer" + "openreplay/backend/pkg/sessionmanager" ) type ServicesBuilder struct { RateLimiter *limiter.UserRateLimiter - AuditTrail tracer.Tracer AssistAPI api.Handlers } -func NewServiceBuilder(log logger.Logger, cfg *assist.Config, webMetrics web.Web, dbMetrics database.Database, pgconn pool.Pool, redis *redis.Client, prefix string) (*ServicesBuilder, error) { +func NewServiceBuilder(log logger.Logger, cfg *assist.Config, webMetrics web.Web, dbMetrics database.Database, pgconn pool.Pool, redis *redis.Client) (*ServicesBuilder, error) { projectsManager := projects.New(log, pgconn, redis, dbMetrics) sessManager, err := sessionmanager.New(log, cfg, redis.Redis) if err != nil { return nil, err } sessManager.Start() - assist := service.NewAssist(log, pgconn, projectsManager, sessManager) - auditrail, err := tracer.NewTracer(log, pgconn, dbMetrics) - if err != nil { - return nil, err - } + assistManager := service.NewAssist(log, pgconn, projectsManager, sessManager) responser := api.NewResponser(webMetrics) - handlers, err := assistAPI.NewHandlers(log, cfg, responser, assist) + handlers, err := assistAPI.NewHandlers(log, cfg, responser, assistManager) if err != nil { return nil, err } return &ServicesBuilder{ RateLimiter: limiter.NewUserRateLimiter(10, 30, 1*time.Minute, 5*time.Minute), - AuditTrail: auditrail, AssistAPI: handlers, }, nil } diff --git a/ee/backend/pkg/sessionmanager/manager.go b/ee/backend/pkg/sessionmanager/manager.go index d098dd871..477c23034 100644 --- a/ee/backend/pkg/sessionmanager/manager.go +++ b/ee/backend/pkg/sessionmanager/manager.go @@ -368,6 +368,9 @@ func (sm *sessionManagerImpl) updateSessions() { sm.log.Debug(sm.ctx, "Session processing cycle completed in %v. Processed %d sessions", duration, len(sm.cache)) } +var ErrSessionNotFound = errors.New("session not found") +var ErrSessionNotBelongToProject = errors.New("session does not belong to the project") + func (sm *sessionManagerImpl) GetByID(projectID, sessionID string) (interface{}, error) { if sessionID == "" { return nil, fmt.Errorf("session ID is required") @@ -378,10 +381,10 @@ func (sm *sessionManagerImpl) GetByID(projectID, sessionID string) (interface{}, sessionData, exists := sm.cache[sessionID] if !exists { - return nil, fmt.Errorf("session not found") + return nil, ErrSessionNotFound } if sessionData.ProjectID != projectID { - return nil, fmt.Errorf("session does not belong to the project") + return nil, ErrSessionNotBelongToProject } return sessionData.Raw, nil }