feat(backend): analytics - common rate limiter and acton middlewares

This commit is contained in:
Shekar Siri 2024-10-28 13:28:10 +01:00
parent 1578b891bd
commit 6d4d24c5e0
5 changed files with 156 additions and 14 deletions

View file

@ -14,6 +14,7 @@ import (
"os"
"os/signal"
"syscall"
"time"
config "openreplay/backend/internal/config/analytics"
)
@ -73,12 +74,13 @@ func main() {
}
authMiddleware := middleware.AuthMiddleware(services, log, excludedPaths, getPermissions, authOptionsSelector)
limiterMiddleware := middleware.RateLimit(common.NewUserRateLimiter(10, 30, 1*time.Minute, 5*time.Minute))
router, err := api.NewRouter(cfg, log, services)
router.GetRouter().Use(middleware.CORS(cfg.UseAccessControlHeaders))
router.GetRouter().Use(authMiddleware)
router.GetRouter().Use(middleware.RateLimit)
router.GetRouter().Use(middleware.Action)
router.GetRouter().Use(limiterMiddleware)
router.GetRouter().Use(middleware.Action())
if err != nil {
log.Fatal(ctx, "failed while creating router: %s", err)

View file

@ -17,6 +17,7 @@ type Router struct {
router *mux.Router
mutex *sync.RWMutex
services *common.ServicesBuilder
limiter *common.UserRateLimiter
}
func NewRouter(cfg *analyticsConfig.Config, log logger.Logger, services *common.ServicesBuilder) (*Router, error) {

View file

@ -3,7 +3,7 @@ package api
import (
"github.com/gorilla/mux"
analyticsConfig "openreplay/backend/internal/config/analytics"
"openreplay/backend/pkg/analytics"
"openreplay/backend/pkg/common"
"openreplay/backend/pkg/logger"
"sync"
)
@ -13,5 +13,6 @@ type Router struct {
cfg *analyticsConfig.Config
router *mux.Router
mutex *sync.RWMutex
services *analytics.ServicesBuilder
services *common.ServicesBuilder
limiter *common.UserRateLimiter
}

View file

@ -0,0 +1,88 @@
package common
import (
"sync"
"time"
)
type RateLimiter struct {
rate int
burst int
tokens int
lastToken time.Time
lastUsed time.Time
mu sync.Mutex
}
func NewRateLimiter(rate int, burst int) *RateLimiter {
return &RateLimiter{
rate: rate,
burst: burst,
tokens: burst,
lastToken: time.Now(),
lastUsed: time.Now(),
}
}
func (rl *RateLimiter) Allow() bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
elapsed := now.Sub(rl.lastToken)
rl.tokens += int(elapsed.Seconds()) * rl.rate
if rl.tokens > rl.burst {
rl.tokens = rl.burst
}
rl.lastToken = now
rl.lastUsed = now
if rl.tokens > 0 {
rl.tokens--
return true
}
return false
}
type UserRateLimiter struct {
rateLimiters sync.Map
rate int
burst int
cleanupInterval time.Duration
maxIdleTime time.Duration
}
func NewUserRateLimiter(rate int, burst int, cleanupInterval time.Duration, maxIdleTime time.Duration) *UserRateLimiter {
url := &UserRateLimiter{
rate: rate,
burst: burst,
cleanupInterval: cleanupInterval,
maxIdleTime: maxIdleTime,
}
go url.cleanup()
return url
}
func (url *UserRateLimiter) GetRateLimiter(user uint64) *RateLimiter {
value, _ := url.rateLimiters.LoadOrStore(user, NewRateLimiter(url.rate, url.burst))
return value.(*RateLimiter)
}
func (url *UserRateLimiter) cleanup() {
for {
time.Sleep(url.cleanupInterval)
now := time.Now()
url.rateLimiters.Range(func(key, value interface{}) bool {
rl := value.(*RateLimiter)
rl.mu.Lock()
if now.Sub(rl.lastUsed) > url.maxIdleTime {
url.rateLimiters.Delete(key)
}
rl.mu.Unlock()
return true
})
}
}

View file

@ -1,7 +1,9 @@
package middleware
import (
"bytes"
"context"
"io"
"net/http"
"openreplay/backend/internal/http/util"
"openreplay/backend/pkg/common"
@ -106,16 +108,64 @@ func GetUserData(r *http.Request) (*auth.User, bool) {
return user, ok
}
func RateLimit(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Implement rate limiting logic here
next.ServeHTTP(w, r)
})
// RateLimit General rate-limiting middleware
func RateLimit(limiter *common.UserRateLimiter) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
next.ServeHTTP(w, r)
return
}
user, ok := GetUserData(r)
if !ok {
w.WriteHeader(http.StatusUnauthorized)
return
}
rl := limiter.GetRateLimiter(user.ID)
if !rl.Allow() {
w.WriteHeader(http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
func Action(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Implement action logging or processing logic here
next.ServeHTTP(w, r)
})
type statusWriter struct {
http.ResponseWriter
statusCode int
}
func (w *statusWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *statusWriter) Write(b []byte) (int, error) {
if w.statusCode == 0 {
w.statusCode = http.StatusOK // Default status code is 200
}
return w.ResponseWriter.Write(b)
}
func Action() func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
next.ServeHTTP(w, r)
}
// Read body and restore the io.ReadCloser to its original state
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "can't read body", http.StatusBadRequest)
return
}
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// Use custom response writer to get the status code
sw := &statusWriter{ResponseWriter: w}
// Serve the request
next.ServeHTTP(sw, r)
//e.logRequest(r, bodyBytes, sw.statusCode)
})
}
}