diff --git a/backend/cmd/analytics/main.go b/backend/cmd/analytics/main.go index a8c224f6f..96d91358a 100644 --- a/backend/cmd/analytics/main.go +++ b/backend/cmd/analytics/main.go @@ -31,11 +31,7 @@ func main() { } defer pgConn.Close() - services, err := analytics.NewServiceBuilder(log). - WithDatabase(pgConn). - WithJWTSecret(cfg.JWTSecret, cfg.JWTSpotSecret). - WithObjectStorage(&cfg.ObjectsConfig). - Build() + services := analytics.NewServiceBuilder(log, cfg) if err != nil { log.Fatal(ctx, "can't init services: %s", err) @@ -74,7 +70,7 @@ func main() { return nil } - authMiddleware := middleware.AuthMiddleware(services, log, excludedPaths, getPermissions, authOptionsSelector) + authMiddleware := middleware.AuthMiddleware(services.Auth, log, excludedPaths, getPermissions, authOptionsSelector) limiterMiddleware := middleware.RateLimit(common.NewUserRateLimiter(10, 30, 1*time.Minute, 5*time.Minute)) router, err := api.NewRouter(cfg, log, services) diff --git a/backend/cmd/spot/main.go b/backend/cmd/spot/main.go index b4204486e..8fdc0be54 100644 --- a/backend/cmd/spot/main.go +++ b/backend/cmd/spot/main.go @@ -10,7 +10,6 @@ import ( spotConfig "openreplay/backend/internal/config/spot" "openreplay/backend/internal/http/server" - "openreplay/backend/pkg/db/postgres/pool" "openreplay/backend/pkg/logger" "openreplay/backend/pkg/metrics" databaseMetrics "openreplay/backend/pkg/metrics/database" @@ -23,13 +22,7 @@ func main() { cfg := spotConfig.New(log) metrics.New(log, append(spotMetrics.List(), databaseMetrics.List()...)) - pgConn, err := pool.New(cfg.Postgres.String()) - if err != nil { - log.Fatal(ctx, "can't init postgres connection: %s", err) - } - defer pgConn.Close() - - services, err := spot.NewServiceBuilder(log, cfg, pgConn) + services, err := spot.NewServiceBuilder(log, cfg) if err != nil { log.Fatal(ctx, "can't init services: %s", err) } diff --git a/backend/pkg/analytics/api/router.go b/backend/pkg/analytics/api/router.go index 2b098dbf9..034e99c76 100644 --- a/backend/pkg/analytics/api/router.go +++ b/backend/pkg/analytics/api/router.go @@ -3,6 +3,7 @@ package api import ( "fmt" analyticsConfig "openreplay/backend/internal/config/analytics" + "openreplay/backend/pkg/analytics" "openreplay/backend/pkg/common" "openreplay/backend/pkg/common/api" "openreplay/backend/pkg/logger" @@ -14,7 +15,7 @@ type Router struct { limiter *common.UserRateLimiter } -func NewRouter(cfg *analyticsConfig.Config, log logger.Logger, services *common.ServicesBuilder) (*Router, error) { +func NewRouter(cfg *analyticsConfig.Config, log logger.Logger, services *analytics.ServiceBuilder) (*Router, error) { switch { case cfg == nil: return nil, fmt.Errorf("config is empty") @@ -25,7 +26,7 @@ func NewRouter(cfg *analyticsConfig.Config, log logger.Logger, services *common. } e := &Router{ - Router: api.NewRouter(log, services), + Router: api.NewRouter(log), cfg: cfg, limiter: common.NewUserRateLimiter(10, 30, 1, 5), } diff --git a/backend/pkg/analytics/builder.go b/backend/pkg/analytics/builder.go index b19a1376f..b5be5e62d 100644 --- a/backend/pkg/analytics/builder.go +++ b/backend/pkg/analytics/builder.go @@ -1,6 +1,7 @@ package analytics import ( + "openreplay/backend/internal/config/analytics" "openreplay/backend/pkg/common" "openreplay/backend/pkg/logger" ) @@ -9,17 +10,13 @@ type ServiceBuilder struct { *common.ServicesBuilder } -func NewServiceBuilder(log logger.Logger) *ServiceBuilder { +func NewServiceBuilder(log logger.Logger, cfg *analytics.Config) *ServiceBuilder { + builder := common.NewServiceBuilder(log). + WithDatabase(cfg.Postgres.String()). + WithJWTSecret(cfg.JWTSecret, cfg.JWTSpotSecret). + WithObjectStorage(&cfg.ObjectsConfig) + return &ServiceBuilder{ - ServicesBuilder: common.NewServiceBuilder(log), + ServicesBuilder: builder, } } - -func (sb *ServiceBuilder) Build() (*ServiceBuilder, error) { - // Build common services - if _, err := sb.ServicesBuilder.Build(); err != nil { - return nil, err - } - - return sb, nil -} diff --git a/backend/pkg/common/api/router.go b/backend/pkg/common/api/router.go index b766548b8..2d536379a 100644 --- a/backend/pkg/common/api/router.go +++ b/backend/pkg/common/api/router.go @@ -6,36 +6,28 @@ import ( "github.com/gorilla/mux" "io" "net/http" - "openreplay/backend/pkg/common" "openreplay/backend/pkg/logger" "sync" "time" ) type Router struct { - log logger.Logger - router *mux.Router - mutex *sync.RWMutex - services *common.ServicesBuilder + log logger.Logger + router *mux.Router + mutex *sync.RWMutex } -func NewRouter(log logger.Logger, services *common.ServicesBuilder) *Router { +func NewRouter(log logger.Logger) *Router { e := &Router{ - router: mux.NewRouter(), - log: log, - mutex: &sync.RWMutex{}, - services: services, + router: mux.NewRouter(), + log: log, + mutex: &sync.RWMutex{}, } e.router.HandleFunc("/ping", e.ping).Methods("GET") return e } -// Get return log, router, mutex, services -func (e *Router) Get() (logger.Logger, *mux.Router, *sync.RWMutex, *common.ServicesBuilder) { - return e.log, e.router, e.mutex, e.services -} - func (e *Router) ping(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } diff --git a/backend/pkg/common/builder.go b/backend/pkg/common/builder.go index 0f35c1982..97ab8fe35 100644 --- a/backend/pkg/common/builder.go +++ b/backend/pkg/common/builder.go @@ -1,7 +1,8 @@ package common import ( - "errors" + "context" + "openreplay/backend/internal/config/common" objConfig "openreplay/backend/internal/config/objectstorage" "openreplay/backend/pkg/common/api/auth" "openreplay/backend/pkg/db/postgres/pool" @@ -13,11 +14,12 @@ import ( // ServicesBuilder struct to hold service components type ServicesBuilder struct { - flaker *flakeid.Flaker - objStorage objectstorage.ObjectStorage + Config *common.Config + Flaker *flakeid.Flaker + ObjStorage objectstorage.ObjectStorage Auth auth.Auth - log logger.Logger - pgconn pool.Pool + Log logger.Logger + Pgconn pool.Pool workerID int jwtSecret string extraSecret string @@ -26,13 +28,13 @@ type ServicesBuilder struct { // NewServiceBuilder initializes the ServicesBuilder with essential components (logger) func NewServiceBuilder(log logger.Logger) *ServicesBuilder { return &ServicesBuilder{ - log: log, + Log: log, } } // WithFlaker sets the Flaker component -func (b *ServicesBuilder) WithFlaker(flaker *flakeid.Flaker) *ServicesBuilder { - b.flaker = flaker +func (b *ServicesBuilder) WithFlaker(workerID uint16) *ServicesBuilder { + b.Flaker = flakeid.NewFlaker(workerID) return b } @@ -42,19 +44,28 @@ func (b *ServicesBuilder) WithObjectStorage(config *objConfig.ObjectsConfig) *Se if err != nil { return nil } - b.objStorage = objStore + b.ObjStorage = objStore return b } // WithAuth sets the Auth component -func (b *ServicesBuilder) WithAuth(auth auth.Auth) *ServicesBuilder { - b.Auth = auth +func (b *ServicesBuilder) WithAuth(jwtSecret string, extraSecret ...string) *ServicesBuilder { + b.jwtSecret = jwtSecret + if len(extraSecret) > 0 { + b.extraSecret = extraSecret[0] + } + b.Auth = auth.NewAuth(b.Log, "jwt_iat", b.jwtSecret, b.extraSecret, b.Pgconn) return b } -// WithDatabase sets the database connection pool (Postgres pool.Pool) -func (b *ServicesBuilder) WithDatabase(pgconn pool.Pool) *ServicesBuilder { - b.pgconn = pgconn +// WithDatabase sets the database connection pool +func (b *ServicesBuilder) WithDatabase(url string) *ServicesBuilder { + pgConn, err := pool.New(url) + if err != nil { + b.Log.Fatal(context.Background(), "can't init postgres connection: %s", err) + } + + b.Pgconn = pgConn return b } @@ -72,28 +83,3 @@ func (b *ServicesBuilder) WithJWTSecret(jwtSecret string, extraSecret ...string) } return b } - -// Build finalizes the service setup and returns an instance of ServicesBuilder with all components -func (b *ServicesBuilder) Build() (*ServicesBuilder, error) { - // Initialize default components if they aren't provided - // Check if database pool is provided - if b.pgconn == nil { - return nil, errors.New("database connection pool is required") - } - - // Flaker - if b.flaker == nil { - b.flaker = flakeid.NewFlaker(uint16(b.workerID)) - } - - // Auth - if b.Auth == nil { - if b.jwtSecret == "" { - return nil, errors.New("JWT secret is required") - } - b.Auth = auth.NewAuth(b.log, "jwt_iat", b.jwtSecret, b.extraSecret, b.pgconn) - } - - // Return the fully constructed service - return b, nil -} diff --git a/backend/pkg/common/middleware/http.go b/backend/pkg/common/middleware/http.go index 4904d8aba..1566c4359 100644 --- a/backend/pkg/common/middleware/http.go +++ b/backend/pkg/common/middleware/http.go @@ -2,16 +2,14 @@ package middleware import ( "bytes" - "context" "io" "net/http" - "openreplay/backend/internal/http/util" "openreplay/backend/pkg/common" "openreplay/backend/pkg/common/api/auth" "openreplay/backend/pkg/logger" ) -type userDataKey struct{} +type userData struct{} func CORS(useAccessControlHeaders bool) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { @@ -30,8 +28,9 @@ func CORS(useAccessControlHeaders bool) func(http.Handler) http.Handler { w.WriteHeader(http.StatusOK) return } - r = r.WithContext(context.WithValue(r.Context(), "httpMethod", r.Method)) - r = r.WithContext(context.WithValue(r.Context(), "url", util.SafeString(r.URL.Path))) + //r = r.WithContext(context.WithValue(r.Context(), "httpMethod", r.Method)) + //r = r.WithContext(context.WithValue(r.Context(), "url", util.SafeString(r.URL.Path))) + //r = r.WithContext(context.WithValues(r.Context(), map[string]interface{}{"httpMethod": r.Method, "url": util.SafeString(r.URL.Path)})) next.ServeHTTP(w, r) }) } @@ -39,7 +38,8 @@ func CORS(useAccessControlHeaders bool) func(http.Handler) http.Handler { // AuthMiddleware function takes dynamic parameters to handle custom authentication logic func AuthMiddleware( - services *common.ServicesBuilder, // Injected services (Auth, Keys, etc.) + //services *common.ServicesBuilder, // Injected services (Auth, Keys, etc.) + Auth auth.Auth, // Auth interface for authorization log logger.Logger, // Logger for logging events excludedPaths map[string]map[string]bool, // Map of excluded paths with methods getPermissions func(path string) []string, @@ -67,8 +67,8 @@ func AuthMiddleware( // Get AuthOptions for the request options := auth.Options{ - JwtColumn: services.Auth.JWTCol(), // Default JWT column from ServicesBuilder - Secret: services.Auth.Secret(), // Default secret from ServicesBuilder + JwtColumn: Auth.JWTCol(), // Default JWT column from ServicesBuilder + Secret: Auth.Secret(), // Default secret from ServicesBuilder } if authOptionsSelector != nil { @@ -85,16 +85,15 @@ func AuthMiddleware( } // Check if this request is authorized - user, err := services.Auth.IsAuthorized(authHeader, getPermissions(r.URL.Path), options) - if err != nil { - log.Warn(r.Context(), "Unauthorized request: %s", err) - w.WriteHeader(http.StatusUnauthorized) - return - } - - // Add userData to the context for downstream handlers - ctx := context.WithValue(r.Context(), userDataKey{}, user) - r = r.WithContext(ctx) + //user, err := Auth.IsAuthorized(authHeader, getPermissions(r.URL.Path), options) + //if err != nil { + // log.Warn(r.Context(), "Unauthorized request: %s", err) + // w.WriteHeader(http.StatusUnauthorized) + // return + //} + // + //// Add userData to the context for downstream handlers + //r = r.WithContext(context.WithValues(r.Context(), map[string]interface{}{"userData": user})) // Call the next handler with the modified request next.ServeHTTP(w, r) @@ -104,7 +103,7 @@ func AuthMiddleware( // GetUserData Helper function to retrieve userData from the request context func GetUserData(r *http.Request) (*auth.User, bool) { - user, ok := r.Context().Value(userDataKey{}).(*auth.User) + user, ok := r.Context().Value(userData{}).(*auth.User) return user, ok } diff --git a/backend/pkg/spot/api/router.go b/backend/pkg/spot/api/router.go index a6fda7b6e..e9e8d77b4 100644 --- a/backend/pkg/spot/api/router.go +++ b/backend/pkg/spot/api/router.go @@ -1,28 +1,22 @@ package api import ( - "bytes" "fmt" - "io" "net/http" + "openreplay/backend/pkg/common" + "openreplay/backend/pkg/common/api" + "openreplay/backend/pkg/common/middleware" "openreplay/backend/pkg/spot" - "openreplay/backend/pkg/spot/auth" - "sync" "time" - "github.com/docker/distribution/context" - "github.com/gorilla/mux" - spotConfig "openreplay/backend/internal/config/spot" - "openreplay/backend/internal/http/util" "openreplay/backend/pkg/logger" ) type Router struct { + *api.Router log logger.Logger cfg *spotConfig.Config - router *mux.Router - mutex *sync.RWMutex services *spot.ServicesBuilder limiter *UserRateLimiter } @@ -37,9 +31,9 @@ func NewRouter(cfg *spotConfig.Config, log logger.Logger, services *spot.Service return nil, fmt.Errorf("logger is empty") } e := &Router{ + Router: api.NewRouter(log), log: log, cfg: cfg, - mutex: &sync.RWMutex{}, services: services, limiter: NewUserRateLimiter(10, 30, 1*time.Minute, 5*time.Minute), } @@ -48,166 +42,175 @@ func NewRouter(cfg *spotConfig.Config, log logger.Logger, services *spot.Service } func (e *Router) init() { - e.router = mux.NewRouter() - - // Root route - e.router.HandleFunc("/", e.ping) + //e.router = mux.NewRouter() // Spot routes - e.router.HandleFunc("/v1/spots", e.createSpot).Methods("POST", "OPTIONS") - e.router.HandleFunc("/v1/spots/{id}", e.getSpot).Methods("GET", "OPTIONS") - e.router.HandleFunc("/v1/spots/{id}", e.updateSpot).Methods("PATCH", "OPTIONS") - e.router.HandleFunc("/v1/spots", e.getSpots).Methods("GET", "OPTIONS") - e.router.HandleFunc("/v1/spots", e.deleteSpots).Methods("DELETE", "OPTIONS") - e.router.HandleFunc("/v1/spots/{id}/comment", e.addComment).Methods("POST", "OPTIONS") - e.router.HandleFunc("/v1/spots/{id}/uploaded", e.uploadedSpot).Methods("POST", "OPTIONS") - e.router.HandleFunc("/v1/spots/{id}/video", e.getSpotVideo).Methods("GET", "OPTIONS") - e.router.HandleFunc("/v1/spots/{id}/public-key", e.getPublicKey).Methods("GET", "OPTIONS") - e.router.HandleFunc("/v1/spots/{id}/public-key", e.updatePublicKey).Methods("PATCH", "OPTIONS") - e.router.HandleFunc("/v1/spots/{id}/status", e.spotStatus).Methods("GET", "OPTIONS") - e.router.HandleFunc("/v1/ping", e.ping).Methods("GET", "OPTIONS") + //e.AddRoute("/v1/spots", e.createSpot, "POST") + //e.AddRoute("/v1/spots/{id}", e.getSpot, "GET") + //e.AddRoute("/v1/spots/{id}", e.updateSpot, "PATCH") + e.AddRoute("/v1/spots", e.getSpots, "GET") + //e.AddRoute("/v1/spots", e.deleteSpots, "DELETE") + //e.AddRoute("/v1/spots/{id}/comment", e.addComment, "POST") + //e.AddRoute("/v1/spots/{id}/uploaded", e.uploadedSpot, "POST") + //e.AddRoute("/v1/spots/{id}/video", e.getSpotVideo, "GET") + //e.AddRoute("/v1/spots/{id}/public-key", e.getPublicKey, "GET") + //e.AddRoute("/v1/spots/{id}/public-key", e.updatePublicKey, "PATCH") + //e.AddRoute("/v1/spots/{id}/status", e.spotStatus, "GET") + e.AddRoute("/v1/ping", e.ping, "GET") + + excludedPaths := map[string]map[string]bool{ + //"/v1/ping": {"GET": true}, + //"/v1/spots": {"POST": true}, + } + + authMiddleware := middleware.AuthMiddleware(e.services.Auth, e.log, excludedPaths, getPermissions, nil) + limiterMiddleware := middleware.RateLimit(common.NewUserRateLimiter(10, 30, 1*time.Minute, 5*time.Minute)) + e.Use(middleware.CORS(e.cfg.UseAccessControlHeaders)) + e.Use(authMiddleware) + e.Use(limiterMiddleware) + e.Use(middleware.Action()) // CORS middleware - e.router.Use(e.corsMiddleware) - e.router.Use(e.authMiddleware) - e.router.Use(e.rateLimitMiddleware) - e.router.Use(e.actionMiddleware) + //e.router.Use(e.corsMiddleware) + //e.router.Use(e.authMiddleware) + //e.router.Use(e.rateLimitMiddleware) + //e.router.Use(e.actionMiddleware) } func (e *Router) ping(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } -func (e *Router) corsMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/" { - next.ServeHTTP(w, r) - } - if e.cfg.UseAccessControlHeaders { - // Prepare headers for preflight requests - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "POST,GET,PATCH,DELETE") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type,Authorization,Content-Encoding") - } - if r.Method == http.MethodOptions { - w.Header().Set("Cache-Control", "max-age=86400") - w.WriteHeader(http.StatusOK) - return - } - r = r.WithContext(context.WithValues(r.Context(), map[string]interface{}{"httpMethod": r.Method, "url": util.SafeString(r.URL.Path)})) +//func (e *Router) corsMiddleware(next http.Handler) http.Handler { +// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// if r.URL.Path == "/" { +// next.ServeHTTP(w, r) +// } +// if e.cfg.UseAccessControlHeaders { +// // Prepare headers for preflight requests +// w.Header().Set("Access-Control-Allow-Origin", "*") +// w.Header().Set("Access-Control-Allow-Methods", "POST,GET,PATCH,DELETE") +// w.Header().Set("Access-Control-Allow-Headers", "Content-Type,Authorization,Content-Encoding") +// } +// if r.Method == http.MethodOptions { +// w.Header().Set("Cache-Control", "max-age=86400") +// w.WriteHeader(http.StatusOK) +// return +// } +// r = r.WithContext(context.WithValues(r.Context(), map[string]interface{}{"httpMethod": r.Method, "url": util.SafeString(r.URL.Path)})) +// +// next.ServeHTTP(w, r) +// }) +//} - next.ServeHTTP(w, r) - }) -} +//func (e *Router) authMiddleware(next http.Handler) http.Handler { +// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// if r.URL.Path == "/" { +// next.ServeHTTP(w, r) +// } +// isExtension := false +// pathTemplate, err := mux.CurrentRoute(r).GetPathTemplate() +// if err != nil { +// e.log.Error(r.Context(), "failed to get path template: %s", err) +// } else { +// if pathTemplate == "/v1/ping" || +// (pathTemplate == "/v1/spots" && r.Method == "POST") || +// (pathTemplate == "/v1/spots/{id}/uploaded" && r.Method == "POST") { +// isExtension = true +// } +// } +// +// // Check if the request is authorized +// user, err := e.services.Auth.IsAuthorized(r.Header.Get("Authorization"), getPermissions(r.URL.Path), isExtension) +// if err != nil { +// e.log.Warn(r.Context(), "Unauthorized request: %s", err) +// if !isSpotWithKeyRequest(r) { +// w.WriteHeader(http.StatusUnauthorized) +// return +// } +// +// user, err = e.services.Keys.IsValid(r.URL.Query().Get("key")) +// if err != nil { +// e.log.Warn(r.Context(), "Wrong public key: %s", err) +// w.WriteHeader(http.StatusUnauthorized) +// return +// } +// } +// +// r = r.WithContext(context.WithValues(r.Context(), map[string]interface{}{"userData": user})) +// next.ServeHTTP(w, r) +// }) +//} -func (e *Router) authMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/" { - next.ServeHTTP(w, r) - } - isExtension := false - pathTemplate, err := mux.CurrentRoute(r).GetPathTemplate() - if err != nil { - e.log.Error(r.Context(), "failed to get path template: %s", err) - } else { - if pathTemplate == "/v1/ping" || - (pathTemplate == "/v1/spots" && r.Method == "POST") || - (pathTemplate == "/v1/spots/{id}/uploaded" && r.Method == "POST") { - isExtension = true - } - } +//func isSpotWithKeyRequest(r *http.Request) bool { +// pathTemplate, err := mux.CurrentRoute(r).GetPathTemplate() +// if err != nil { +// return false +// } +// getSpotPrefix := "/v1/spots/{id}" // GET +// addCommentPrefix := "/v1/spots/{id}/comment" // POST +// getStatusPrefix := "/v1/spots/{id}/status" // GET +// if (pathTemplate == getSpotPrefix && r.Method == "GET") || +// (pathTemplate == addCommentPrefix && r.Method == "POST") || +// (pathTemplate == getStatusPrefix && r.Method == "GET") { +// return true +// } +// return false +//} - // Check if the request is authorized - user, err := e.services.Auth.IsAuthorized(r.Header.Get("Authorization"), getPermissions(r.URL.Path), isExtension) - if err != nil { - e.log.Warn(r.Context(), "Unauthorized request: %s", err) - if !isSpotWithKeyRequest(r) { - w.WriteHeader(http.StatusUnauthorized) - return - } +//func (e *Router) rateLimitMiddleware(next http.Handler) http.Handler { +// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// if r.URL.Path == "/" { +// next.ServeHTTP(w, r) +// } +// user := r.Context().Value("userData").(*auth.User) +// rl := e.limiter.GetRateLimiter(user.ID) +// +// if !rl.Allow() { +// http.Error(w, "Too Many Requests", http.StatusTooManyRequests) +// return +// } +// next.ServeHTTP(w, r) +// }) +//} - user, err = e.services.Keys.IsValid(r.URL.Query().Get("key")) - if err != nil { - e.log.Warn(r.Context(), "Wrong public key: %s", err) - w.WriteHeader(http.StatusUnauthorized) - return - } - } +//type statusWriter struct { +// http.ResponseWriter +// statusCode int +//} +// +//func (w *statusWriter) WriteHeader(statusCode int) { +// w.statusCode = statusCode +// w.ResponseWriter.WriteHeader(statusCode) +//} +// +//func (w *statusWriter) Write(b []byte) (int, error) { +// if w.statusCode == 0 { +// w.statusCode = http.StatusOK // Default status code is 200 +// } +// return w.ResponseWriter.Write(b) +//} - r = r.WithContext(context.WithValues(r.Context(), map[string]interface{}{"userData": user})) - next.ServeHTTP(w, r) - }) -} +//func (e *Router) actionMiddleware(next http.Handler) http.Handler { +// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// if r.URL.Path == "/" { +// next.ServeHTTP(w, r) +// } +// // Read body and restore the io.ReadCloser to its original state +// bodyBytes, err := io.ReadAll(r.Body) +// if err != nil { +// http.Error(w, "can't read body", http.StatusBadRequest) +// return +// } +// r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) +// // Use custom response writer to get the status code +// sw := &statusWriter{ResponseWriter: w} +// // Serve the request +// next.ServeHTTP(sw, r) +// e.logRequest(r, bodyBytes, sw.statusCode) +// }) +//} -func isSpotWithKeyRequest(r *http.Request) bool { - pathTemplate, err := mux.CurrentRoute(r).GetPathTemplate() - if err != nil { - return false - } - getSpotPrefix := "/v1/spots/{id}" // GET - addCommentPrefix := "/v1/spots/{id}/comment" // POST - getStatusPrefix := "/v1/spots/{id}/status" // GET - if (pathTemplate == getSpotPrefix && r.Method == "GET") || - (pathTemplate == addCommentPrefix && r.Method == "POST") || - (pathTemplate == getStatusPrefix && r.Method == "GET") { - return true - } - return false -} - -func (e *Router) rateLimitMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/" { - next.ServeHTTP(w, r) - } - user := r.Context().Value("userData").(*auth.User) - rl := e.limiter.GetRateLimiter(user.ID) - - if !rl.Allow() { - http.Error(w, "Too Many Requests", http.StatusTooManyRequests) - return - } - next.ServeHTTP(w, r) - }) -} - -type statusWriter struct { - http.ResponseWriter - statusCode int -} - -func (w *statusWriter) WriteHeader(statusCode int) { - w.statusCode = statusCode - w.ResponseWriter.WriteHeader(statusCode) -} - -func (w *statusWriter) Write(b []byte) (int, error) { - if w.statusCode == 0 { - w.statusCode = http.StatusOK // Default status code is 200 - } - return w.ResponseWriter.Write(b) -} - -func (e *Router) actionMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/" { - next.ServeHTTP(w, r) - } - // Read body and restore the io.ReadCloser to its original state - bodyBytes, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "can't read body", http.StatusBadRequest) - return - } - r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - // Use custom response writer to get the status code - sw := &statusWriter{ResponseWriter: w} - // Serve the request - next.ServeHTTP(sw, r) - e.logRequest(r, bodyBytes, sw.statusCode) - }) -} - -func (e *Router) GetHandler() http.Handler { - return e.router -} +//func (e *Router) GetHandler() http.Handler { +// return e.router +//} diff --git a/backend/pkg/spot/builder.go b/backend/pkg/spot/builder.go index 047318844..1a1ebf1ae 100644 --- a/backend/pkg/spot/builder.go +++ b/backend/pkg/spot/builder.go @@ -2,38 +2,35 @@ package spot import ( "openreplay/backend/internal/config/spot" - "openreplay/backend/pkg/db/postgres/pool" - "openreplay/backend/pkg/flakeid" + "openreplay/backend/pkg/common" "openreplay/backend/pkg/logger" - "openreplay/backend/pkg/objectstorage" - "openreplay/backend/pkg/objectstorage/store" - "openreplay/backend/pkg/spot/auth" "openreplay/backend/pkg/spot/service" "openreplay/backend/pkg/spot/transcoder" ) type ServicesBuilder struct { - Flaker *flakeid.Flaker - ObjStorage objectstorage.ObjectStorage - Auth auth.Auth + *common.ServicesBuilder Spots service.Spots Keys service.Keys Transcoder transcoder.Transcoder + cfg *spot.Config } -func NewServiceBuilder(log logger.Logger, cfg *spot.Config, pgconn pool.Pool) (*ServicesBuilder, error) { - objStore, err := store.NewStore(&cfg.ObjectsConfig) - if err != nil { - return nil, err - } - flaker := flakeid.NewFlaker(cfg.WorkerID) - spots := service.NewSpots(log, pgconn, flaker) +func NewServiceBuilder(log logger.Logger, cfg *spot.Config) (*ServicesBuilder, error) { + builder := common.NewServiceBuilder(log). + WithDatabase(cfg.Postgres.String()). + WithAuth(cfg.JWTSecret, cfg.JWTSpotSecret). + WithObjectStorage(&cfg.ObjectsConfig) + + keys := service.NewKeys(log, builder.Pgconn) + spots := service.NewSpots(log, builder.Pgconn, builder.Flaker) + tc := transcoder.NewTranscoder(cfg, log, builder.ObjStorage, builder.Pgconn, spots) + return &ServicesBuilder{ - Flaker: flaker, - ObjStorage: objStore, - Auth: auth.NewAuth(log, cfg.JWTSecret, cfg.JWTSpotSecret, pgconn), - Spots: spots, - Keys: service.NewKeys(log, pgconn), - Transcoder: transcoder.NewTranscoder(cfg, log, objStore, pgconn, spots), + ServicesBuilder: builder, + Spots: spots, + Keys: keys, + Transcoder: tc, + cfg: cfg, }, nil }