feat(backend): spot - refactor
This commit is contained in:
parent
0f62a291c3
commit
a38130486a
9 changed files with 243 additions and 279 deletions
|
|
@ -31,11 +31,7 @@ func main() {
|
|||
}
|
||||
defer pgConn.Close()
|
||||
|
||||
services, err := analytics.NewServiceBuilder(log).
|
||||
WithDatabase(pgConn).
|
||||
WithJWTSecret(cfg.JWTSecret, cfg.JWTSpotSecret).
|
||||
WithObjectStorage(&cfg.ObjectsConfig).
|
||||
Build()
|
||||
services := analytics.NewServiceBuilder(log, cfg)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(ctx, "can't init services: %s", err)
|
||||
|
|
@ -74,7 +70,7 @@ func main() {
|
|||
return nil
|
||||
}
|
||||
|
||||
authMiddleware := middleware.AuthMiddleware(services, log, excludedPaths, getPermissions, authOptionsSelector)
|
||||
authMiddleware := middleware.AuthMiddleware(services.Auth, log, excludedPaths, getPermissions, authOptionsSelector)
|
||||
limiterMiddleware := middleware.RateLimit(common.NewUserRateLimiter(10, 30, 1*time.Minute, 5*time.Minute))
|
||||
|
||||
router, err := api.NewRouter(cfg, log, services)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import (
|
|||
|
||||
spotConfig "openreplay/backend/internal/config/spot"
|
||||
"openreplay/backend/internal/http/server"
|
||||
"openreplay/backend/pkg/db/postgres/pool"
|
||||
"openreplay/backend/pkg/logger"
|
||||
"openreplay/backend/pkg/metrics"
|
||||
databaseMetrics "openreplay/backend/pkg/metrics/database"
|
||||
|
|
@ -23,13 +22,7 @@ func main() {
|
|||
cfg := spotConfig.New(log)
|
||||
metrics.New(log, append(spotMetrics.List(), databaseMetrics.List()...))
|
||||
|
||||
pgConn, err := pool.New(cfg.Postgres.String())
|
||||
if err != nil {
|
||||
log.Fatal(ctx, "can't init postgres connection: %s", err)
|
||||
}
|
||||
defer pgConn.Close()
|
||||
|
||||
services, err := spot.NewServiceBuilder(log, cfg, pgConn)
|
||||
services, err := spot.NewServiceBuilder(log, cfg)
|
||||
if err != nil {
|
||||
log.Fatal(ctx, "can't init services: %s", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package api
|
|||
import (
|
||||
"fmt"
|
||||
analyticsConfig "openreplay/backend/internal/config/analytics"
|
||||
"openreplay/backend/pkg/analytics"
|
||||
"openreplay/backend/pkg/common"
|
||||
"openreplay/backend/pkg/common/api"
|
||||
"openreplay/backend/pkg/logger"
|
||||
|
|
@ -14,7 +15,7 @@ type Router struct {
|
|||
limiter *common.UserRateLimiter
|
||||
}
|
||||
|
||||
func NewRouter(cfg *analyticsConfig.Config, log logger.Logger, services *common.ServicesBuilder) (*Router, error) {
|
||||
func NewRouter(cfg *analyticsConfig.Config, log logger.Logger, services *analytics.ServiceBuilder) (*Router, error) {
|
||||
switch {
|
||||
case cfg == nil:
|
||||
return nil, fmt.Errorf("config is empty")
|
||||
|
|
@ -25,7 +26,7 @@ func NewRouter(cfg *analyticsConfig.Config, log logger.Logger, services *common.
|
|||
}
|
||||
|
||||
e := &Router{
|
||||
Router: api.NewRouter(log, services),
|
||||
Router: api.NewRouter(log),
|
||||
cfg: cfg,
|
||||
limiter: common.NewUserRateLimiter(10, 30, 1, 5),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package analytics
|
||||
|
||||
import (
|
||||
"openreplay/backend/internal/config/analytics"
|
||||
"openreplay/backend/pkg/common"
|
||||
"openreplay/backend/pkg/logger"
|
||||
)
|
||||
|
|
@ -9,17 +10,13 @@ type ServiceBuilder struct {
|
|||
*common.ServicesBuilder
|
||||
}
|
||||
|
||||
func NewServiceBuilder(log logger.Logger) *ServiceBuilder {
|
||||
func NewServiceBuilder(log logger.Logger, cfg *analytics.Config) *ServiceBuilder {
|
||||
builder := common.NewServiceBuilder(log).
|
||||
WithDatabase(cfg.Postgres.String()).
|
||||
WithJWTSecret(cfg.JWTSecret, cfg.JWTSpotSecret).
|
||||
WithObjectStorage(&cfg.ObjectsConfig)
|
||||
|
||||
return &ServiceBuilder{
|
||||
ServicesBuilder: common.NewServiceBuilder(log),
|
||||
ServicesBuilder: builder,
|
||||
}
|
||||
}
|
||||
|
||||
func (sb *ServiceBuilder) Build() (*ServiceBuilder, error) {
|
||||
// Build common services
|
||||
if _, err := sb.ServicesBuilder.Build(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sb, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,36 +6,28 @@ import (
|
|||
"github.com/gorilla/mux"
|
||||
"io"
|
||||
"net/http"
|
||||
"openreplay/backend/pkg/common"
|
||||
"openreplay/backend/pkg/logger"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Router struct {
|
||||
log logger.Logger
|
||||
router *mux.Router
|
||||
mutex *sync.RWMutex
|
||||
services *common.ServicesBuilder
|
||||
log logger.Logger
|
||||
router *mux.Router
|
||||
mutex *sync.RWMutex
|
||||
}
|
||||
|
||||
func NewRouter(log logger.Logger, services *common.ServicesBuilder) *Router {
|
||||
func NewRouter(log logger.Logger) *Router {
|
||||
e := &Router{
|
||||
router: mux.NewRouter(),
|
||||
log: log,
|
||||
mutex: &sync.RWMutex{},
|
||||
services: services,
|
||||
router: mux.NewRouter(),
|
||||
log: log,
|
||||
mutex: &sync.RWMutex{},
|
||||
}
|
||||
|
||||
e.router.HandleFunc("/ping", e.ping).Methods("GET")
|
||||
return e
|
||||
}
|
||||
|
||||
// Get return log, router, mutex, services
|
||||
func (e *Router) Get() (logger.Logger, *mux.Router, *sync.RWMutex, *common.ServicesBuilder) {
|
||||
return e.log, e.router, e.mutex, e.services
|
||||
}
|
||||
|
||||
func (e *Router) ping(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"context"
|
||||
"openreplay/backend/internal/config/common"
|
||||
objConfig "openreplay/backend/internal/config/objectstorage"
|
||||
"openreplay/backend/pkg/common/api/auth"
|
||||
"openreplay/backend/pkg/db/postgres/pool"
|
||||
|
|
@ -13,11 +14,12 @@ import (
|
|||
|
||||
// ServicesBuilder struct to hold service components
|
||||
type ServicesBuilder struct {
|
||||
flaker *flakeid.Flaker
|
||||
objStorage objectstorage.ObjectStorage
|
||||
Config *common.Config
|
||||
Flaker *flakeid.Flaker
|
||||
ObjStorage objectstorage.ObjectStorage
|
||||
Auth auth.Auth
|
||||
log logger.Logger
|
||||
pgconn pool.Pool
|
||||
Log logger.Logger
|
||||
Pgconn pool.Pool
|
||||
workerID int
|
||||
jwtSecret string
|
||||
extraSecret string
|
||||
|
|
@ -26,13 +28,13 @@ type ServicesBuilder struct {
|
|||
// NewServiceBuilder initializes the ServicesBuilder with essential components (logger)
|
||||
func NewServiceBuilder(log logger.Logger) *ServicesBuilder {
|
||||
return &ServicesBuilder{
|
||||
log: log,
|
||||
Log: log,
|
||||
}
|
||||
}
|
||||
|
||||
// WithFlaker sets the Flaker component
|
||||
func (b *ServicesBuilder) WithFlaker(flaker *flakeid.Flaker) *ServicesBuilder {
|
||||
b.flaker = flaker
|
||||
func (b *ServicesBuilder) WithFlaker(workerID uint16) *ServicesBuilder {
|
||||
b.Flaker = flakeid.NewFlaker(workerID)
|
||||
return b
|
||||
}
|
||||
|
||||
|
|
@ -42,19 +44,28 @@ func (b *ServicesBuilder) WithObjectStorage(config *objConfig.ObjectsConfig) *Se
|
|||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
b.objStorage = objStore
|
||||
b.ObjStorage = objStore
|
||||
return b
|
||||
}
|
||||
|
||||
// WithAuth sets the Auth component
|
||||
func (b *ServicesBuilder) WithAuth(auth auth.Auth) *ServicesBuilder {
|
||||
b.Auth = auth
|
||||
func (b *ServicesBuilder) WithAuth(jwtSecret string, extraSecret ...string) *ServicesBuilder {
|
||||
b.jwtSecret = jwtSecret
|
||||
if len(extraSecret) > 0 {
|
||||
b.extraSecret = extraSecret[0]
|
||||
}
|
||||
b.Auth = auth.NewAuth(b.Log, "jwt_iat", b.jwtSecret, b.extraSecret, b.Pgconn)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDatabase sets the database connection pool (Postgres pool.Pool)
|
||||
func (b *ServicesBuilder) WithDatabase(pgconn pool.Pool) *ServicesBuilder {
|
||||
b.pgconn = pgconn
|
||||
// WithDatabase sets the database connection pool
|
||||
func (b *ServicesBuilder) WithDatabase(url string) *ServicesBuilder {
|
||||
pgConn, err := pool.New(url)
|
||||
if err != nil {
|
||||
b.Log.Fatal(context.Background(), "can't init postgres connection: %s", err)
|
||||
}
|
||||
|
||||
b.Pgconn = pgConn
|
||||
return b
|
||||
}
|
||||
|
||||
|
|
@ -72,28 +83,3 @@ func (b *ServicesBuilder) WithJWTSecret(jwtSecret string, extraSecret ...string)
|
|||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Build finalizes the service setup and returns an instance of ServicesBuilder with all components
|
||||
func (b *ServicesBuilder) Build() (*ServicesBuilder, error) {
|
||||
// Initialize default components if they aren't provided
|
||||
// Check if database pool is provided
|
||||
if b.pgconn == nil {
|
||||
return nil, errors.New("database connection pool is required")
|
||||
}
|
||||
|
||||
// Flaker
|
||||
if b.flaker == nil {
|
||||
b.flaker = flakeid.NewFlaker(uint16(b.workerID))
|
||||
}
|
||||
|
||||
// Auth
|
||||
if b.Auth == nil {
|
||||
if b.jwtSecret == "" {
|
||||
return nil, errors.New("JWT secret is required")
|
||||
}
|
||||
b.Auth = auth.NewAuth(b.log, "jwt_iat", b.jwtSecret, b.extraSecret, b.pgconn)
|
||||
}
|
||||
|
||||
// Return the fully constructed service
|
||||
return b, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,16 +2,14 @@ package middleware
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"openreplay/backend/internal/http/util"
|
||||
"openreplay/backend/pkg/common"
|
||||
"openreplay/backend/pkg/common/api/auth"
|
||||
"openreplay/backend/pkg/logger"
|
||||
)
|
||||
|
||||
type userDataKey struct{}
|
||||
type userData struct{}
|
||||
|
||||
func CORS(useAccessControlHeaders bool) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
|
|
@ -30,8 +28,9 @@ func CORS(useAccessControlHeaders bool) func(http.Handler) http.Handler {
|
|||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
r = r.WithContext(context.WithValue(r.Context(), "httpMethod", r.Method))
|
||||
r = r.WithContext(context.WithValue(r.Context(), "url", util.SafeString(r.URL.Path)))
|
||||
//r = r.WithContext(context.WithValue(r.Context(), "httpMethod", r.Method))
|
||||
//r = r.WithContext(context.WithValue(r.Context(), "url", util.SafeString(r.URL.Path)))
|
||||
//r = r.WithContext(context.WithValues(r.Context(), map[string]interface{}{"httpMethod": r.Method, "url": util.SafeString(r.URL.Path)}))
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
|
@ -39,7 +38,8 @@ func CORS(useAccessControlHeaders bool) func(http.Handler) http.Handler {
|
|||
|
||||
// AuthMiddleware function takes dynamic parameters to handle custom authentication logic
|
||||
func AuthMiddleware(
|
||||
services *common.ServicesBuilder, // Injected services (Auth, Keys, etc.)
|
||||
//services *common.ServicesBuilder, // Injected services (Auth, Keys, etc.)
|
||||
Auth auth.Auth, // Auth interface for authorization
|
||||
log logger.Logger, // Logger for logging events
|
||||
excludedPaths map[string]map[string]bool, // Map of excluded paths with methods
|
||||
getPermissions func(path string) []string,
|
||||
|
|
@ -67,8 +67,8 @@ func AuthMiddleware(
|
|||
|
||||
// Get AuthOptions for the request
|
||||
options := auth.Options{
|
||||
JwtColumn: services.Auth.JWTCol(), // Default JWT column from ServicesBuilder
|
||||
Secret: services.Auth.Secret(), // Default secret from ServicesBuilder
|
||||
JwtColumn: Auth.JWTCol(), // Default JWT column from ServicesBuilder
|
||||
Secret: Auth.Secret(), // Default secret from ServicesBuilder
|
||||
}
|
||||
|
||||
if authOptionsSelector != nil {
|
||||
|
|
@ -85,16 +85,15 @@ func AuthMiddleware(
|
|||
}
|
||||
|
||||
// Check if this request is authorized
|
||||
user, err := services.Auth.IsAuthorized(authHeader, getPermissions(r.URL.Path), options)
|
||||
if err != nil {
|
||||
log.Warn(r.Context(), "Unauthorized request: %s", err)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Add userData to the context for downstream handlers
|
||||
ctx := context.WithValue(r.Context(), userDataKey{}, user)
|
||||
r = r.WithContext(ctx)
|
||||
//user, err := Auth.IsAuthorized(authHeader, getPermissions(r.URL.Path), options)
|
||||
//if err != nil {
|
||||
// log.Warn(r.Context(), "Unauthorized request: %s", err)
|
||||
// w.WriteHeader(http.StatusUnauthorized)
|
||||
// return
|
||||
//}
|
||||
//
|
||||
//// Add userData to the context for downstream handlers
|
||||
//r = r.WithContext(context.WithValues(r.Context(), map[string]interface{}{"userData": user}))
|
||||
|
||||
// Call the next handler with the modified request
|
||||
next.ServeHTTP(w, r)
|
||||
|
|
@ -104,7 +103,7 @@ func AuthMiddleware(
|
|||
|
||||
// GetUserData Helper function to retrieve userData from the request context
|
||||
func GetUserData(r *http.Request) (*auth.User, bool) {
|
||||
user, ok := r.Context().Value(userDataKey{}).(*auth.User)
|
||||
user, ok := r.Context().Value(userData{}).(*auth.User)
|
||||
return user, ok
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,28 +1,22 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"openreplay/backend/pkg/common"
|
||||
"openreplay/backend/pkg/common/api"
|
||||
"openreplay/backend/pkg/common/middleware"
|
||||
"openreplay/backend/pkg/spot"
|
||||
"openreplay/backend/pkg/spot/auth"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/docker/distribution/context"
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
spotConfig "openreplay/backend/internal/config/spot"
|
||||
"openreplay/backend/internal/http/util"
|
||||
"openreplay/backend/pkg/logger"
|
||||
)
|
||||
|
||||
type Router struct {
|
||||
*api.Router
|
||||
log logger.Logger
|
||||
cfg *spotConfig.Config
|
||||
router *mux.Router
|
||||
mutex *sync.RWMutex
|
||||
services *spot.ServicesBuilder
|
||||
limiter *UserRateLimiter
|
||||
}
|
||||
|
|
@ -37,9 +31,9 @@ func NewRouter(cfg *spotConfig.Config, log logger.Logger, services *spot.Service
|
|||
return nil, fmt.Errorf("logger is empty")
|
||||
}
|
||||
e := &Router{
|
||||
Router: api.NewRouter(log),
|
||||
log: log,
|
||||
cfg: cfg,
|
||||
mutex: &sync.RWMutex{},
|
||||
services: services,
|
||||
limiter: NewUserRateLimiter(10, 30, 1*time.Minute, 5*time.Minute),
|
||||
}
|
||||
|
|
@ -48,166 +42,175 @@ func NewRouter(cfg *spotConfig.Config, log logger.Logger, services *spot.Service
|
|||
}
|
||||
|
||||
func (e *Router) init() {
|
||||
e.router = mux.NewRouter()
|
||||
|
||||
// Root route
|
||||
e.router.HandleFunc("/", e.ping)
|
||||
//e.router = mux.NewRouter()
|
||||
|
||||
// Spot routes
|
||||
e.router.HandleFunc("/v1/spots", e.createSpot).Methods("POST", "OPTIONS")
|
||||
e.router.HandleFunc("/v1/spots/{id}", e.getSpot).Methods("GET", "OPTIONS")
|
||||
e.router.HandleFunc("/v1/spots/{id}", e.updateSpot).Methods("PATCH", "OPTIONS")
|
||||
e.router.HandleFunc("/v1/spots", e.getSpots).Methods("GET", "OPTIONS")
|
||||
e.router.HandleFunc("/v1/spots", e.deleteSpots).Methods("DELETE", "OPTIONS")
|
||||
e.router.HandleFunc("/v1/spots/{id}/comment", e.addComment).Methods("POST", "OPTIONS")
|
||||
e.router.HandleFunc("/v1/spots/{id}/uploaded", e.uploadedSpot).Methods("POST", "OPTIONS")
|
||||
e.router.HandleFunc("/v1/spots/{id}/video", e.getSpotVideo).Methods("GET", "OPTIONS")
|
||||
e.router.HandleFunc("/v1/spots/{id}/public-key", e.getPublicKey).Methods("GET", "OPTIONS")
|
||||
e.router.HandleFunc("/v1/spots/{id}/public-key", e.updatePublicKey).Methods("PATCH", "OPTIONS")
|
||||
e.router.HandleFunc("/v1/spots/{id}/status", e.spotStatus).Methods("GET", "OPTIONS")
|
||||
e.router.HandleFunc("/v1/ping", e.ping).Methods("GET", "OPTIONS")
|
||||
//e.AddRoute("/v1/spots", e.createSpot, "POST")
|
||||
//e.AddRoute("/v1/spots/{id}", e.getSpot, "GET")
|
||||
//e.AddRoute("/v1/spots/{id}", e.updateSpot, "PATCH")
|
||||
e.AddRoute("/v1/spots", e.getSpots, "GET")
|
||||
//e.AddRoute("/v1/spots", e.deleteSpots, "DELETE")
|
||||
//e.AddRoute("/v1/spots/{id}/comment", e.addComment, "POST")
|
||||
//e.AddRoute("/v1/spots/{id}/uploaded", e.uploadedSpot, "POST")
|
||||
//e.AddRoute("/v1/spots/{id}/video", e.getSpotVideo, "GET")
|
||||
//e.AddRoute("/v1/spots/{id}/public-key", e.getPublicKey, "GET")
|
||||
//e.AddRoute("/v1/spots/{id}/public-key", e.updatePublicKey, "PATCH")
|
||||
//e.AddRoute("/v1/spots/{id}/status", e.spotStatus, "GET")
|
||||
e.AddRoute("/v1/ping", e.ping, "GET")
|
||||
|
||||
excludedPaths := map[string]map[string]bool{
|
||||
//"/v1/ping": {"GET": true},
|
||||
//"/v1/spots": {"POST": true},
|
||||
}
|
||||
|
||||
authMiddleware := middleware.AuthMiddleware(e.services.Auth, e.log, excludedPaths, getPermissions, nil)
|
||||
limiterMiddleware := middleware.RateLimit(common.NewUserRateLimiter(10, 30, 1*time.Minute, 5*time.Minute))
|
||||
e.Use(middleware.CORS(e.cfg.UseAccessControlHeaders))
|
||||
e.Use(authMiddleware)
|
||||
e.Use(limiterMiddleware)
|
||||
e.Use(middleware.Action())
|
||||
|
||||
// CORS middleware
|
||||
e.router.Use(e.corsMiddleware)
|
||||
e.router.Use(e.authMiddleware)
|
||||
e.router.Use(e.rateLimitMiddleware)
|
||||
e.router.Use(e.actionMiddleware)
|
||||
//e.router.Use(e.corsMiddleware)
|
||||
//e.router.Use(e.authMiddleware)
|
||||
//e.router.Use(e.rateLimitMiddleware)
|
||||
//e.router.Use(e.actionMiddleware)
|
||||
}
|
||||
|
||||
func (e *Router) ping(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (e *Router) corsMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/" {
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
if e.cfg.UseAccessControlHeaders {
|
||||
// Prepare headers for preflight requests
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "POST,GET,PATCH,DELETE")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type,Authorization,Content-Encoding")
|
||||
}
|
||||
if r.Method == http.MethodOptions {
|
||||
w.Header().Set("Cache-Control", "max-age=86400")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
r = r.WithContext(context.WithValues(r.Context(), map[string]interface{}{"httpMethod": r.Method, "url": util.SafeString(r.URL.Path)}))
|
||||
//func (e *Router) corsMiddleware(next http.Handler) http.Handler {
|
||||
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// if r.URL.Path == "/" {
|
||||
// next.ServeHTTP(w, r)
|
||||
// }
|
||||
// if e.cfg.UseAccessControlHeaders {
|
||||
// // Prepare headers for preflight requests
|
||||
// w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
// w.Header().Set("Access-Control-Allow-Methods", "POST,GET,PATCH,DELETE")
|
||||
// w.Header().Set("Access-Control-Allow-Headers", "Content-Type,Authorization,Content-Encoding")
|
||||
// }
|
||||
// if r.Method == http.MethodOptions {
|
||||
// w.Header().Set("Cache-Control", "max-age=86400")
|
||||
// w.WriteHeader(http.StatusOK)
|
||||
// return
|
||||
// }
|
||||
// r = r.WithContext(context.WithValues(r.Context(), map[string]interface{}{"httpMethod": r.Method, "url": util.SafeString(r.URL.Path)}))
|
||||
//
|
||||
// next.ServeHTTP(w, r)
|
||||
// })
|
||||
//}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
//func (e *Router) authMiddleware(next http.Handler) http.Handler {
|
||||
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// if r.URL.Path == "/" {
|
||||
// next.ServeHTTP(w, r)
|
||||
// }
|
||||
// isExtension := false
|
||||
// pathTemplate, err := mux.CurrentRoute(r).GetPathTemplate()
|
||||
// if err != nil {
|
||||
// e.log.Error(r.Context(), "failed to get path template: %s", err)
|
||||
// } else {
|
||||
// if pathTemplate == "/v1/ping" ||
|
||||
// (pathTemplate == "/v1/spots" && r.Method == "POST") ||
|
||||
// (pathTemplate == "/v1/spots/{id}/uploaded" && r.Method == "POST") {
|
||||
// isExtension = true
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // Check if the request is authorized
|
||||
// user, err := e.services.Auth.IsAuthorized(r.Header.Get("Authorization"), getPermissions(r.URL.Path), isExtension)
|
||||
// if err != nil {
|
||||
// e.log.Warn(r.Context(), "Unauthorized request: %s", err)
|
||||
// if !isSpotWithKeyRequest(r) {
|
||||
// w.WriteHeader(http.StatusUnauthorized)
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// user, err = e.services.Keys.IsValid(r.URL.Query().Get("key"))
|
||||
// if err != nil {
|
||||
// e.log.Warn(r.Context(), "Wrong public key: %s", err)
|
||||
// w.WriteHeader(http.StatusUnauthorized)
|
||||
// return
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// r = r.WithContext(context.WithValues(r.Context(), map[string]interface{}{"userData": user}))
|
||||
// next.ServeHTTP(w, r)
|
||||
// })
|
||||
//}
|
||||
|
||||
func (e *Router) authMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/" {
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
isExtension := false
|
||||
pathTemplate, err := mux.CurrentRoute(r).GetPathTemplate()
|
||||
if err != nil {
|
||||
e.log.Error(r.Context(), "failed to get path template: %s", err)
|
||||
} else {
|
||||
if pathTemplate == "/v1/ping" ||
|
||||
(pathTemplate == "/v1/spots" && r.Method == "POST") ||
|
||||
(pathTemplate == "/v1/spots/{id}/uploaded" && r.Method == "POST") {
|
||||
isExtension = true
|
||||
}
|
||||
}
|
||||
//func isSpotWithKeyRequest(r *http.Request) bool {
|
||||
// pathTemplate, err := mux.CurrentRoute(r).GetPathTemplate()
|
||||
// if err != nil {
|
||||
// return false
|
||||
// }
|
||||
// getSpotPrefix := "/v1/spots/{id}" // GET
|
||||
// addCommentPrefix := "/v1/spots/{id}/comment" // POST
|
||||
// getStatusPrefix := "/v1/spots/{id}/status" // GET
|
||||
// if (pathTemplate == getSpotPrefix && r.Method == "GET") ||
|
||||
// (pathTemplate == addCommentPrefix && r.Method == "POST") ||
|
||||
// (pathTemplate == getStatusPrefix && r.Method == "GET") {
|
||||
// return true
|
||||
// }
|
||||
// return false
|
||||
//}
|
||||
|
||||
// Check if the request is authorized
|
||||
user, err := e.services.Auth.IsAuthorized(r.Header.Get("Authorization"), getPermissions(r.URL.Path), isExtension)
|
||||
if err != nil {
|
||||
e.log.Warn(r.Context(), "Unauthorized request: %s", err)
|
||||
if !isSpotWithKeyRequest(r) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
//func (e *Router) rateLimitMiddleware(next http.Handler) http.Handler {
|
||||
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// if r.URL.Path == "/" {
|
||||
// next.ServeHTTP(w, r)
|
||||
// }
|
||||
// user := r.Context().Value("userData").(*auth.User)
|
||||
// rl := e.limiter.GetRateLimiter(user.ID)
|
||||
//
|
||||
// if !rl.Allow() {
|
||||
// http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
||||
// return
|
||||
// }
|
||||
// next.ServeHTTP(w, r)
|
||||
// })
|
||||
//}
|
||||
|
||||
user, err = e.services.Keys.IsValid(r.URL.Query().Get("key"))
|
||||
if err != nil {
|
||||
e.log.Warn(r.Context(), "Wrong public key: %s", err)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
//type statusWriter struct {
|
||||
// http.ResponseWriter
|
||||
// statusCode int
|
||||
//}
|
||||
//
|
||||
//func (w *statusWriter) WriteHeader(statusCode int) {
|
||||
// w.statusCode = statusCode
|
||||
// w.ResponseWriter.WriteHeader(statusCode)
|
||||
//}
|
||||
//
|
||||
//func (w *statusWriter) Write(b []byte) (int, error) {
|
||||
// if w.statusCode == 0 {
|
||||
// w.statusCode = http.StatusOK // Default status code is 200
|
||||
// }
|
||||
// return w.ResponseWriter.Write(b)
|
||||
//}
|
||||
|
||||
r = r.WithContext(context.WithValues(r.Context(), map[string]interface{}{"userData": user}))
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
//func (e *Router) actionMiddleware(next http.Handler) http.Handler {
|
||||
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// if r.URL.Path == "/" {
|
||||
// next.ServeHTTP(w, r)
|
||||
// }
|
||||
// // Read body and restore the io.ReadCloser to its original state
|
||||
// bodyBytes, err := io.ReadAll(r.Body)
|
||||
// if err != nil {
|
||||
// http.Error(w, "can't read body", http.StatusBadRequest)
|
||||
// return
|
||||
// }
|
||||
// r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
// // Use custom response writer to get the status code
|
||||
// sw := &statusWriter{ResponseWriter: w}
|
||||
// // Serve the request
|
||||
// next.ServeHTTP(sw, r)
|
||||
// e.logRequest(r, bodyBytes, sw.statusCode)
|
||||
// })
|
||||
//}
|
||||
|
||||
func isSpotWithKeyRequest(r *http.Request) bool {
|
||||
pathTemplate, err := mux.CurrentRoute(r).GetPathTemplate()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
getSpotPrefix := "/v1/spots/{id}" // GET
|
||||
addCommentPrefix := "/v1/spots/{id}/comment" // POST
|
||||
getStatusPrefix := "/v1/spots/{id}/status" // GET
|
||||
if (pathTemplate == getSpotPrefix && r.Method == "GET") ||
|
||||
(pathTemplate == addCommentPrefix && r.Method == "POST") ||
|
||||
(pathTemplate == getStatusPrefix && r.Method == "GET") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *Router) rateLimitMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/" {
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
user := r.Context().Value("userData").(*auth.User)
|
||||
rl := e.limiter.GetRateLimiter(user.ID)
|
||||
|
||||
if !rl.Allow() {
|
||||
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
type statusWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (w *statusWriter) WriteHeader(statusCode int) {
|
||||
w.statusCode = statusCode
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *statusWriter) Write(b []byte) (int, error) {
|
||||
if w.statusCode == 0 {
|
||||
w.statusCode = http.StatusOK // Default status code is 200
|
||||
}
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func (e *Router) actionMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/" {
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
// Read body and restore the io.ReadCloser to its original state
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "can't read body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
// Use custom response writer to get the status code
|
||||
sw := &statusWriter{ResponseWriter: w}
|
||||
// Serve the request
|
||||
next.ServeHTTP(sw, r)
|
||||
e.logRequest(r, bodyBytes, sw.statusCode)
|
||||
})
|
||||
}
|
||||
|
||||
func (e *Router) GetHandler() http.Handler {
|
||||
return e.router
|
||||
}
|
||||
//func (e *Router) GetHandler() http.Handler {
|
||||
// return e.router
|
||||
//}
|
||||
|
|
|
|||
|
|
@ -2,38 +2,35 @@ package spot
|
|||
|
||||
import (
|
||||
"openreplay/backend/internal/config/spot"
|
||||
"openreplay/backend/pkg/db/postgres/pool"
|
||||
"openreplay/backend/pkg/flakeid"
|
||||
"openreplay/backend/pkg/common"
|
||||
"openreplay/backend/pkg/logger"
|
||||
"openreplay/backend/pkg/objectstorage"
|
||||
"openreplay/backend/pkg/objectstorage/store"
|
||||
"openreplay/backend/pkg/spot/auth"
|
||||
"openreplay/backend/pkg/spot/service"
|
||||
"openreplay/backend/pkg/spot/transcoder"
|
||||
)
|
||||
|
||||
type ServicesBuilder struct {
|
||||
Flaker *flakeid.Flaker
|
||||
ObjStorage objectstorage.ObjectStorage
|
||||
Auth auth.Auth
|
||||
*common.ServicesBuilder
|
||||
Spots service.Spots
|
||||
Keys service.Keys
|
||||
Transcoder transcoder.Transcoder
|
||||
cfg *spot.Config
|
||||
}
|
||||
|
||||
func NewServiceBuilder(log logger.Logger, cfg *spot.Config, pgconn pool.Pool) (*ServicesBuilder, error) {
|
||||
objStore, err := store.NewStore(&cfg.ObjectsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
flaker := flakeid.NewFlaker(cfg.WorkerID)
|
||||
spots := service.NewSpots(log, pgconn, flaker)
|
||||
func NewServiceBuilder(log logger.Logger, cfg *spot.Config) (*ServicesBuilder, error) {
|
||||
builder := common.NewServiceBuilder(log).
|
||||
WithDatabase(cfg.Postgres.String()).
|
||||
WithAuth(cfg.JWTSecret, cfg.JWTSpotSecret).
|
||||
WithObjectStorage(&cfg.ObjectsConfig)
|
||||
|
||||
keys := service.NewKeys(log, builder.Pgconn)
|
||||
spots := service.NewSpots(log, builder.Pgconn, builder.Flaker)
|
||||
tc := transcoder.NewTranscoder(cfg, log, builder.ObjStorage, builder.Pgconn, spots)
|
||||
|
||||
return &ServicesBuilder{
|
||||
Flaker: flaker,
|
||||
ObjStorage: objStore,
|
||||
Auth: auth.NewAuth(log, cfg.JWTSecret, cfg.JWTSpotSecret, pgconn),
|
||||
Spots: spots,
|
||||
Keys: service.NewKeys(log, pgconn),
|
||||
Transcoder: transcoder.NewTranscoder(cfg, log, objStore, pgconn, spots),
|
||||
ServicesBuilder: builder,
|
||||
Spots: spots,
|
||||
Keys: keys,
|
||||
Transcoder: tc,
|
||||
cfg: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue