openreplay/backend/pkg/common/middleware/http.go
2024-10-23 14:11:45 +02:00

101 lines
3.2 KiB
Go

package middleware
import (
"context"
"github.com/gorilla/mux"
"net/http"
"openreplay/backend/internal/http/util"
"openreplay/backend/pkg/common"
"openreplay/backend/pkg/common/api/auth"
"openreplay/backend/pkg/logger"
)
type userDataKey struct{}
func CORS(useAccessControlHeaders bool) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
next.ServeHTTP(w, r)
return
}
if useAccessControlHeaders {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST,GET,PATCH,DELETE")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type,Authorization,Content-Encoding")
}
if r.Method == http.MethodOptions {
w.Header().Set("Cache-Control", "max-age=86400")
w.WriteHeader(http.StatusOK)
return
}
r = r.WithContext(context.WithValue(r.Context(), "httpMethod", r.Method))
r = r.WithContext(context.WithValue(r.Context(), "url", util.SafeString(r.URL.Path)))
next.ServeHTTP(w, r)
})
}
}
// AuthMiddleware function takes dynamic parameters to handle custom authentication logic
func AuthMiddleware(
services *common.ServicesBuilder, // Injected services (Auth, Keys, etc.)
log logger.Logger, // Logger for logging events
excludedPaths map[string]map[string]bool, // Map of excluded paths with methods
getPermissions func(path string) []string, // Function to retrieve permissions for a path
) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
next.ServeHTTP(w, r)
return
}
// Exclude specific paths and methods from auth
if methods, ok := excludedPaths[r.URL.Path]; ok && methods[r.Method] {
next.ServeHTTP(w, r)
return
}
// Check if the route is dynamic and get the path template
pathTemplate, err := mux.CurrentRoute(r).GetPathTemplate()
if err != nil {
log.Error(r.Context(), "failed to get path template: %s", err)
}
// Check if this request is authorized
user, err := services.Auth.IsAuthorized(r.Header.Get("Authorization"), getPermissions(r.URL.Path), pathTemplate != "")
if err != nil {
log.Warn(r.Context(), "Unauthorized request: %s", err)
w.WriteHeader(http.StatusUnauthorized)
return
}
// Add userData to the context for downstream handlers
ctx := context.WithValue(r.Context(), userDataKey{}, user)
r = r.WithContext(ctx)
// Call the next handler with the modified request
next.ServeHTTP(w, r)
})
}
}
// GetUserData Helper function to retrieve userData from the request context
func GetUserData(r *http.Request) (*auth.User, bool) {
user, ok := r.Context().Value(userDataKey{}).(*auth.User)
return user, ok
}
func RateLimit(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Implement rate limiting logic here
next.ServeHTTP(w, r)
})
}
func Action(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Implement action logging or processing logic here
next.ServeHTTP(w, r)
})
}