From 1578b891bde0303785cd8dc3c7f1501ea7d5a543 Mon Sep 17 00:00:00 2001 From: Shekar Siri Date: Mon, 28 Oct 2024 12:34:47 +0100 Subject: [PATCH] feat(backend): analytics - common auth middleware with extra secret to support spot like service --- backend/cmd/analytics/main.go | 32 ++++++++++++---- backend/internal/config/analytics/config.go | 1 + backend/pkg/analytics/api/handler.go | 5 +++ backend/pkg/analytics/api/router.go | 5 ++- backend/pkg/analytics/builder.go | 29 --------------- backend/pkg/common/api/auth/auth.go | 41 +++++++++++++++++---- backend/pkg/common/api/auth/authorizer.go | 8 ++-- backend/pkg/common/api/auth/storage.go | 10 ++--- backend/pkg/common/builder.go | 24 +++++++----- backend/pkg/common/middleware/http.go | 34 +++++++++++++---- 10 files changed, 118 insertions(+), 71 deletions(-) diff --git a/backend/cmd/analytics/main.go b/backend/cmd/analytics/main.go index 6881e4625..1d7b15b87 100644 --- a/backend/cmd/analytics/main.go +++ b/backend/cmd/analytics/main.go @@ -2,9 +2,12 @@ package main import ( "context" + "github.com/gorilla/mux" + "net/http" "openreplay/backend/internal/http/server" "openreplay/backend/pkg/analytics/api" "openreplay/backend/pkg/common" + "openreplay/backend/pkg/common/api/auth" "openreplay/backend/pkg/common/middleware" "openreplay/backend/pkg/db/postgres/pool" "openreplay/backend/pkg/logger" @@ -29,18 +32,13 @@ func main() { builder := common.NewServiceBuilder(log) services, err := builder. WithDatabase(pgConn). - WithJWTSecret(cfg.JWTSecret). + WithJWTSecret(cfg.JWTSecret, cfg.JWTSpotSecret). Build() if err != nil { log.Fatal(ctx, "can't init services: %s", err) } - //services, err := analytics.NewServiceBuilder(log, cfg, pgConn) - //if err != nil { - // log.Fatal(ctx, "can't init services: %s", err) - //} - // Define excluded paths for this service excludedPaths := map[string]map[string]bool{ //"/v1/ping": {"GET": true}, @@ -56,9 +54,29 @@ func main() { return []string{"user"} } + authOptionsSelector := func(r *http.Request) *auth.Options { + pathTemplate, err := mux.CurrentRoute(r).GetPathTemplate() + if err != nil { + log.Error(r.Context(), "failed to get path template: %s", err) + return nil // Use default options if there’s an error + } + + // Customize based on route and method + if pathTemplate == "/v1/spots/{id}/uploaded" && r.Method == "POST" { + column := "spot_jwt_iat" + secret := cfg.JWTSpotSecret + return &auth.Options{JwtColumn: column, Secret: secret} + } + + // Return nil to signal default options in AuthMiddleware + return nil + } + + authMiddleware := middleware.AuthMiddleware(services, log, excludedPaths, getPermissions, authOptionsSelector) + router, err := api.NewRouter(cfg, log, services) router.GetRouter().Use(middleware.CORS(cfg.UseAccessControlHeaders)) - router.GetRouter().Use(middleware.AuthMiddleware(services, log, excludedPaths, getPermissions)) + router.GetRouter().Use(authMiddleware) router.GetRouter().Use(middleware.RateLimit) router.GetRouter().Use(middleware.Action) diff --git a/backend/internal/config/analytics/config.go b/backend/internal/config/analytics/config.go index ea44a4842..dfbf05d12 100644 --- a/backend/internal/config/analytics/config.go +++ b/backend/internal/config/analytics/config.go @@ -25,6 +25,7 @@ type Config struct { UseAccessControlHeaders bool `env:"USE_CORS,default=false"` ProjectExpiration time.Duration `env:"PROJECT_EXPIRATION,default=10m"` JWTSecret string `env:"JWT_SECRET,required"` + JWTSpotSecret string `env:"JWT_SPOT_SECRET,required"` // TODO: remove this MinimumStreamDuration int `env:"MINIMUM_STREAM_DURATION,default=15000"` // 15s WorkerID uint16 } diff --git a/backend/pkg/analytics/api/handler.go b/backend/pkg/analytics/api/handler.go index 48569cc26..4cd0c5eeb 100644 --- a/backend/pkg/analytics/api/handler.go +++ b/backend/pkg/analytics/api/handler.go @@ -11,6 +11,11 @@ import ( "time" ) +func (e *Router) spotTest(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("Welcome to NSE Live API")) +} + func (e *Router) createDashboard(w http.ResponseWriter, r *http.Request) { startTime := time.Now() bodySize := 0 diff --git a/backend/pkg/analytics/api/router.go b/backend/pkg/analytics/api/router.go index 6bba46946..40ee3d3a9 100644 --- a/backend/pkg/analytics/api/router.go +++ b/backend/pkg/analytics/api/router.go @@ -42,8 +42,9 @@ func (e *Router) init() { e.router = mux.NewRouter() e.router.HandleFunc("/", e.ping) - e.router.HandleFunc("/{projectId}/dashboards", e.createDashboard).Methods("POST") - e.router.HandleFunc("/{projectId}/dashboards", e.getDashboards).Methods("GET") + e.router.HandleFunc("/{projectId}/dashboards", e.createDashboard).Methods("POST", "OPTIONS") + e.router.HandleFunc("/v1/spots/{id}/uploaded", e.spotTest).Methods("POST", "OPTIONS") + e.router.HandleFunc("/{projectId}/dashboards", e.getDashboards).Methods("GET", "OPTIONS") e.router.HandleFunc("/{projectId}/dashboards/{dashboardId}", e.getDashboard).Methods("GET") e.router.HandleFunc("/{projectId}/dashboards/{dashboardId}", e.updateDashboard).Methods("PUT") e.router.HandleFunc("/{projectId}/dashboards/{dashboardId}", e.deleteDashboard).Methods("DELETE") diff --git a/backend/pkg/analytics/builder.go b/backend/pkg/analytics/builder.go index bb977cc5a..472287ac9 100644 --- a/backend/pkg/analytics/builder.go +++ b/backend/pkg/analytics/builder.go @@ -1,30 +1 @@ package analytics - -//import ( -// "openreplay/backend/internal/config/analytics" -// "openreplay/backend/pkg/common/api/auth" -// "openreplay/backend/pkg/db/postgres/pool" -// "openreplay/backend/pkg/flakeid" -// "openreplay/backend/pkg/logger" -// "openreplay/backend/pkg/objectstorage" -// "openreplay/backend/pkg/objectstorage/store" -//) -// -//type ServicesBuilder struct { -// Flaker *flakeid.Flaker -// ObjStorage objectstorage.ObjectStorage -// Auth auth.Auth -//} -// -//func NewServiceBuilder(log logger.Logger, cfg *analytics.Config, pgconn pool.Pool) (*ServicesBuilder, error) { -// objStore, err := store.NewStore(&cfg.ObjectsConfig) -// if err != nil { -// return nil, err -// } -// flaker := flakeid.NewFlaker(cfg.WorkerID) -// return &ServicesBuilder{ -// Flaker: flaker, -// ObjStorage: objStore, -// Auth: auth.NewAuth(log, cfg.JWTSecret, pgconn), -// }, nil -//} diff --git a/backend/pkg/common/api/auth/auth.go b/backend/pkg/common/api/auth/auth.go index 4f4fd62c5..311b255dd 100644 --- a/backend/pkg/common/api/auth/auth.go +++ b/backend/pkg/common/api/auth/auth.go @@ -10,21 +10,46 @@ import ( "openreplay/backend/pkg/logger" ) +// Options struct to hold optional JWT column and secret +type Options struct { + JwtColumn string // The JWT column to use (e.g., "jwt_iat" or "spot_jwt_iat") + Secret string // An optional secret; if nil, default secret is used +} + type Auth interface { - IsAuthorized(authHeader string, permissions []string, isExtension bool) (*User, error) + IsAuthorized(authHeader string, permissions []string, options Options) (*User, error) + Secret() string + JWTCol() string + ExtraSecret() string } type authImpl struct { - log logger.Logger - secret string - pgconn pool.Pool + log logger.Logger + secret string + extraSecret string + pgconn pool.Pool + jwtCol string } -func NewAuth(log logger.Logger, jwtSecret string, conn pool.Pool) Auth { +func (a *authImpl) Secret() string { + return a.secret +} + +func (a *authImpl) JWTCol() string { + return a.jwtCol +} + +func (a *authImpl) ExtraSecret() string { + return a.extraSecret +} + +func NewAuth(log logger.Logger, jwtCol string, jwtSecret string, extraSecret string, conn pool.Pool) Auth { return &authImpl{ - log: log, - secret: jwtSecret, - pgconn: conn, + log: log, + secret: jwtSecret, + extraSecret: extraSecret, + pgconn: conn, + jwtCol: jwtCol, } } diff --git a/backend/pkg/common/api/auth/authorizer.go b/backend/pkg/common/api/auth/authorizer.go index b36ce81bb..c1f0cb8bc 100644 --- a/backend/pkg/common/api/auth/authorizer.go +++ b/backend/pkg/common/api/auth/authorizer.go @@ -1,10 +1,12 @@ package auth -func (a *authImpl) IsAuthorized(authHeader string, permissions []string, isExtension bool) (*User, error) { - secret := a.secret +func (a *authImpl) IsAuthorized(authHeader string, permissions []string, options Options) (*User, error) { + jwtCol := options.JwtColumn + secret := options.Secret + jwtInfo, err := parseJWT(authHeader, secret) if err != nil { return nil, err } - return authUser(a.pgconn, jwtInfo.UserId, jwtInfo.TenantID, int(jwtInfo.IssuedAt.Unix()), isExtension) + return authUser(a.pgconn, jwtInfo.UserId, jwtInfo.TenantID, int(jwtInfo.IssuedAt.Unix()), jwtCol) } diff --git a/backend/pkg/common/api/auth/storage.go b/backend/pkg/common/api/auth/storage.go index 3830f94f9..f64bc9248 100644 --- a/backend/pkg/common/api/auth/storage.go +++ b/backend/pkg/common/api/auth/storage.go @@ -5,15 +5,15 @@ import ( "openreplay/backend/pkg/db/postgres/pool" ) -func authUser(conn pool.Pool, userID, tenantID, jwtIAT int, isExtension bool) (*User, error) { - sql := ` - SELECT user_id, name, email +func authUser(conn pool.Pool, userID, tenantID, jwtIAT int, jwtCol string) (*User, error) { + sql := fmt.Sprintf(` + SELECT user_id, name, email, EXTRACT(epoch FROM %s)::BIGINT AS jwt_iat FROM public.users WHERE user_id = $1 AND deleted_at IS NULL - LIMIT 1;` + LIMIT 1;`, jwtCol) user := &User{TenantID: 1, AuthMethod: "jwt"} if err := conn.QueryRow(sql, userID).Scan(&user.ID, &user.Name, &user.Email, &user.JwtIat); err != nil { - return nil, fmt.Errorf("user not found") + return nil, fmt.Errorf("user not found") // TODO should be a proper message with error message } if user.JwtIat == 0 || abs(jwtIAT-user.JwtIat) > 1 { return nil, fmt.Errorf("token has been updated") diff --git a/backend/pkg/common/builder.go b/backend/pkg/common/builder.go index 23d20a592..74a7900a8 100644 --- a/backend/pkg/common/builder.go +++ b/backend/pkg/common/builder.go @@ -11,13 +11,14 @@ import ( // ServicesBuilder struct to hold service components type ServicesBuilder struct { - flaker *flakeid.Flaker - objStorage objectstorage.ObjectStorage - Auth auth.Auth - log logger.Logger - pgconn pool.Pool - workerID int - jwtSecret string + flaker *flakeid.Flaker + objStorage objectstorage.ObjectStorage + Auth auth.Auth + log logger.Logger + pgconn pool.Pool + workerID int + jwtSecret string + extraSecret string } // NewServiceBuilder initializes the ServicesBuilder with essential components (logger) @@ -57,9 +58,12 @@ func (b *ServicesBuilder) WithWorkerID(workerID int) *ServicesBuilder { return b } -// WithJWTSecret sets the JWT secret for Auth -func (b *ServicesBuilder) WithJWTSecret(jwtSecret string) *ServicesBuilder { +// WithJWTSecret sets the JWT and optional extra secret for Auth +func (b *ServicesBuilder) WithJWTSecret(jwtSecret string, extraSecret ...string) *ServicesBuilder { b.jwtSecret = jwtSecret + if len(extraSecret) > 0 { + b.extraSecret = extraSecret[0] + } return b } @@ -82,7 +86,7 @@ func (b *ServicesBuilder) Build() (*ServicesBuilder, error) { if b.jwtSecret == "" { return nil, errors.New("JWT secret is required") } - b.Auth = auth.NewAuth(b.log, b.jwtSecret, b.pgconn) + b.Auth = auth.NewAuth(b.log, "jwt_iat", b.jwtSecret, b.extraSecret, b.pgconn) } // Return the fully constructed service diff --git a/backend/pkg/common/middleware/http.go b/backend/pkg/common/middleware/http.go index 0c4630514..5c6d6a1e6 100644 --- a/backend/pkg/common/middleware/http.go +++ b/backend/pkg/common/middleware/http.go @@ -2,7 +2,6 @@ package middleware import ( "context" - "github.com/gorilla/mux" "net/http" "openreplay/backend/internal/http/util" "openreplay/backend/pkg/common" @@ -41,7 +40,8 @@ func AuthMiddleware( services *common.ServicesBuilder, // Injected services (Auth, Keys, etc.) log logger.Logger, // Logger for logging events excludedPaths map[string]map[string]bool, // Map of excluded paths with methods - getPermissions func(path string) []string, // Function to retrieve permissions for a path + getPermissions func(path string) []string, + authOptionsSelector func(r *http.Request) *auth.Options, // Function to retrieve permissions for a path ) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -56,14 +56,34 @@ func AuthMiddleware( return } - // Check if the route is dynamic and get the path template - pathTemplate, err := mux.CurrentRoute(r).GetPathTemplate() - if err != nil { - log.Error(r.Context(), "failed to get path template: %s", err) + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + log.Warn(r.Context(), "Authorization header missing") + w.WriteHeader(http.StatusUnauthorized) + return + } + + // Get AuthOptions for the request + options := auth.Options{ + JwtColumn: services.Auth.JWTCol(), // Default JWT column from ServicesBuilder + Secret: services.Auth.Secret(), // Default secret from ServicesBuilder + } + + if authOptionsSelector != nil { + selectorOptions := authOptionsSelector(r) + if selectorOptions != nil { + // Override defaults with values from selectorOptions + if selectorOptions.JwtColumn != "" { + options.JwtColumn = selectorOptions.JwtColumn + } + if selectorOptions.Secret != "" { + options.Secret = selectorOptions.Secret + } + } } // Check if this request is authorized - user, err := services.Auth.IsAuthorized(r.Header.Get("Authorization"), getPermissions(r.URL.Path), pathTemplate != "") + user, err := services.Auth.IsAuthorized(authHeader, getPermissions(r.URL.Path), options) if err != nil { log.Warn(r.Context(), "Unauthorized request: %s", err) w.WriteHeader(http.StatusUnauthorized)