diff --git a/backend/cmd/analytics/main.go b/backend/cmd/analytics/main.go index 1d7b15b87..e184a6d0b 100644 --- a/backend/cmd/analytics/main.go +++ b/backend/cmd/analytics/main.go @@ -14,6 +14,7 @@ import ( "os" "os/signal" "syscall" + "time" config "openreplay/backend/internal/config/analytics" ) @@ -73,12 +74,13 @@ func main() { } authMiddleware := middleware.AuthMiddleware(services, log, excludedPaths, getPermissions, authOptionsSelector) + limiterMiddleware := middleware.RateLimit(common.NewUserRateLimiter(10, 30, 1*time.Minute, 5*time.Minute)) router, err := api.NewRouter(cfg, log, services) router.GetRouter().Use(middleware.CORS(cfg.UseAccessControlHeaders)) router.GetRouter().Use(authMiddleware) - router.GetRouter().Use(middleware.RateLimit) - router.GetRouter().Use(middleware.Action) + router.GetRouter().Use(limiterMiddleware) + router.GetRouter().Use(middleware.Action()) if err != nil { log.Fatal(ctx, "failed while creating router: %s", err) diff --git a/backend/pkg/analytics/api/router.go b/backend/pkg/analytics/api/router.go index 40ee3d3a9..c2ba94427 100644 --- a/backend/pkg/analytics/api/router.go +++ b/backend/pkg/analytics/api/router.go @@ -17,6 +17,7 @@ type Router struct { router *mux.Router mutex *sync.RWMutex services *common.ServicesBuilder + limiter *common.UserRateLimiter } func NewRouter(cfg *analyticsConfig.Config, log logger.Logger, services *common.ServicesBuilder) (*Router, error) { diff --git a/backend/pkg/common/api/router.go b/backend/pkg/common/api/router.go index 734604d18..c496bf5a8 100644 --- a/backend/pkg/common/api/router.go +++ b/backend/pkg/common/api/router.go @@ -3,7 +3,7 @@ package api import ( "github.com/gorilla/mux" analyticsConfig "openreplay/backend/internal/config/analytics" - "openreplay/backend/pkg/analytics" + "openreplay/backend/pkg/common" "openreplay/backend/pkg/logger" "sync" ) @@ -13,5 +13,6 @@ type Router struct { cfg *analyticsConfig.Config router *mux.Router mutex *sync.RWMutex - services *analytics.ServicesBuilder + services *common.ServicesBuilder + limiter *common.UserRateLimiter } diff --git a/backend/pkg/common/limiter.go b/backend/pkg/common/limiter.go new file mode 100644 index 000000000..2e38200e8 --- /dev/null +++ b/backend/pkg/common/limiter.go @@ -0,0 +1,88 @@ +package common + +import ( + "sync" + "time" +) + +type RateLimiter struct { + rate int + burst int + tokens int + lastToken time.Time + lastUsed time.Time + mu sync.Mutex +} + +func NewRateLimiter(rate int, burst int) *RateLimiter { + return &RateLimiter{ + rate: rate, + burst: burst, + tokens: burst, + lastToken: time.Now(), + lastUsed: time.Now(), + } +} + +func (rl *RateLimiter) Allow() bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + elapsed := now.Sub(rl.lastToken) + + rl.tokens += int(elapsed.Seconds()) * rl.rate + if rl.tokens > rl.burst { + rl.tokens = rl.burst + } + + rl.lastToken = now + rl.lastUsed = now + + if rl.tokens > 0 { + rl.tokens-- + return true + } + return false +} + +type UserRateLimiter struct { + rateLimiters sync.Map + rate int + burst int + cleanupInterval time.Duration + maxIdleTime time.Duration +} + +func NewUserRateLimiter(rate int, burst int, cleanupInterval time.Duration, maxIdleTime time.Duration) *UserRateLimiter { + url := &UserRateLimiter{ + rate: rate, + burst: burst, + cleanupInterval: cleanupInterval, + maxIdleTime: maxIdleTime, + } + go url.cleanup() + return url +} + +func (url *UserRateLimiter) GetRateLimiter(user uint64) *RateLimiter { + value, _ := url.rateLimiters.LoadOrStore(user, NewRateLimiter(url.rate, url.burst)) + return value.(*RateLimiter) +} + +func (url *UserRateLimiter) cleanup() { + for { + time.Sleep(url.cleanupInterval) + now := time.Now() + + url.rateLimiters.Range(func(key, value interface{}) bool { + rl := value.(*RateLimiter) + rl.mu.Lock() + if now.Sub(rl.lastUsed) > url.maxIdleTime { + url.rateLimiters.Delete(key) + } + rl.mu.Unlock() + return true + }) + } +} diff --git a/backend/pkg/common/middleware/http.go b/backend/pkg/common/middleware/http.go index 5c6d6a1e6..4904d8aba 100644 --- a/backend/pkg/common/middleware/http.go +++ b/backend/pkg/common/middleware/http.go @@ -1,7 +1,9 @@ package middleware import ( + "bytes" "context" + "io" "net/http" "openreplay/backend/internal/http/util" "openreplay/backend/pkg/common" @@ -106,16 +108,64 @@ func GetUserData(r *http.Request) (*auth.User, bool) { return user, ok } -func RateLimit(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Implement rate limiting logic here - next.ServeHTTP(w, r) - }) +// RateLimit General rate-limiting middleware +func RateLimit(limiter *common.UserRateLimiter) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/" { + next.ServeHTTP(w, r) + return + } + user, ok := GetUserData(r) + if !ok { + w.WriteHeader(http.StatusUnauthorized) + return + } + rl := limiter.GetRateLimiter(user.ID) + if !rl.Allow() { + w.WriteHeader(http.StatusTooManyRequests) + return + } + next.ServeHTTP(w, r) + }) + } } -func Action(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Implement action logging or processing logic here - next.ServeHTTP(w, r) - }) +type statusWriter struct { + http.ResponseWriter + statusCode int +} + +func (w *statusWriter) WriteHeader(statusCode int) { + w.statusCode = statusCode + w.ResponseWriter.WriteHeader(statusCode) +} + +func (w *statusWriter) Write(b []byte) (int, error) { + if w.statusCode == 0 { + w.statusCode = http.StatusOK // Default status code is 200 + } + return w.ResponseWriter.Write(b) +} + +func Action() func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/" { + next.ServeHTTP(w, r) + } + // Read body and restore the io.ReadCloser to its original state + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "can't read body", http.StatusBadRequest) + return + } + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + // Use custom response writer to get the status code + sw := &statusWriter{ResponseWriter: w} + // Serve the request + next.ServeHTTP(sw, r) + //e.logRequest(r, bodyBytes, sw.statusCode) + }) + } }