openreplay/backend/pkg/common/middleware/http.go
2024-10-30 10:53:55 +01:00

170 lines
5.2 KiB
Go

package middleware
import (
"bytes"
"io"
"net/http"
"openreplay/backend/pkg/common"
"openreplay/backend/pkg/common/api/auth"
"openreplay/backend/pkg/logger"
)
type userData struct{}
func CORS(useAccessControlHeaders bool) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
next.ServeHTTP(w, r)
return
}
if useAccessControlHeaders {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST,GET,PATCH,DELETE")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type,Authorization,Content-Encoding")
}
if r.Method == http.MethodOptions {
w.Header().Set("Cache-Control", "max-age=86400")
w.WriteHeader(http.StatusOK)
return
}
//r = r.WithContext(context.WithValue(r.Context(), "httpMethod", r.Method))
//r = r.WithContext(context.WithValue(r.Context(), "url", util.SafeString(r.URL.Path)))
//r = r.WithContext(context.WithValues(r.Context(), map[string]interface{}{"httpMethod": r.Method, "url": util.SafeString(r.URL.Path)}))
next.ServeHTTP(w, r)
})
}
}
// AuthMiddleware function takes dynamic parameters to handle custom authentication logic
func AuthMiddleware(
//services *common.ServicesBuilder, // Injected services (Auth, Keys, etc.)
Auth auth.Auth, // Auth interface for authorization
log logger.Logger, // Logger for logging events
excludedPaths map[string]map[string]bool, // Map of excluded paths with methods
getPermissions func(path string) []string,
authOptionsSelector func(r *http.Request) *auth.Options, // Function to retrieve permissions for a path
) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
next.ServeHTTP(w, r)
return
}
// Exclude specific paths and methods from auth
if methods, ok := excludedPaths[r.URL.Path]; ok && methods[r.Method] {
next.ServeHTTP(w, r)
return
}
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
log.Warn(r.Context(), "Authorization header missing")
w.WriteHeader(http.StatusUnauthorized)
return
}
// Get AuthOptions for the request
options := auth.Options{
JwtColumn: Auth.JWTCol(), // Default JWT column from ServicesBuilder
Secret: Auth.Secret(), // Default secret from ServicesBuilder
}
if authOptionsSelector != nil {
selectorOptions := authOptionsSelector(r)
if selectorOptions != nil {
// Override defaults with values from selectorOptions
if selectorOptions.JwtColumn != "" {
options.JwtColumn = selectorOptions.JwtColumn
}
if selectorOptions.Secret != "" {
options.Secret = selectorOptions.Secret
}
}
}
// Check if this request is authorized
//user, err := Auth.IsAuthorized(authHeader, getPermissions(r.URL.Path), options)
//if err != nil {
// log.Warn(r.Context(), "Unauthorized request: %s", err)
// w.WriteHeader(http.StatusUnauthorized)
// return
//}
//
//// Add userData to the context for downstream handlers
//r = r.WithContext(context.WithValues(r.Context(), map[string]interface{}{"userData": user}))
// Call the next handler with the modified request
next.ServeHTTP(w, r)
})
}
}
// GetUserData Helper function to retrieve userData from the request context
func GetUserData(r *http.Request) (*auth.User, bool) {
user, ok := r.Context().Value(userData{}).(*auth.User)
return user, ok
}
// RateLimit General rate-limiting middleware
func RateLimit(limiter *common.UserRateLimiter) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
next.ServeHTTP(w, r)
return
}
user, ok := GetUserData(r)
if !ok {
w.WriteHeader(http.StatusUnauthorized)
return
}
rl := limiter.GetRateLimiter(user.ID)
if !rl.Allow() {
w.WriteHeader(http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
type statusWriter struct {
http.ResponseWriter
statusCode int
}
func (w *statusWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *statusWriter) Write(b []byte) (int, error) {
if w.statusCode == 0 {
w.statusCode = http.StatusOK // Default status code is 200
}
return w.ResponseWriter.Write(b)
}
func Action() func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
next.ServeHTTP(w, r)
}
// Read body and restore the io.ReadCloser to its original state
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "can't read body", http.StatusBadRequest)
return
}
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// Use custom response writer to get the status code
sw := &statusWriter{ResponseWriter: w}
// Serve the request
next.ServeHTTP(sw, r)
//e.logRequest(r, bodyBytes, sw.statusCode)
})
}
}