diff --git a/backend/cmd/analytics/main.go b/backend/cmd/analytics/main.go index 747cc5910..6881e4625 100644 --- a/backend/cmd/analytics/main.go +++ b/backend/cmd/analytics/main.go @@ -3,8 +3,9 @@ package main import ( "context" "openreplay/backend/internal/http/server" - "openreplay/backend/pkg/analytics" "openreplay/backend/pkg/analytics/api" + "openreplay/backend/pkg/common" + "openreplay/backend/pkg/common/middleware" "openreplay/backend/pkg/db/postgres/pool" "openreplay/backend/pkg/logger" "os" @@ -25,12 +26,42 @@ func main() { } defer pgConn.Close() - services, err := analytics.NewServiceBuilder(log, cfg, pgConn) + builder := common.NewServiceBuilder(log) + services, err := builder. + WithDatabase(pgConn). + WithJWTSecret(cfg.JWTSecret). + Build() + if err != nil { log.Fatal(ctx, "can't init services: %s", err) } + //services, err := analytics.NewServiceBuilder(log, cfg, pgConn) + //if err != nil { + // log.Fatal(ctx, "can't init services: %s", err) + //} + + // Define excluded paths for this service + excludedPaths := map[string]map[string]bool{ + //"/v1/ping": {"GET": true}, + //"/v1/spots": {"POST": true}, + } + + // Define permission fetching logic + getPermissions := func(path string) []string { + // Example logic to return permissions based on path + if path == "/v1/admin" { + return []string{"admin"} + } + return []string{"user"} + } + router, err := api.NewRouter(cfg, log, services) + router.GetRouter().Use(middleware.CORS(cfg.UseAccessControlHeaders)) + router.GetRouter().Use(middleware.AuthMiddleware(services, log, excludedPaths, getPermissions)) + router.GetRouter().Use(middleware.RateLimit) + router.GetRouter().Use(middleware.Action) + if err != nil { log.Fatal(ctx, "failed while creating router: %s", err) } diff --git a/backend/pkg/analytics/api/router.go b/backend/pkg/analytics/api/router.go index 7608e84ae..4a1fbca49 100644 --- a/backend/pkg/analytics/api/router.go +++ b/backend/pkg/analytics/api/router.go @@ -2,12 +2,13 @@ package api import ( "fmt" - "github.com/gorilla/mux" "net/http" analyticsConfig "openreplay/backend/internal/config/analytics" - "openreplay/backend/pkg/analytics" + "openreplay/backend/pkg/common" "openreplay/backend/pkg/logger" "sync" + + "github.com/gorilla/mux" ) type Router struct { @@ -15,10 +16,10 @@ type Router struct { cfg *analyticsConfig.Config router *mux.Router mutex *sync.RWMutex - services *analytics.ServicesBuilder + services *common.ServicesBuilder } -func NewRouter(cfg *analyticsConfig.Config, log logger.Logger, services *analytics.ServicesBuilder) (*Router, error) { +func NewRouter(cfg *analyticsConfig.Config, log logger.Logger, services *common.ServicesBuilder) (*Router, error) { switch { case cfg == nil: return nil, fmt.Errorf("config is empty") @@ -44,10 +45,7 @@ func (e *Router) init() { e.router.HandleFunc("/", e.ping) // Analytics routes - // e.router.HandleFunc("/v1/analytics", e.createAnalytics).Methods("POST", "OPTIONS") - // e.router.HandleFunc("/v1/analytics/{id}", e.getAnalytics).Methods("GET", "OPTIONS") - // e.router.HandleFunc("/v1/analytics/{id}", e.updateAnalytics).Methods("PATCH", "OPTIONS") - // e.router.HandleFunc("/v1/analytics", e.getAnalytics).Methods("GET", "OPTIONS") + e.router.HandleFunc("/v1/analytics/{id}", e.getAnalytics).Methods("GET", "OPTIONS") } func (e *Router) ping(w http.ResponseWriter, r *http.Request) { @@ -57,3 +55,17 @@ func (e *Router) ping(w http.ResponseWriter, r *http.Request) { func (e *Router) GetHandler() http.Handler { return e.router } + +func (e *Router) GetRouter() *mux.Router { + return e.router +} + +func (e *Router) getAnalytics(w http.ResponseWriter, r *http.Request) { + //w.WriteHeader(http.StatusOK) + vars := mux.Vars(r) + id := vars["id"] + e.log.Info(r.Context(), id) + w.WriteHeader(http.StatusOK) + + //e.ResponseWithJSON(w, http.StatusOK, map[string]string{"message": "getAnalytics"}) +} diff --git a/backend/pkg/common/api/router.go b/backend/pkg/common/api/router.go new file mode 100644 index 000000000..734604d18 --- /dev/null +++ b/backend/pkg/common/api/router.go @@ -0,0 +1,17 @@ +package api + +import ( + "github.com/gorilla/mux" + analyticsConfig "openreplay/backend/internal/config/analytics" + "openreplay/backend/pkg/analytics" + "openreplay/backend/pkg/logger" + "sync" +) + +type Router struct { + log logger.Logger + cfg *analyticsConfig.Config + router *mux.Router + mutex *sync.RWMutex + services *analytics.ServicesBuilder +} diff --git a/backend/pkg/common/builder.go b/backend/pkg/common/builder.go index 805d0c79a..23d20a592 100644 --- a/backend/pkg/common/builder.go +++ b/backend/pkg/common/builder.go @@ -1 +1,90 @@ package common + +import ( + "errors" + "openreplay/backend/pkg/common/api/auth" + "openreplay/backend/pkg/db/postgres/pool" + "openreplay/backend/pkg/flakeid" + "openreplay/backend/pkg/logger" + "openreplay/backend/pkg/objectstorage" +) + +// ServicesBuilder struct to hold service components +type ServicesBuilder struct { + flaker *flakeid.Flaker + objStorage objectstorage.ObjectStorage + Auth auth.Auth + log logger.Logger + pgconn pool.Pool + workerID int + jwtSecret string +} + +// NewServiceBuilder initializes the ServicesBuilder with essential components (logger) +func NewServiceBuilder(log logger.Logger) *ServicesBuilder { + return &ServicesBuilder{ + log: log, + } +} + +// WithFlaker sets the Flaker component +func (b *ServicesBuilder) WithFlaker(flaker *flakeid.Flaker) *ServicesBuilder { + b.flaker = flaker + return b +} + +// WithObjectStorage sets the Object Storage component +func (b *ServicesBuilder) WithObjectStorage(objStorage objectstorage.ObjectStorage) *ServicesBuilder { + b.objStorage = objStorage + return b +} + +// WithAuth sets the Auth component +func (b *ServicesBuilder) WithAuth(auth auth.Auth) *ServicesBuilder { + b.Auth = auth + return b +} + +// WithDatabase sets the database connection pool (Postgres pool.Pool) +func (b *ServicesBuilder) WithDatabase(pgconn pool.Pool) *ServicesBuilder { + b.pgconn = pgconn + return b +} + +// WithWorkerID sets the WorkerID for Flaker +func (b *ServicesBuilder) WithWorkerID(workerID int) *ServicesBuilder { + b.workerID = workerID + return b +} + +// WithJWTSecret sets the JWT secret for Auth +func (b *ServicesBuilder) WithJWTSecret(jwtSecret string) *ServicesBuilder { + b.jwtSecret = jwtSecret + return b +} + +// Build finalizes the service setup and returns an instance of ServicesBuilder with all components +func (b *ServicesBuilder) Build() (*ServicesBuilder, error) { + // Initialize default components if they aren't provided + + // Check if database pool is provided + if b.pgconn == nil { + return nil, errors.New("database connection pool is required") + } + + // Flaker + if b.flaker == nil { + b.flaker = flakeid.NewFlaker(uint16(b.workerID)) + } + + // Auth + if b.Auth == nil { + if b.jwtSecret == "" { + return nil, errors.New("JWT secret is required") + } + b.Auth = auth.NewAuth(b.log, b.jwtSecret, b.pgconn) + } + + // Return the fully constructed service + return b, nil +} diff --git a/backend/pkg/common/middleware/http.go b/backend/pkg/common/middleware/http.go new file mode 100644 index 000000000..0c4630514 --- /dev/null +++ b/backend/pkg/common/middleware/http.go @@ -0,0 +1,101 @@ +package middleware + +import ( + "context" + "github.com/gorilla/mux" + "net/http" + "openreplay/backend/internal/http/util" + "openreplay/backend/pkg/common" + "openreplay/backend/pkg/common/api/auth" + "openreplay/backend/pkg/logger" +) + +type userDataKey struct{} + +func CORS(useAccessControlHeaders bool) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/" { + next.ServeHTTP(w, r) + return + } + if useAccessControlHeaders { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "POST,GET,PATCH,DELETE") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type,Authorization,Content-Encoding") + } + if r.Method == http.MethodOptions { + w.Header().Set("Cache-Control", "max-age=86400") + w.WriteHeader(http.StatusOK) + return + } + r = r.WithContext(context.WithValue(r.Context(), "httpMethod", r.Method)) + r = r.WithContext(context.WithValue(r.Context(), "url", util.SafeString(r.URL.Path))) + next.ServeHTTP(w, r) + }) + } +} + +// AuthMiddleware function takes dynamic parameters to handle custom authentication logic +func AuthMiddleware( + services *common.ServicesBuilder, // Injected services (Auth, Keys, etc.) + log logger.Logger, // Logger for logging events + excludedPaths map[string]map[string]bool, // Map of excluded paths with methods + getPermissions func(path string) []string, // Function to retrieve permissions for a path +) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/" { + next.ServeHTTP(w, r) + return + } + + // Exclude specific paths and methods from auth + if methods, ok := excludedPaths[r.URL.Path]; ok && methods[r.Method] { + next.ServeHTTP(w, r) + return + } + + // Check if the route is dynamic and get the path template + pathTemplate, err := mux.CurrentRoute(r).GetPathTemplate() + if err != nil { + log.Error(r.Context(), "failed to get path template: %s", err) + } + + // Check if this request is authorized + user, err := services.Auth.IsAuthorized(r.Header.Get("Authorization"), getPermissions(r.URL.Path), pathTemplate != "") + if err != nil { + log.Warn(r.Context(), "Unauthorized request: %s", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + + // Add userData to the context for downstream handlers + ctx := context.WithValue(r.Context(), userDataKey{}, user) + r = r.WithContext(ctx) + + // Call the next handler with the modified request + next.ServeHTTP(w, r) + }) + } +} + +// GetUserData Helper function to retrieve userData from the request context +func GetUserData(r *http.Request) (*auth.User, bool) { + user, ok := r.Context().Value(userDataKey{}).(*auth.User) + return user, ok +} + +func RateLimit(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Implement rate limiting logic here + next.ServeHTTP(w, r) + }) +} + +func Action(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Implement action logging or processing logic here + next.ServeHTTP(w, r) + }) +}