From fcc0195528acea535deec8df69f6e5a398d890d7 Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 24 Apr 2024 16:45:52 +0200 Subject: [PATCH] feat(backend): added projects filter to connector logic (#2130) --- backend/internal/config/connector/config.go | 22 ++++++++++++++++++ ee/backend/pkg/connector/saver.go | 25 +++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/backend/internal/config/connector/config.go b/backend/internal/config/connector/config.go index 0015bb140..491413035 100644 --- a/backend/internal/config/connector/config.go +++ b/backend/internal/config/connector/config.go @@ -5,6 +5,8 @@ import ( "openreplay/backend/internal/config/configurator" "openreplay/backend/internal/config/objectstorage" "openreplay/backend/internal/config/redis" + "strconv" + "strings" "time" ) @@ -24,6 +26,7 @@ type Config struct { TopicAnalytics string `env:"TOPIC_ANALYTICS,required"` CommitBatchTimeout time.Duration `env:"COMMIT_BATCH_TIMEOUT,default=5s"` UseProfiler bool `env:"PROFILER_ENABLED,default=false"` + ProjectIDs string `env:"PROJECT_IDS"` } func New() *Config { @@ -31,3 +34,22 @@ func New() *Config { configurator.Process(cfg) return cfg } + +func (c *Config) GetAllowedProjectIDs() []int { + stringIDs := strings.Split(c.ProjectIDs, ",") + if len(stringIDs) == 0 { + return nil + } + ids := make([]int, 0, len(stringIDs)) + for _, id := range stringIDs { + intID, err := strconv.Atoi(id) + if err != nil { + continue + } + ids = append(ids, intID) + } + if len(ids) == 0 { + return nil + } + return ids +} diff --git a/ee/backend/pkg/connector/saver.go b/ee/backend/pkg/connector/saver.go index 315d89e0e..a202c8191 100644 --- a/ee/backend/pkg/connector/saver.go +++ b/ee/backend/pkg/connector/saver.go @@ -26,6 +26,7 @@ type Saver struct { lastUpdate map[uint64]time.Time finishedSessions []uint64 events []map[string]string + projectIDs map[uint32]bool } func New(log logger.Logger, cfg *config.Config, db Database, sessions sessions.Sessions, projects projects.Projects) *Saver { @@ -41,6 +42,16 @@ func New(log logger.Logger, cfg *config.Config, db Database, sessions sessions.S if err := validateColumnNames(eventColumns); err != nil { log.Error(ctx, "can't validate events column names: %s", err) } + // Parse project IDs + projectIDs := make(map[uint32]bool, len(cfg.ProjectIDs)) + if len(cfg.GetAllowedProjectIDs()) == 0 { + log.Info(ctx, "empty project IDs white list") + projectIDs = nil + } else { + for _, id := range cfg.GetAllowedProjectIDs() { + projectIDs[uint32(id)] = true + } + } return &Saver{ log: log, cfg: cfg, @@ -49,6 +60,7 @@ func New(log logger.Logger, cfg *config.Config, db Database, sessions sessions.S projModule: projects, updatedSessions: make(map[uint64]bool, 0), lastUpdate: make(map[uint64]time.Time, 0), + projectIDs: projectIDs, } } @@ -410,6 +422,19 @@ func (s *Saver) handleSession(msg messages.Message) { } func (s *Saver) Handle(msg messages.Message) { + if s.projectIDs != nil { + // Check if project ID is allowed + sessInfo, err := s.sessModule.Get(msg.SessionID()) + if err != nil { + s.log.Error(context.Background(), "can't get session info: %s, skipping message", err) + return + } + if !s.projectIDs[sessInfo.ProjectID] { + s.log.Debug(context.Background(), "project ID %d is not allowed, skipping message", sessInfo.ProjectID) + return + } + s.log.Debug(context.Background(), "project ID %d is allowed", sessInfo.ProjectID) + } newEvent := handleEvent(msg) if newEvent != nil { if s.events == nil {