feat(backend): analytics - common rate limiter and acton middlewares
This commit is contained in:
parent
1578b891bd
commit
6d4d24c5e0
5 changed files with 156 additions and 14 deletions
|
|
@ -14,6 +14,7 @@ import (
|
|||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
config "openreplay/backend/internal/config/analytics"
|
||||
)
|
||||
|
|
@ -73,12 +74,13 @@ func main() {
|
|||
}
|
||||
|
||||
authMiddleware := middleware.AuthMiddleware(services, log, excludedPaths, getPermissions, authOptionsSelector)
|
||||
limiterMiddleware := middleware.RateLimit(common.NewUserRateLimiter(10, 30, 1*time.Minute, 5*time.Minute))
|
||||
|
||||
router, err := api.NewRouter(cfg, log, services)
|
||||
router.GetRouter().Use(middleware.CORS(cfg.UseAccessControlHeaders))
|
||||
router.GetRouter().Use(authMiddleware)
|
||||
router.GetRouter().Use(middleware.RateLimit)
|
||||
router.GetRouter().Use(middleware.Action)
|
||||
router.GetRouter().Use(limiterMiddleware)
|
||||
router.GetRouter().Use(middleware.Action())
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(ctx, "failed while creating router: %s", err)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ type Router struct {
|
|||
router *mux.Router
|
||||
mutex *sync.RWMutex
|
||||
services *common.ServicesBuilder
|
||||
limiter *common.UserRateLimiter
|
||||
}
|
||||
|
||||
func NewRouter(cfg *analyticsConfig.Config, log logger.Logger, services *common.ServicesBuilder) (*Router, error) {
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ package api
|
|||
import (
|
||||
"github.com/gorilla/mux"
|
||||
analyticsConfig "openreplay/backend/internal/config/analytics"
|
||||
"openreplay/backend/pkg/analytics"
|
||||
"openreplay/backend/pkg/common"
|
||||
"openreplay/backend/pkg/logger"
|
||||
"sync"
|
||||
)
|
||||
|
|
@ -13,5 +13,6 @@ type Router struct {
|
|||
cfg *analyticsConfig.Config
|
||||
router *mux.Router
|
||||
mutex *sync.RWMutex
|
||||
services *analytics.ServicesBuilder
|
||||
services *common.ServicesBuilder
|
||||
limiter *common.UserRateLimiter
|
||||
}
|
||||
|
|
|
|||
88
backend/pkg/common/limiter.go
Normal file
88
backend/pkg/common/limiter.go
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RateLimiter struct {
|
||||
rate int
|
||||
burst int
|
||||
tokens int
|
||||
lastToken time.Time
|
||||
lastUsed time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewRateLimiter(rate int, burst int) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
rate: rate,
|
||||
burst: burst,
|
||||
tokens: burst,
|
||||
lastToken: time.Now(),
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) Allow() bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(rl.lastToken)
|
||||
|
||||
rl.tokens += int(elapsed.Seconds()) * rl.rate
|
||||
if rl.tokens > rl.burst {
|
||||
rl.tokens = rl.burst
|
||||
}
|
||||
|
||||
rl.lastToken = now
|
||||
rl.lastUsed = now
|
||||
|
||||
if rl.tokens > 0 {
|
||||
rl.tokens--
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type UserRateLimiter struct {
|
||||
rateLimiters sync.Map
|
||||
rate int
|
||||
burst int
|
||||
cleanupInterval time.Duration
|
||||
maxIdleTime time.Duration
|
||||
}
|
||||
|
||||
func NewUserRateLimiter(rate int, burst int, cleanupInterval time.Duration, maxIdleTime time.Duration) *UserRateLimiter {
|
||||
url := &UserRateLimiter{
|
||||
rate: rate,
|
||||
burst: burst,
|
||||
cleanupInterval: cleanupInterval,
|
||||
maxIdleTime: maxIdleTime,
|
||||
}
|
||||
go url.cleanup()
|
||||
return url
|
||||
}
|
||||
|
||||
func (url *UserRateLimiter) GetRateLimiter(user uint64) *RateLimiter {
|
||||
value, _ := url.rateLimiters.LoadOrStore(user, NewRateLimiter(url.rate, url.burst))
|
||||
return value.(*RateLimiter)
|
||||
}
|
||||
|
||||
func (url *UserRateLimiter) cleanup() {
|
||||
for {
|
||||
time.Sleep(url.cleanupInterval)
|
||||
now := time.Now()
|
||||
|
||||
url.rateLimiters.Range(func(key, value interface{}) bool {
|
||||
rl := value.(*RateLimiter)
|
||||
rl.mu.Lock()
|
||||
if now.Sub(rl.lastUsed) > url.maxIdleTime {
|
||||
url.rateLimiters.Delete(key)
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,9 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"openreplay/backend/internal/http/util"
|
||||
"openreplay/backend/pkg/common"
|
||||
|
|
@ -106,16 +108,64 @@ func GetUserData(r *http.Request) (*auth.User, bool) {
|
|||
return user, ok
|
||||
}
|
||||
|
||||
func RateLimit(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Implement rate limiting logic here
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
// RateLimit General rate-limiting middleware
|
||||
func RateLimit(limiter *common.UserRateLimiter) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
user, ok := GetUserData(r)
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
rl := limiter.GetRateLimiter(user.ID)
|
||||
if !rl.Allow() {
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Action(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Implement action logging or processing logic here
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
type statusWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (w *statusWriter) WriteHeader(statusCode int) {
|
||||
w.statusCode = statusCode
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *statusWriter) Write(b []byte) (int, error) {
|
||||
if w.statusCode == 0 {
|
||||
w.statusCode = http.StatusOK // Default status code is 200
|
||||
}
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func Action() func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/" {
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
// Read body and restore the io.ReadCloser to its original state
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "can't read body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
// Use custom response writer to get the status code
|
||||
sw := &statusWriter{ResponseWriter: w}
|
||||
// Serve the request
|
||||
next.ServeHTTP(sw, r)
|
||||
//e.logRequest(r, bodyBytes, sw.statusCode)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue