diff --git a/ee/backend/cmd/assist-api/main.go b/ee/backend/cmd/assist-api/main.go index a73b249d7..b7c944e56 100644 --- a/ee/backend/cmd/assist-api/main.go +++ b/ee/backend/cmd/assist-api/main.go @@ -45,6 +45,9 @@ func main() { if err != nil { log.Fatal(ctx, "can't init services: %s", err) } + defer func() { + builder.AssistStats.Stop() + }() router, err := api.NewRouter(&cfg.HTTP, log) if err != nil { diff --git a/ee/backend/pkg/assist/builder.go b/ee/backend/pkg/assist/builder.go index 154d5adb0..c08067651 100644 --- a/ee/backend/pkg/assist/builder.go +++ b/ee/backend/pkg/assist/builder.go @@ -20,6 +20,7 @@ import ( type ServicesBuilder struct { RateLimiter *limiter.UserRateLimiter AssistAPI api.Handlers + AssistStats service.AssistStats } func NewServiceBuilder(log logger.Logger, cfg *assist.Config, webMetrics web.Web, dbMetrics database.Database, pgconn pool.Pool, redis *redis.Client) (*ServicesBuilder, error) { @@ -29,6 +30,10 @@ func NewServiceBuilder(log logger.Logger, cfg *assist.Config, webMetrics web.Web return nil, err } sessManager.Start() + assistStats, err := service.NewAssistStats(log, pgconn, redis.Redis) + if err != nil { + return nil, err + } assistManager := service.NewAssist(log, pgconn, projectsManager, sessManager) responser := api.NewResponser(webMetrics) handlers, err := assistAPI.NewHandlers(log, cfg, responser, assistManager) @@ -38,5 +43,6 @@ func NewServiceBuilder(log logger.Logger, cfg *assist.Config, webMetrics web.Web return &ServicesBuilder{ RateLimiter: limiter.NewUserRateLimiter(10, 30, 1*time.Minute, 5*time.Minute), AssistAPI: handlers, + AssistStats: assistStats, }, nil } diff --git a/ee/backend/pkg/assist/service/stats.go b/ee/backend/pkg/assist/service/stats.go new file mode 100644 index 000000000..475b27d3d --- /dev/null +++ b/ee/backend/pkg/assist/service/stats.go @@ -0,0 +1,122 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "time" + + "github.com/redis/go-redis/v9" + + "openreplay/backend/pkg/db/postgres/pool" + "openreplay/backend/pkg/logger" +) + +type assistStatsImpl struct { + log logger.Logger + pgClient pool.Pool + redisClient *redis.Client + ticker *time.Ticker + stopChan chan struct{} +} + +type AssistStats interface { + Stop() +} + +func NewAssistStats(log logger.Logger, pgClient pool.Pool, redisClient *redis.Client) (AssistStats, error) { + switch { + case log == nil: + return nil, errors.New("logger is empty") + case pgClient == nil: + return nil, errors.New("pg client is empty") + case redisClient == nil: + return nil, errors.New("redis client is empty") + } + stats := &assistStatsImpl{ + log: log, + pgClient: pgClient, + redisClient: redisClient, + ticker: time.NewTicker(time.Minute), + stopChan: make(chan struct{}), + } + stats.init() + return stats, nil +} + +func (as *assistStatsImpl) init() { + as.log.Debug(context.Background(), "Starting assist stats") + + go func() { + for { + select { + case <-as.ticker.C: + as.loadData() + case <-as.stopChan: + as.log.Debug(context.Background(), "Stopping assist stats") + return + } + } + }() +} + +type AssistStatsEvent struct { + ProjectID uint32 `json:"project_id"` + SessionID string `json:"session_id"` + AgentID string `json:"agent_id"` + EventID string `json:"event_id"` + EventType string `json:"event_type"` + EventState string `json:"event_state"` + Timestamp int64 `json:"timestamp"` +} + +func (as *assistStatsImpl) loadData() { + ctx := context.Background() + + events, err := as.redisClient.LPopCount(ctx, "assist:stats", 1000).Result() + if err != nil { + as.log.Error(ctx, "Failed to load data from redis: ", err) + return + } + if len(events) == 0 { + as.log.Debug(ctx, "No data to load from redis") + return + } + as.log.Debug(ctx, "Loaded %d events from redis", len(events)) + + for _, event := range events { + e := &AssistStatsEvent{} + err := json.Unmarshal([]byte(event), &e) + if err != nil { + as.log.Error(ctx, "Failed to unmarshal event: ", err) + continue + } + switch e.EventType { + case "start": + err = as.insertEvent(e) + case "end": + err = as.updateEvent(e) + default: + as.log.Warn(ctx, "Unknown event type: %s", e.EventType) + } + if err != nil { + as.log.Error(ctx, "Failed to process event: ", err) + continue + } + } +} + +func (as *assistStatsImpl) insertEvent(event *AssistStatsEvent) error { + insertQuery := `INSERT INTO assist_events (event_id, project_id, session_id, agent_id, event_type, timestamp) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (event_id) DO NOTHING` + return as.pgClient.Exec(insertQuery, event.EventID, event.ProjectID, event.SessionID, event.AgentID, event.EventType, event.Timestamp) +} + +func (as *assistStatsImpl) updateEvent(event *AssistStatsEvent) error { + updateQuery := `UPDATE assist_events SET duration = $1 - timestamp WHERE event_id = $2` + return as.pgClient.Exec(updateQuery, event.Timestamp, event.EventID) +} + +func (as *assistStatsImpl) Stop() { + close(as.stopChan) + as.ticker.Stop() +} diff --git a/ee/backend/pkg/sessionmanager/manager.go b/ee/backend/pkg/sessionmanager/manager.go index 477c23034..40ee2a620 100644 --- a/ee/backend/pkg/sessionmanager/manager.go +++ b/ee/backend/pkg/sessionmanager/manager.go @@ -10,7 +10,7 @@ import ( "sync" "time" - "github.com/go-redis/redis" + "github.com/redis/go-redis/v9" "openreplay/backend/internal/config/assist" "openreplay/backend/pkg/logger" @@ -119,7 +119,7 @@ func (sm *sessionManagerImpl) getNodeIDs() ([]string, error) { var cursor uint64 = 0 for { - keys, nextCursor, err := sm.client.Scan(cursor, NodeKeyPattern, 100).Result() + keys, nextCursor, err := sm.client.Scan(sm.ctx, cursor, NodeKeyPattern, 100).Result() if err != nil { return nil, fmt.Errorf("scan failed: %v", err) } @@ -144,7 +144,7 @@ func (sm *sessionManagerImpl) getAllNodeSessions(nodeIDs []string) map[string]st go func(id string) { defer wg.Done() - sessionListJSON, err := sm.client.Get(id).Result() + sessionListJSON, err := sm.client.Get(sm.ctx, id).Result() if err != nil { if errors.Is(err, redis.Nil) { return @@ -198,7 +198,7 @@ func (sm *sessionManagerImpl) getSessionData(sessionIDs []string) map[string]*Se keys[j] = ActiveSessionPrefix + id } - results, err := sm.client.MGet(keys...).Result() + results, err := sm.client.MGet(sm.ctx, keys...).Result() if err != nil { sm.log.Debug(sm.ctx, "Error in MGET operation: %v", err) continue // TODO: Handle the error @@ -294,7 +294,7 @@ func (sm *sessionManagerImpl) getAllRecentlyUpdatedSessions() (map[string]struct ) for { - batchIDs, cursor, err = sm.client.SScan(RecentlyUpdatedSessions, cursor, "*", sm.scanSize).Result() + batchIDs, cursor, err = sm.client.SScan(sm.ctx, RecentlyUpdatedSessions, cursor, "*", sm.scanSize).Result() if err != nil { sm.log.Debug(sm.ctx, "Error scanning updated session IDs: %v", err) return nil, err @@ -316,7 +316,7 @@ func (sm *sessionManagerImpl) getAllRecentlyUpdatedSessions() (map[string]struct for id := range allIDs { sessionIDsSlice = append(sessionIDsSlice, id) } - removed := sm.client.SRem(RecentlyUpdatedSessions, sessionIDsSlice...).Val() + removed := sm.client.SRem(sm.ctx, RecentlyUpdatedSessions, sessionIDsSlice...).Val() sm.log.Debug(sm.ctx, "Fetched and removed %d session IDs from updated_session_set", removed) return allIDs, nil