563 lines
14 KiB
Go
563 lines
14 KiB
Go
package sessionmanager
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/go-redis/redis"
|
|
|
|
"openreplay/backend/pkg/logger"
|
|
)
|
|
|
|
const (
|
|
TickInterval = 10 * time.Second
|
|
NodeKeyPattern = "assist:nodes:*"
|
|
ActiveSessionPrefix = "assist:online_sessions:"
|
|
RecentlyUpdatedSessions = "assist:updated_sessions"
|
|
BatchSize = 1000
|
|
)
|
|
|
|
type SessionData struct {
|
|
Timestamp uint64 `json:"timestamp"`
|
|
ProjectID string `json:"projectID"`
|
|
SessionID string `json:"sessionID"`
|
|
UserID *string `json:"userID"`
|
|
UserUUID *string `json:"userUUID"`
|
|
UserOS *string `json:"userOs"`
|
|
UserBrowser *string `json:"userBrowser"`
|
|
UserDevice *string `json:"userDevice"`
|
|
UserPlatform *string `json:"userDeviceType"` // is
|
|
UserCountry *string `json:"userCountry"` // is
|
|
UserState *string `json:"userState"` // is
|
|
UserCity *string `json:"userCity"` // is
|
|
Metadata *map[string]string `json:"metadata"` // contains
|
|
Raw interface{}
|
|
}
|
|
|
|
type SessionManager interface {
|
|
Start()
|
|
Stop()
|
|
GetByID(projectID, sessionID string) (*SessionData, error)
|
|
GetAll(projectID string, filters []*Filter, sort SortOrder, page, limit int) ([]interface{}, int, map[string]map[string]int, error)
|
|
Autocomplete(projectID string, key FilterType, value string) ([]interface{}, error)
|
|
}
|
|
|
|
type sessionManagerImpl struct {
|
|
ctx context.Context
|
|
log logger.Logger
|
|
client *redis.Client
|
|
ticker *time.Ticker
|
|
wg *sync.WaitGroup
|
|
stopChan chan struct{}
|
|
mutex *sync.RWMutex
|
|
cache map[string]*SessionData
|
|
sorted []*SessionData
|
|
}
|
|
|
|
func New(log logger.Logger, redis *redis.Client) (SessionManager, error) {
|
|
switch {
|
|
case log == nil:
|
|
return nil, fmt.Errorf("logger is required")
|
|
case redis == nil:
|
|
return nil, fmt.Errorf("redis client is required")
|
|
}
|
|
sm := &sessionManagerImpl{
|
|
ctx: context.Background(),
|
|
log: log,
|
|
client: redis,
|
|
ticker: time.NewTicker(TickInterval),
|
|
wg: &sync.WaitGroup{},
|
|
stopChan: make(chan struct{}),
|
|
mutex: &sync.RWMutex{},
|
|
cache: make(map[string]*SessionData),
|
|
sorted: make([]*SessionData, 0),
|
|
}
|
|
return sm, nil
|
|
}
|
|
|
|
func (sm *sessionManagerImpl) Start() {
|
|
sm.log.Debug(sm.ctx, "Starting session manager...")
|
|
|
|
go func() {
|
|
sm.loadSessions()
|
|
for {
|
|
select {
|
|
case <-sm.ticker.C:
|
|
sm.updateSessions()
|
|
case <-sm.stopChan:
|
|
sm.log.Debug(sm.ctx, "Stopping session manager...")
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (sm *sessionManagerImpl) Stop() {
|
|
close(sm.stopChan)
|
|
sm.ticker.Stop()
|
|
sm.wg.Wait()
|
|
if err := sm.client.Close(); err != nil {
|
|
sm.log.Debug(sm.ctx, "Error closing Redis connection: %v", err)
|
|
}
|
|
sm.log.Debug(sm.ctx, "Session manager stopped")
|
|
}
|
|
|
|
func (sm *sessionManagerImpl) getNodeIDs() ([]string, error) {
|
|
var nodeIDs = make([]string, 0, 16) // Let's assume we have at most 16 nodes
|
|
var cursor uint64 = 0
|
|
|
|
for {
|
|
keys, nextCursor, err := sm.client.Scan(cursor, NodeKeyPattern, 100).Result()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("scan failed: %v", err)
|
|
}
|
|
for _, key := range keys {
|
|
nodeIDs = append(nodeIDs, key)
|
|
}
|
|
cursor = nextCursor
|
|
if cursor == 0 {
|
|
break
|
|
}
|
|
}
|
|
return nodeIDs, nil
|
|
}
|
|
|
|
func (sm *sessionManagerImpl) getAllNodeSessions(nodeIDs []string) map[string]struct{} {
|
|
allSessionIDs := make(map[string]struct{})
|
|
var mu sync.Mutex
|
|
var wg sync.WaitGroup
|
|
|
|
for _, nodeID := range nodeIDs {
|
|
wg.Add(1)
|
|
go func(id string) {
|
|
defer wg.Done()
|
|
|
|
sessionListJSON, err := sm.client.Get(id).Result()
|
|
if err != nil {
|
|
if errors.Is(err, redis.Nil) {
|
|
return
|
|
}
|
|
sm.log.Debug(sm.ctx, "Error getting session list for node %s: %v", id, err)
|
|
return
|
|
}
|
|
|
|
var sessionList []string
|
|
if err = json.Unmarshal([]byte(sessionListJSON), &sessionList); err != nil {
|
|
sm.log.Debug(sm.ctx, "Error unmarshalling session list for node %s: %v", id, err)
|
|
return
|
|
}
|
|
|
|
mu.Lock()
|
|
for _, sessionID := range sessionList {
|
|
allSessionIDs[sessionID] = struct{}{}
|
|
}
|
|
mu.Unlock()
|
|
}(nodeID)
|
|
}
|
|
wg.Wait()
|
|
return allSessionIDs
|
|
}
|
|
|
|
func (sm *sessionManagerImpl) getOnlineSessionIDs() (map[string]struct{}, error) {
|
|
nodeIDs, err := sm.getNodeIDs()
|
|
if err != nil {
|
|
sm.log.Debug(sm.ctx, "Error getting node IDs: %v", err)
|
|
return nil, err
|
|
}
|
|
sm.log.Debug(sm.ctx, "Found %d nodes", len(nodeIDs))
|
|
|
|
allSessionIDs := sm.getAllNodeSessions(nodeIDs)
|
|
sm.log.Debug(sm.ctx, "Collected %d unique session IDs", len(allSessionIDs))
|
|
return allSessionIDs, nil
|
|
}
|
|
|
|
func (sm *sessionManagerImpl) getSessionData(sessionIDs []string) map[string]*SessionData {
|
|
sessionData := make(map[string]*SessionData, len(sessionIDs))
|
|
|
|
for i := 0; i < len(sessionIDs); i += BatchSize {
|
|
end := i + BatchSize
|
|
if end > len(sessionIDs) {
|
|
end = len(sessionIDs)
|
|
}
|
|
batch := sessionIDs[i:end]
|
|
|
|
keys := make([]string, len(batch))
|
|
for j, id := range batch {
|
|
keys[j] = ActiveSessionPrefix + id
|
|
}
|
|
|
|
results, err := sm.client.MGet(keys...).Result()
|
|
if err != nil {
|
|
sm.log.Debug(sm.ctx, "Error in MGET operation: %v", err)
|
|
continue // TODO: Handle the error
|
|
}
|
|
|
|
for j, result := range results {
|
|
if result == nil {
|
|
continue
|
|
}
|
|
|
|
strVal, ok := result.(string)
|
|
if !ok {
|
|
sm.log.Debug(sm.ctx, "Unexpected type for session data: %T", result)
|
|
continue
|
|
}
|
|
|
|
var data SessionData
|
|
if err := json.Unmarshal([]byte(strVal), &data); err != nil {
|
|
sm.log.Debug(sm.ctx, "Error unmarshalling session data: %v", err)
|
|
continue
|
|
}
|
|
raw := make(map[string]interface{})
|
|
if err := json.Unmarshal([]byte(strVal), &raw); err != nil {
|
|
sm.log.Debug(sm.ctx, "Error unmarshalling raw session data: %v", err)
|
|
continue
|
|
}
|
|
data.Raw = raw
|
|
sessionData[batch[j]] = &data
|
|
}
|
|
sm.log.Debug(sm.ctx, "Collected %d sessions", len(results))
|
|
}
|
|
|
|
sm.wg.Wait()
|
|
return sessionData
|
|
}
|
|
|
|
func (sm *sessionManagerImpl) updateCache(sessionsToAdd map[string]*SessionData, sessionsToRemove []string) {
|
|
sm.mutex.Lock()
|
|
defer sm.mutex.Unlock()
|
|
|
|
if sessionsToRemove != nil {
|
|
for _, sessID := range sessionsToRemove {
|
|
delete(sm.cache, sessID)
|
|
}
|
|
}
|
|
if sessionsToAdd == nil {
|
|
return
|
|
}
|
|
for sessID, session := range sessionsToAdd {
|
|
sm.cache[sessID] = session
|
|
}
|
|
|
|
sessionList := make([]*SessionData, 0, len(sm.cache))
|
|
for _, s := range sm.cache {
|
|
sessionList = append(sessionList, s)
|
|
}
|
|
sort.Slice(sessionList, func(i, j int) bool {
|
|
return sessionList[i].Timestamp < sessionList[j].Timestamp
|
|
})
|
|
sm.sorted = sessionList
|
|
}
|
|
|
|
func (sm *sessionManagerImpl) loadSessions() {
|
|
startTime := time.Now()
|
|
sm.log.Debug(sm.ctx, "Starting session processing cycle")
|
|
|
|
sessIDs, err := sm.getOnlineSessionIDs()
|
|
if err != nil {
|
|
sm.log.Debug(sm.ctx, "Error getting online session IDs: %v", err)
|
|
return
|
|
}
|
|
if len(sessIDs) == 0 {
|
|
sm.log.Debug(sm.ctx, "No sessions found for nodes")
|
|
return
|
|
}
|
|
allSessionIDsList := make([]string, 0, len(sessIDs))
|
|
for sessionID := range sessIDs {
|
|
allSessionIDsList = append(allSessionIDsList, sessionID)
|
|
}
|
|
sessionMap := sm.getSessionData(allSessionIDsList)
|
|
sm.updateCache(sessionMap, nil)
|
|
|
|
duration := time.Since(startTime)
|
|
sm.log.Debug(sm.ctx, "Session processing cycle completed in %v. Processed %d sessions", duration, len(sm.cache))
|
|
}
|
|
|
|
func (sm *sessionManagerImpl) getAllRecentlyUpdatedSessions() (map[string]struct{}, error) {
|
|
var (
|
|
cursor uint64
|
|
allIDs = make(map[string]struct{})
|
|
batchIDs []string
|
|
err error
|
|
)
|
|
|
|
for {
|
|
batchIDs, cursor, err = sm.client.SScan(RecentlyUpdatedSessions, cursor, "*", 1000).Result()
|
|
if err != nil {
|
|
sm.log.Debug(sm.ctx, "Error scanning updated session IDs: %v", err)
|
|
return nil, err
|
|
}
|
|
for _, id := range batchIDs {
|
|
allIDs[id] = struct{}{}
|
|
}
|
|
if cursor == 0 {
|
|
break
|
|
}
|
|
}
|
|
|
|
if len(allIDs) == 0 {
|
|
sm.log.Debug(sm.ctx, "No updated session IDs found")
|
|
return allIDs, nil
|
|
}
|
|
|
|
var sessionIDsSlice []interface{}
|
|
for id := range allIDs {
|
|
sessionIDsSlice = append(sessionIDsSlice, id)
|
|
}
|
|
removed := sm.client.SRem(RecentlyUpdatedSessions, sessionIDsSlice...).Val()
|
|
sm.log.Debug(sm.ctx, "Fetched and removed %d session IDs from updated_session_set", removed)
|
|
|
|
return allIDs, nil
|
|
}
|
|
|
|
func (sm *sessionManagerImpl) updateSessions() {
|
|
startTime := time.Now()
|
|
sm.log.Debug(sm.ctx, "Starting session processing cycle")
|
|
|
|
sessIDs, err := sm.getOnlineSessionIDs()
|
|
if err != nil {
|
|
sm.log.Debug(sm.ctx, "Error getting online session IDs: %v", err)
|
|
return
|
|
}
|
|
|
|
updatedSessIDs, err := sm.getAllRecentlyUpdatedSessions()
|
|
if err != nil {
|
|
sm.log.Debug(sm.ctx, "Error getting recently updated sessions: %v", err)
|
|
return
|
|
}
|
|
|
|
sm.mutex.RLock()
|
|
toAdd := make([]string, 0, len(updatedSessIDs))
|
|
if updatedSessIDs == nil {
|
|
updatedSessIDs = make(map[string]struct{})
|
|
}
|
|
for sessID, _ := range sessIDs {
|
|
if _, exists := sm.cache[sessID]; !exists {
|
|
updatedSessIDs[sessID] = struct{}{} // Add to updated sessions if not in cache
|
|
}
|
|
}
|
|
for sessID, _ := range updatedSessIDs {
|
|
toAdd = append(toAdd, sessID)
|
|
}
|
|
|
|
toRemove := make([]string, 0, len(sessIDs)/16)
|
|
for sessID, _ := range sm.cache {
|
|
if _, exists := sessIDs[sessID]; !exists {
|
|
toRemove = append(toRemove, sessID)
|
|
}
|
|
}
|
|
sm.mutex.RUnlock()
|
|
|
|
// Load full session data from Redis
|
|
newCache := sm.getSessionData(toAdd)
|
|
sm.updateCache(newCache, toRemove)
|
|
|
|
duration := time.Since(startTime)
|
|
sm.log.Debug(sm.ctx, "Session processing cycle completed in %v. Processed %d sessions", duration, len(sm.cache))
|
|
}
|
|
|
|
func (sm *sessionManagerImpl) GetByID(projectID, sessionID string) (*SessionData, error) {
|
|
if sessionID == "" {
|
|
return nil, fmt.Errorf("session ID is required")
|
|
}
|
|
|
|
sm.mutex.RLock()
|
|
defer sm.mutex.RUnlock()
|
|
|
|
sessionData, exists := sm.cache[sessionID]
|
|
if !exists {
|
|
return nil, fmt.Errorf("session not found")
|
|
}
|
|
if sessionData.ProjectID != projectID {
|
|
return nil, fmt.Errorf("session does not belong to the project")
|
|
}
|
|
return sessionData, nil
|
|
}
|
|
|
|
func (sm *sessionManagerImpl) GetAll(projectID string, filters []*Filter, sort SortOrder, page, limit int) ([]interface{}, int, map[string]map[string]int, error) {
|
|
if page < 1 || limit < 1 {
|
|
page, limit = 1, 10 // Set default values
|
|
}
|
|
|
|
sm.mutex.RLock()
|
|
defer sm.mutex.RUnlock()
|
|
|
|
// Prepare filter counter
|
|
counter := make(map[string]map[string]int)
|
|
for _, filter := range filters {
|
|
if _, ok := counter[string(filter.Type)]; !ok {
|
|
counter[string(filter.Type)] = make(map[string]int)
|
|
}
|
|
for _, value := range filter.Value {
|
|
counter[string(filter.Type)][value] = 0
|
|
}
|
|
}
|
|
filtered := make([]interface{}, 0, limit)
|
|
for _, session := range sm.sorted {
|
|
sm.log.Info(sm.ctx, "projectID: %s, sessionID: %s", session.ProjectID, session.SessionID)
|
|
if session.ProjectID != projectID {
|
|
continue
|
|
}
|
|
if matchesFilters(session, filters, counter) {
|
|
filtered = append(filtered, session.Raw)
|
|
}
|
|
}
|
|
|
|
start := (page - 1) * limit
|
|
end := start + limit
|
|
if start > len(filtered) {
|
|
return []interface{}{}, 0, make(map[string]map[string]int), nil
|
|
}
|
|
if end > len(filtered) {
|
|
end = len(filtered)
|
|
}
|
|
return filtered[start:end], len(filtered), counter, nil
|
|
}
|
|
|
|
func matchesFilters(session *SessionData, filters []*Filter, counter map[string]map[string]int) bool {
|
|
if filters == nil || len(filters) == 0 {
|
|
return true
|
|
}
|
|
matchedFilters := make(map[string][]string, len(filters))
|
|
for _, filter := range filters {
|
|
name := string(filter.Type)
|
|
if _, ok := matchedFilters[name]; !ok {
|
|
matchedFilters[name] = make([]string, 0, len(filter.Value))
|
|
}
|
|
var value string
|
|
|
|
switch filter.Type {
|
|
case UserID:
|
|
if session.UserID != nil {
|
|
value = *session.UserID
|
|
}
|
|
case UserAnonymousID:
|
|
if session.UserUUID != nil {
|
|
value = *session.UserUUID
|
|
}
|
|
case UserOS:
|
|
if session.UserOS != nil {
|
|
value = *session.UserOS
|
|
}
|
|
case UserBrowser:
|
|
if session.UserBrowser != nil {
|
|
value = *session.UserBrowser
|
|
}
|
|
case UserDevice:
|
|
if session.UserDevice != nil {
|
|
value = *session.UserDevice
|
|
}
|
|
case UserPlatform:
|
|
if session.UserPlatform != nil {
|
|
value = *session.UserPlatform
|
|
}
|
|
case UserCountry:
|
|
if session.UserCountry != nil {
|
|
value = *session.UserCountry
|
|
}
|
|
case UserState:
|
|
if session.UserState != nil {
|
|
value = *session.UserState
|
|
}
|
|
case UserCity:
|
|
if session.UserCity != nil {
|
|
value = *session.UserCity
|
|
}
|
|
default:
|
|
if val, ok := (*session.Metadata)[name]; ok {
|
|
value = val
|
|
}
|
|
}
|
|
|
|
matched := false
|
|
for _, filterValue := range filter.Value {
|
|
if filter.Operator == Is && value != filterValue {
|
|
continue
|
|
} else if filter.Operator == Contains && !strings.Contains(strings.ToLower(value), strings.ToLower(filterValue)) {
|
|
continue
|
|
}
|
|
matched = true
|
|
matchedFilters[name] = append(matchedFilters[name], value)
|
|
}
|
|
if !matched {
|
|
return false
|
|
}
|
|
}
|
|
for values, filter := range matchedFilters {
|
|
for _, value := range filter {
|
|
counter[values][value]++
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (sm *sessionManagerImpl) Autocomplete(projectID string, key FilterType, value string) ([]interface{}, error) {
|
|
matches := make(map[string]struct{}) // To ensure uniqueness
|
|
lowerValue := strings.ToLower(value)
|
|
|
|
sm.mutex.RLock()
|
|
defer sm.mutex.RUnlock()
|
|
|
|
for _, session := range sm.sorted {
|
|
if session.ProjectID != projectID {
|
|
continue
|
|
}
|
|
|
|
var fieldValue string
|
|
switch key {
|
|
case UserID:
|
|
if session.UserID != nil {
|
|
fieldValue = *session.UserID
|
|
}
|
|
case UserAnonymousID:
|
|
if session.UserUUID != nil {
|
|
fieldValue = *session.UserUUID
|
|
}
|
|
case UserOS:
|
|
if session.UserOS != nil {
|
|
fieldValue = *session.UserOS
|
|
}
|
|
case UserBrowser:
|
|
if session.UserBrowser != nil {
|
|
fieldValue = *session.UserBrowser
|
|
}
|
|
case UserDevice:
|
|
if session.UserDevice != nil {
|
|
fieldValue = *session.UserDevice
|
|
}
|
|
case UserState:
|
|
if session.UserState != nil {
|
|
fieldValue = *session.UserState
|
|
}
|
|
case UserCity:
|
|
if session.UserCity != nil {
|
|
fieldValue = *session.UserCity
|
|
}
|
|
default:
|
|
if v, ok := (*session.Metadata)[string(key)]; ok {
|
|
fieldValue = v
|
|
}
|
|
}
|
|
|
|
if fieldValue != "" && strings.Contains(strings.ToLower(fieldValue), lowerValue) {
|
|
matches[fieldValue] = struct{}{}
|
|
}
|
|
}
|
|
|
|
results := make([]interface{}, 0, len(matches))
|
|
keyName := strings.ToUpper(string(key))
|
|
type pair struct {
|
|
Type string `json:"type"`
|
|
Value string `json:"value"`
|
|
}
|
|
for k := range matches {
|
|
results = append(results, pair{Type: keyName, Value: k})
|
|
}
|
|
return results, nil
|
|
}
|