diff --git a/backend/cmd/http/main.go b/backend/cmd/http/main.go index 0b815ea1e..c6034e278 100644 --- a/backend/cmd/http/main.go +++ b/backend/cmd/http/main.go @@ -2,30 +2,27 @@ package main import ( "context" - "os" - "os/signal" - "syscall" "openreplay/backend/internal/config/http" - "openreplay/backend/internal/http/router" - "openreplay/backend/internal/http/server" "openreplay/backend/internal/http/services" "openreplay/backend/pkg/db/postgres/pool" "openreplay/backend/pkg/db/redis" "openreplay/backend/pkg/logger" "openreplay/backend/pkg/metrics" databaseMetrics "openreplay/backend/pkg/metrics/database" - httpMetrics "openreplay/backend/pkg/metrics/http" + "openreplay/backend/pkg/metrics/web" "openreplay/backend/pkg/queue" + "openreplay/backend/pkg/server" + "openreplay/backend/pkg/server/api" ) func main() { ctx := context.Background() log := logger.New() cfg := http.New(log) - metrics.New(log, append(httpMetrics.List(), databaseMetrics.List()...)) + webMetrics := web.New("http") + metrics.New(log, append(webMetrics.List(), databaseMetrics.List()...)) - // Connect to queue producer := queue.NewProducer(cfg.MessageSizeLimit, true) defer producer.Close(15000) @@ -37,38 +34,21 @@ func main() { redisClient, err := redis.New(&cfg.Redis) if err != nil { - log.Warn(ctx, "can't init redis connection: %s", err) + log.Info(ctx, "no redis cache: %s", err) } defer redisClient.Close() - services, err := services.New(log, cfg, producer, pgConn, redisClient) + builder, err := services.New(log, cfg, webMetrics, producer, pgConn, redisClient) if err != nil { log.Fatal(ctx, "failed while creating services: %s", err) } - router, err := router.NewRouter(cfg, log, services) + router, err := api.NewRouter(&cfg.HTTP, log) if err != nil { log.Fatal(ctx, "failed while creating router: %s", err) } + router.AddHandlers(api.NoPrefix, builder.WebAPI, builder.MobileAPI, builder.ConditionsAPI, builder.FeatureFlagsAPI, + builder.TagsAPI, builder.UxTestsAPI) - server, err := server.New(router.GetHandler(), cfg.HTTPHost, cfg.HTTPPort, cfg.HTTPTimeout) - if err != nil { - log.Fatal(ctx, "failed while creating server: %s", err) - } - - // Run server - go func() { - if err := server.Start(); err != nil { - log.Fatal(ctx, "http server error: %s", err) - } - }() - - log.Info(ctx, "server successfully started on port %s", cfg.HTTPPort) - - // Wait stop signal to shut down server gracefully - sigchan := make(chan os.Signal, 1) - signal.Notify(sigchan, syscall.SIGINT, syscall.SIGTERM) - <-sigchan - log.Info(ctx, "shutting down the server") - server.Stop() + server.Run(ctx, log, &cfg.HTTP, router) } diff --git a/backend/cmd/integrations/main.go b/backend/cmd/integrations/main.go index 2e371bba4..e9c79aa02 100644 --- a/backend/cmd/integrations/main.go +++ b/backend/cmd/integrations/main.go @@ -2,24 +2,24 @@ package main import ( "context" - "os" - "os/signal" - "syscall" config "openreplay/backend/internal/config/integrations" - "openreplay/backend/internal/http/server" "openreplay/backend/pkg/db/postgres/pool" - integration "openreplay/backend/pkg/integrations" + "openreplay/backend/pkg/integrations" "openreplay/backend/pkg/logger" "openreplay/backend/pkg/metrics" "openreplay/backend/pkg/metrics/database" + "openreplay/backend/pkg/metrics/web" + "openreplay/backend/pkg/server" + "openreplay/backend/pkg/server/api" ) func main() { ctx := context.Background() log := logger.New() cfg := config.New(log) - metrics.New(log, append(database.List())) + webMetrics := web.New("integrations") + metrics.New(log, append(webMetrics.List(), database.List()...)) pgConn, err := pool.New(cfg.Postgres.String()) if err != nil { @@ -27,31 +27,17 @@ func main() { } defer pgConn.Close() - services, err := integration.NewServiceBuilder(log, cfg, pgConn) + builder, err := integrations.NewServiceBuilder(log, cfg, webMetrics, pgConn) if err != nil { log.Fatal(ctx, "can't init services: %s", err) } - router, err := integration.NewRouter(cfg, log, services) + router, err := api.NewRouter(&cfg.HTTP, log) if err != nil { log.Fatal(ctx, "failed while creating router: %s", err) } + router.AddHandlers(api.NoPrefix, builder.IntegrationsAPI) + router.AddMiddlewares(builder.Auth.Middleware, builder.RateLimiter.Middleware, builder.AuditTrail.Middleware) - dataIntegrationServer, err := server.New(router.GetHandler(), cfg.HTTPHost, cfg.HTTPPort, cfg.HTTPTimeout) - if err != nil { - log.Fatal(ctx, "failed while creating server: %s", err) - } - go func() { - if err := dataIntegrationServer.Start(); err != nil { - log.Fatal(ctx, "http server error: %s", err) - } - }() - log.Info(ctx, "server successfully started on port %s", cfg.HTTPPort) - - // Wait stop signal to shut down server gracefully - sigchan := make(chan os.Signal, 1) - signal.Notify(sigchan, syscall.SIGINT, syscall.SIGTERM) - <-sigchan - log.Info(ctx, "shutting down the server") - dataIntegrationServer.Stop() + server.Run(ctx, log, &cfg.HTTP, router) } diff --git a/backend/cmd/spot/main.go b/backend/cmd/spot/main.go index b4204486e..637648d46 100644 --- a/backend/cmd/spot/main.go +++ b/backend/cmd/spot/main.go @@ -2,26 +2,25 @@ package main import ( "context" - "openreplay/backend/pkg/spot" - "openreplay/backend/pkg/spot/api" - "os" - "os/signal" - "syscall" spotConfig "openreplay/backend/internal/config/spot" - "openreplay/backend/internal/http/server" "openreplay/backend/pkg/db/postgres/pool" "openreplay/backend/pkg/logger" "openreplay/backend/pkg/metrics" databaseMetrics "openreplay/backend/pkg/metrics/database" spotMetrics "openreplay/backend/pkg/metrics/spot" + "openreplay/backend/pkg/metrics/web" + "openreplay/backend/pkg/server" + "openreplay/backend/pkg/server/api" + "openreplay/backend/pkg/spot" ) func main() { ctx := context.Background() log := logger.New() cfg := spotConfig.New(log) - metrics.New(log, append(spotMetrics.List(), databaseMetrics.List()...)) + webMetrics := web.New("spot") + metrics.New(log, append(webMetrics.List(), append(spotMetrics.List(), databaseMetrics.List()...)...)) pgConn, err := pool.New(cfg.Postgres.String()) if err != nil { @@ -29,32 +28,17 @@ func main() { } defer pgConn.Close() - services, err := spot.NewServiceBuilder(log, cfg, pgConn) + builder, err := spot.NewServiceBuilder(log, cfg, webMetrics, pgConn) if err != nil { log.Fatal(ctx, "can't init services: %s", err) } - router, err := api.NewRouter(cfg, log, services) + router, err := api.NewRouter(&cfg.HTTP, log) if err != nil { log.Fatal(ctx, "failed while creating router: %s", err) } + router.AddHandlers(api.NoPrefix, builder.SpotsAPI) + router.AddMiddlewares(builder.Auth.Middleware, builder.RateLimiter.Middleware, builder.AuditTrail.Middleware) - spotServer, err := server.New(router.GetHandler(), cfg.HTTPHost, cfg.HTTPPort, cfg.HTTPTimeout) - if err != nil { - log.Fatal(ctx, "failed while creating server: %s", err) - } - - go func() { - if err := spotServer.Start(); err != nil { - log.Fatal(ctx, "http server error: %s", err) - } - }() - log.Info(ctx, "server successfully started on port %s", cfg.HTTPPort) - - // Wait stop signal to shut down server gracefully - sigchan := make(chan os.Signal, 1) - signal.Notify(sigchan, syscall.SIGINT, syscall.SIGTERM) - <-sigchan - log.Info(ctx, "shutting down the server") - spotServer.Stop() + server.Run(ctx, log, &cfg.HTTP, router) } diff --git a/backend/internal/config/common/config.go b/backend/internal/config/common/config.go index 073981351..dd21d2ae0 100644 --- a/backend/internal/config/common/config.go +++ b/backend/internal/config/common/config.go @@ -1,6 +1,9 @@ package common -import "strings" +import ( + "strings" + "time" +) // Common config for all services @@ -70,3 +73,13 @@ type ElasticSearch struct { func (cfg *ElasticSearch) GetURLs() []string { return strings.Split(cfg.URLs, ",") } + +type HTTP struct { + HTTPHost string `env:"HTTP_HOST,default="` + HTTPPort string `env:"HTTP_PORT,required"` + HTTPTimeout time.Duration `env:"HTTP_TIMEOUT,default=60s"` + JsonSizeLimit int64 `env:"JSON_SIZE_LIMIT,default=131072"` // 128KB, 1000 for HTTP service + UseAccessControlHeaders bool `env:"USE_CORS,default=false"` + JWTSecret string `env:"JWT_SECRET,required"` + JWTSpotSecret string `env:"JWT_SPOT_SECRET,required"` +} diff --git a/backend/internal/config/http/config.go b/backend/internal/config/http/config.go index 720b6f1f7..4a42be542 100644 --- a/backend/internal/config/http/config.go +++ b/backend/internal/config/http/config.go @@ -15,31 +15,27 @@ type Config struct { common.Postgres redis.Redis objectstorage.ObjectsConfig - HTTPHost string `env:"HTTP_HOST,default="` - HTTPPort string `env:"HTTP_PORT,required"` - HTTPTimeout time.Duration `env:"HTTP_TIMEOUT,default=60s"` - TopicRawWeb string `env:"TOPIC_RAW_WEB,required"` - TopicRawMobile string `env:"TOPIC_RAW_IOS,required"` - TopicRawImages string `env:"TOPIC_RAW_IMAGES,required"` - TopicCanvasImages string `env:"TOPIC_CANVAS_IMAGES,required"` - BeaconSizeLimit int64 `env:"BEACON_SIZE_LIMIT,required"` - CompressionThreshold int64 `env:"COMPRESSION_THRESHOLD,default=20000"` - JsonSizeLimit int64 `env:"JSON_SIZE_LIMIT,default=1000"` - FileSizeLimit int64 `env:"FILE_SIZE_LIMIT,default=10000000"` - TokenSecret string `env:"TOKEN_SECRET,required"` - UAParserFile string `env:"UAPARSER_FILE,required"` - MaxMinDBFile string `env:"MAXMINDDB_FILE,required"` - UseProfiler bool `env:"PROFILER_ENABLED,default=false"` - UseAccessControlHeaders bool `env:"USE_CORS,default=false"` - ProjectExpiration time.Duration `env:"PROJECT_EXPIRATION,default=10m"` - RecordCanvas bool `env:"RECORD_CANVAS,default=false"` - CanvasQuality string `env:"CANVAS_QUALITY,default=low"` - CanvasFps int `env:"CANVAS_FPS,default=1"` - MobileQuality string `env:"MOBILE_QUALITY,default=low"` // (low, standard, high) - MobileFps int `env:"MOBILE_FPS,default=1"` - IsFeatureFlagEnabled bool `env:"IS_FEATURE_FLAG_ENABLED,default=false"` - IsUsabilityTestEnabled bool `env:"IS_USABILITY_TEST_ENABLED,default=false"` - WorkerID uint16 + common.HTTP + TopicRawWeb string `env:"TOPIC_RAW_WEB,required"` + TopicRawMobile string `env:"TOPIC_RAW_IOS,required"` + TopicRawImages string `env:"TOPIC_RAW_IMAGES,required"` + TopicCanvasImages string `env:"TOPIC_CANVAS_IMAGES,required"` + BeaconSizeLimit int64 `env:"BEACON_SIZE_LIMIT,required"` + CompressionThreshold int64 `env:"COMPRESSION_THRESHOLD,default=20000"` + FileSizeLimit int64 `env:"FILE_SIZE_LIMIT,default=10000000"` + TokenSecret string `env:"TOKEN_SECRET,required"` + UAParserFile string `env:"UAPARSER_FILE,required"` + MaxMinDBFile string `env:"MAXMINDDB_FILE,required"` + UseProfiler bool `env:"PROFILER_ENABLED,default=false"` + ProjectExpiration time.Duration `env:"PROJECT_EXPIRATION,default=10m"` + RecordCanvas bool `env:"RECORD_CANVAS,default=false"` + CanvasQuality string `env:"CANVAS_QUALITY,default=low"` + CanvasFps int `env:"CANVAS_FPS,default=1"` + MobileQuality string `env:"MOBILE_QUALITY,default=low"` // (low, standard, high) + MobileFps int `env:"MOBILE_FPS,default=1"` + IsFeatureFlagEnabled bool `env:"IS_FEATURE_FLAG_ENABLED,default=false"` + IsUsabilityTestEnabled bool `env:"IS_USABILITY_TEST_ENABLED,default=false"` + WorkerID uint16 } func New(log logger.Logger) *Config { diff --git a/backend/internal/config/integrations/config.go b/backend/internal/config/integrations/config.go index a1507bbb4..9df273c1f 100644 --- a/backend/internal/config/integrations/config.go +++ b/backend/internal/config/integrations/config.go @@ -16,14 +16,9 @@ type Config struct { common.Postgres redis.Redis objectstorage.ObjectsConfig - HTTPHost string `env:"HTTP_HOST,default="` - HTTPPort string `env:"HTTP_PORT,required"` - HTTPTimeout time.Duration `env:"HTTP_TIMEOUT,default=60s"` - JsonSizeLimit int64 `env:"JSON_SIZE_LIMIT,default=131072"` // 128KB - UseAccessControlHeaders bool `env:"USE_CORS,default=false"` - ProjectExpiration time.Duration `env:"PROJECT_EXPIRATION,default=10m"` - JWTSecret string `env:"JWT_SECRET,required"` - WorkerID uint16 + common.HTTP + ProjectExpiration time.Duration `env:"PROJECT_EXPIRATION,default=10m"` + WorkerID uint16 } func New(log logger.Logger) *Config { diff --git a/backend/internal/config/spot/config.go b/backend/internal/config/spot/config.go index 7f65dcc83..1f2562600 100644 --- a/backend/internal/config/spot/config.go +++ b/backend/internal/config/spot/config.go @@ -16,18 +16,12 @@ type Config struct { common.Postgres redis.Redis objectstorage.ObjectsConfig - FSDir string `env:"FS_DIR,required"` - SpotsDir string `env:"SPOTS_DIR,default=spots"` - HTTPHost string `env:"HTTP_HOST,default="` - HTTPPort string `env:"HTTP_PORT,required"` - HTTPTimeout time.Duration `env:"HTTP_TIMEOUT,default=60s"` - JsonSizeLimit int64 `env:"JSON_SIZE_LIMIT,default=131072"` // 128KB - UseAccessControlHeaders bool `env:"USE_CORS,default=false"` - ProjectExpiration time.Duration `env:"PROJECT_EXPIRATION,default=10m"` - JWTSecret string `env:"JWT_SECRET,required"` - JWTSpotSecret string `env:"JWT_SPOT_SECRET,required"` - MinimumStreamDuration int `env:"MINIMUM_STREAM_DURATION,default=15000"` // 15s - WorkerID uint16 + common.HTTP + FSDir string `env:"FS_DIR,required"` + SpotsDir string `env:"SPOTS_DIR,default=spots"` + ProjectExpiration time.Duration `env:"PROJECT_EXPIRATION,default=10m"` + MinimumStreamDuration int `env:"MINIMUM_STREAM_DURATION,default=15000"` // 15s + WorkerID uint16 } func New(log logger.Logger) *Config { diff --git a/backend/internal/http/geoip/geoip.go b/backend/internal/http/geoip/geoip.go index b3a2fa134..c580c5409 100644 --- a/backend/internal/http/geoip/geoip.go +++ b/backend/internal/http/geoip/geoip.go @@ -2,7 +2,10 @@ package geoip import ( "errors" + "github.com/tomasen/realip" "net" + "net/http" + "openreplay/backend/pkg/logger" "strings" "github.com/oschwald/maxminddb-golang" @@ -46,18 +49,23 @@ func UnpackGeoRecord(pkg string) *GeoRecord { type GeoParser interface { Parse(ip net.IP) (*GeoRecord, error) + ExtractGeoData(r *http.Request) *GeoRecord } type geoParser struct { - r *maxminddb.Reader + log logger.Logger + r *maxminddb.Reader } -func New(file string) (GeoParser, error) { +func New(log logger.Logger, file string) (GeoParser, error) { r, err := maxminddb.Open(file) if err != nil { return nil, err } - return &geoParser{r}, nil + return &geoParser{ + log: log, + r: r, + }, nil } func (geoIP *geoParser) Parse(ip net.IP) (*GeoRecord, error) { @@ -82,3 +90,12 @@ func (geoIP *geoParser) Parse(ip net.IP) (*GeoRecord, error) { res.City = record.City.Names["en"] return res, nil } + +func (geoIP *geoParser) ExtractGeoData(r *http.Request) *GeoRecord { + ip := net.ParseIP(realip.FromRequest(r)) + geoRec, err := geoIP.Parse(ip) + if err != nil { + geoIP.log.Warn(r.Context(), "failed to parse geo data: %v", err) + } + return geoRec +} diff --git a/backend/internal/http/router/conditions.go b/backend/internal/http/router/conditions.go deleted file mode 100644 index c4dc67bcc..000000000 --- a/backend/internal/http/router/conditions.go +++ /dev/null @@ -1,11 +0,0 @@ -package router - -import ( - "errors" - "net/http" - "time" -) - -func (e *Router) getConditions(w http.ResponseWriter, r *http.Request) { - e.ResponseWithError(r.Context(), w, http.StatusNotImplemented, errors.New("no support"), time.Now(), r.URL.Path, 0) -} diff --git a/backend/internal/http/router/handlers-mobile.go b/backend/internal/http/router/handlers-mobile.go deleted file mode 100644 index b9b5366dd..000000000 --- a/backend/internal/http/router/handlers-mobile.go +++ /dev/null @@ -1,271 +0,0 @@ -package router - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "math/rand" - "net/http" - "openreplay/backend/internal/http/ios" - "openreplay/backend/internal/http/uuid" - "openreplay/backend/pkg/db/postgres" - "openreplay/backend/pkg/messages" - "openreplay/backend/pkg/sessions" - "openreplay/backend/pkg/token" - "strconv" - "time" -) - -func (e *Router) startMobileSessionHandler(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - - if r.Body == nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, errors.New("request body is empty"), startTime, r.URL.Path, 0) - return - } - body := http.MaxBytesReader(w, r.Body, e.cfg.JsonSizeLimit) - defer body.Close() - - req := &StartMobileSessionRequest{} - if err := json.NewDecoder(body).Decode(req); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, 0) - return - } - - // Add tracker version to context - r = r.WithContext(context.WithValue(r.Context(), "tracker", req.TrackerVersion)) - - if req.ProjectKey == nil { - e.ResponseWithError(r.Context(), w, http.StatusForbidden, errors.New("projectKey value required"), startTime, r.URL.Path, 0) - return - } - - p, err := e.services.Projects.GetProjectByKey(*req.ProjectKey) - if err != nil { - if postgres.IsNoRowsErr(err) { - logErr := fmt.Errorf("project doesn't exist or is not active, key: %s", *req.ProjectKey) - e.ResponseWithError(r.Context(), w, http.StatusNotFound, logErr, startTime, r.URL.Path, 0) - } else { - e.log.Error(r.Context(), "failed to get project by key: %s, err: %s", *req.ProjectKey, err) - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, errors.New("can't find a project"), startTime, r.URL.Path, 0) - } - return - } - - // Add projectID to context - r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", p.ProjectID))) - - // Check if the project supports mobile sessions - if !p.IsMobile() { - e.ResponseWithError(r.Context(), w, http.StatusForbidden, errors.New("project doesn't support mobile sessions"), startTime, r.URL.Path, 0) - return - } - - if !checkMobileTrackerVersion(req.TrackerVersion) { - e.ResponseWithError(r.Context(), w, http.StatusUpgradeRequired, errors.New("tracker version not supported"), startTime, r.URL.Path, 0) - return - } - - userUUID := uuid.GetUUID(req.UserUUID) - tokenData, err := e.services.Tokenizer.Parse(req.Token) - - if err != nil { // Starting the new one - dice := byte(rand.Intn(100)) // [0, 100) - // Use condition rate if it's set - if req.Condition != "" { - rate, err := e.services.Conditions.GetRate(p.ProjectID, req.Condition, int(p.SampleRate)) - if err != nil { - e.log.Warn(r.Context(), "can't get condition rate, condition: %s, err: %s", req.Condition, err) - } else { - p.SampleRate = byte(rate) - } - } - if dice >= p.SampleRate { - e.ResponseWithError(r.Context(), w, http.StatusForbidden, fmt.Errorf("capture rate miss, rate: %d", p.SampleRate), startTime, r.URL.Path, 0) - return - } - - ua := e.services.UaParser.ParseFromHTTPRequest(r) - if ua == nil { - e.ResponseWithError(r.Context(), w, http.StatusForbidden, fmt.Errorf("browser not recognized, user-agent: %s", r.Header.Get("User-Agent")), startTime, r.URL.Path, 0) - return - } - sessionID, err := e.services.Flaker.Compose(uint64(startTime.UnixMilli())) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, 0) - return - } - - expTime := startTime.Add(time.Duration(p.MaxSessionDuration) * time.Millisecond) - tokenData = &token.TokenData{sessionID, 0, expTime.UnixMilli()} - - // Add sessionID to context - r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionID))) - - geoInfo := e.ExtractGeoData(r) - deviceType, platform, os := ios.GetIOSDeviceType(req.UserDevice), "ios", "IOS" - if req.Platform != "" && req.Platform != "ios" { - deviceType = req.UserDeviceType - platform = req.Platform - os = "Android" - } - - if !req.DoNotRecord { - if err := e.services.Sessions.Add(&sessions.Session{ - SessionID: sessionID, - Platform: platform, - Timestamp: req.Timestamp, - Timezone: req.Timezone, - ProjectID: p.ProjectID, - TrackerVersion: req.TrackerVersion, - RevID: req.RevID, - UserUUID: userUUID, - UserOS: os, - UserOSVersion: req.UserOSVersion, - UserDevice: ios.MapIOSDevice(req.UserDevice), - UserDeviceType: deviceType, - UserCountry: geoInfo.Country, - UserState: geoInfo.State, - UserCity: geoInfo.City, - UserDeviceMemorySize: req.DeviceMemory, - UserDeviceHeapSize: req.DeviceMemory, - ScreenWidth: req.Width, - ScreenHeight: req.Height, - }); err != nil { - e.log.Warn(r.Context(), "failed to add mobile session to DB: %s", err) - } - - sessStart := &messages.MobileSessionStart{ - Timestamp: req.Timestamp, - ProjectID: uint64(p.ProjectID), - TrackerVersion: req.TrackerVersion, - RevID: req.RevID, - UserUUID: userUUID, - UserOS: os, - UserOSVersion: req.UserOSVersion, - UserDevice: ios.MapIOSDevice(req.UserDevice), - UserDeviceType: deviceType, - UserCountry: geoInfo.Pack(), - } - - if err := e.services.Producer.Produce(e.cfg.TopicRawMobile, tokenData.ID, sessStart.Encode()); err != nil { - e.log.Error(r.Context(), "failed to send mobile sessionStart event to queue: %s", err) - } - } - } - - e.ResponseWithJSON(r.Context(), w, &StartMobileSessionResponse{ - Token: e.services.Tokenizer.Compose(*tokenData), - UserUUID: userUUID, - SessionID: strconv.FormatUint(tokenData.ID, 10), - BeaconSizeLimit: e.cfg.BeaconSizeLimit, - ImageQuality: e.cfg.MobileQuality, - FrameRate: e.cfg.MobileFps, - ProjectID: strconv.FormatUint(uint64(p.ProjectID), 10), - Features: e.features, - }, startTime, r.URL.Path, 0) -} - -func (e *Router) pushMobileMessagesHandler(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - - sessionData, err := e.services.Tokenizer.ParseFromHTTPRequest(r) - if sessionData != nil { - r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) - } - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, 0) - return - } - - // Add sessionID and projectID to context - if info, err := e.services.Sessions.Get(sessionData.ID); err == nil { - r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", info.ProjectID))) - } - - e.pushMessages(w, r, sessionData.ID, e.cfg.TopicRawMobile) -} - -func (e *Router) pushMobileLateMessagesHandler(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - sessionData, err := e.services.Tokenizer.ParseFromHTTPRequest(r) - if sessionData != nil { - r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) - } - - if err != nil && err != token.EXPIRED { - e.ResponseWithError(r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, 0) - return - } - // Check timestamps here? - e.pushMessages(w, r, sessionData.ID, e.cfg.TopicRawMobile) -} - -func (e *Router) mobileImagesUploadHandler(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - - sessionData, err := e.services.Tokenizer.ParseFromHTTPRequest(r) - if sessionData != nil { - r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) - } - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, 0) - return - } - - // Add sessionID and projectID to context - if info, err := e.services.Sessions.Get(sessionData.ID); err == nil { - r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", info.ProjectID))) - } - - if r.Body == nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, errors.New("request body is empty"), startTime, r.URL.Path, 0) - return - } - r.Body = http.MaxBytesReader(w, r.Body, e.cfg.FileSizeLimit) - defer r.Body.Close() - - err = r.ParseMultipartForm(5 * 1e6) // ~5Mb - if err == http.ErrNotMultipart || err == http.ErrMissingBoundary { - e.ResponseWithError(r.Context(), w, http.StatusUnsupportedMediaType, err, startTime, r.URL.Path, 0) - return - } else if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, 0) // TODO: send error here only on staging - return - } - - if r.MultipartForm == nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, errors.New("multipart not parsed"), startTime, r.URL.Path, 0) - return - } - - if len(r.MultipartForm.Value["projectKey"]) == 0 { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, errors.New("projectKey parameter missing"), startTime, r.URL.Path, 0) // status for missing/wrong parameter? - return - } - - for _, fileHeaderList := range r.MultipartForm.File { - for _, fileHeader := range fileHeaderList { - file, err := fileHeader.Open() - if err != nil { - continue - } - - data, err := io.ReadAll(file) - if err != nil { - file.Close() - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, 0) - return - } - file.Close() - - if err := e.services.Producer.Produce(e.cfg.TopicRawImages, sessionData.ID, data); err != nil { - e.log.Warn(r.Context(), "failed to send image to queue: %s", err) - } - } - } - - e.ResponseOK(r.Context(), w, startTime, r.URL.Path, 0) -} diff --git a/backend/internal/http/router/handlers-web.go b/backend/internal/http/router/handlers-web.go deleted file mode 100644 index fe379592b..000000000 --- a/backend/internal/http/router/handlers-web.go +++ /dev/null @@ -1,754 +0,0 @@ -package router - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "math/rand" - "net/http" - "strconv" - "strings" - "time" - - "github.com/gorilla/mux" - - "github.com/Masterminds/semver" - "github.com/klauspost/compress/gzip" - "openreplay/backend/internal/http/util" - "openreplay/backend/internal/http/uuid" - "openreplay/backend/pkg/db/postgres" - "openreplay/backend/pkg/featureflags" - "openreplay/backend/pkg/flakeid" - . "openreplay/backend/pkg/messages" - "openreplay/backend/pkg/sessions" - "openreplay/backend/pkg/token" - "openreplay/backend/pkg/uxtesting" -) - -func (e *Router) readBody(w http.ResponseWriter, r *http.Request, limit int64) ([]byte, error) { - body := http.MaxBytesReader(w, r.Body, limit) - var ( - bodyBytes []byte - err error - ) - - // Check if body is gzipped and decompress it - if r.Header.Get("Content-Encoding") == "gzip" { - reader, err := gzip.NewReader(body) - if err != nil { - return nil, fmt.Errorf("can't create gzip reader: %s", err) - } - bodyBytes, err = io.ReadAll(reader) - if err != nil { - return nil, fmt.Errorf("can't read gzip body: %s", err) - } - if err := reader.Close(); err != nil { - e.log.Warn(r.Context(), "can't close gzip reader: %s", err) - } - } else { - bodyBytes, err = io.ReadAll(body) - } - - // Close body - if closeErr := body.Close(); closeErr != nil { - e.log.Warn(r.Context(), "error while closing request body: %s", closeErr) - } - if err != nil { - return nil, err - } - return bodyBytes, nil -} - -func checkMobileTrackerVersion(ver string) bool { - c, err := semver.NewConstraint(">=1.0.9") - if err != nil { - return false - } - // Check for beta version - parts := strings.Split(ver, "-") - if len(parts) > 1 { - ver = parts[0] - } - v, err := semver.NewVersion(ver) - if err != nil { - return false - } - return c.Check(v) -} - -func getSessionTimestamp(req *StartSessionRequest, startTimeMili int64) (ts uint64) { - ts = uint64(req.Timestamp) - if req.IsOffline { - return - } - c, err := semver.NewConstraint(">=4.1.6") - if err != nil { - return - } - ver := req.TrackerVersion - parts := strings.Split(ver, "-") - if len(parts) > 1 { - ver = parts[0] - } - v, err := semver.NewVersion(ver) - if err != nil { - return - } - if c.Check(v) { - ts = uint64(startTimeMili) - if req.BufferDiff > 0 && req.BufferDiff < 5*60*1000 { - ts -= req.BufferDiff - } - } - return -} - -func (e *Router) startSessionHandlerWeb(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - bodySize := 0 - - // Check request body - if r.Body == nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, errors.New("request body is empty"), startTime, r.URL.Path, bodySize) - return - } - - bodyBytes, err := e.readBody(w, r, e.cfg.JsonSizeLimit) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) - return - } - bodySize = len(bodyBytes) - - // Parse request body - req := &StartSessionRequest{} - if err := json.Unmarshal(bodyBytes, req); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) - return - } - - // Add tracker version to context - r = r.WithContext(context.WithValue(r.Context(), "tracker", req.TrackerVersion)) - - // Handler's logic - if req.ProjectKey == nil { - e.ResponseWithError(r.Context(), w, http.StatusForbidden, errors.New("ProjectKey value required"), startTime, r.URL.Path, bodySize) - return - } - - p, err := e.services.Projects.GetProjectByKey(*req.ProjectKey) - if err != nil { - if postgres.IsNoRowsErr(err) { - logErr := fmt.Errorf("project doesn't exist or is not active, key: %s", *req.ProjectKey) - e.ResponseWithError(r.Context(), w, http.StatusNotFound, logErr, startTime, r.URL.Path, bodySize) - } else { - e.log.Error(r.Context(), "failed to get project by key: %s, err: %s", *req.ProjectKey, err) - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, errors.New("can't find a project"), startTime, r.URL.Path, bodySize) - } - return - } - - // Add projectID to context - r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", p.ProjectID))) - - // Check if the project supports mobile sessions - if !p.IsWeb() { - e.ResponseWithError(r.Context(), w, http.StatusForbidden, errors.New("project doesn't support web sessions"), startTime, r.URL.Path, bodySize) - return - } - - ua := e.services.UaParser.ParseFromHTTPRequest(r) - if ua == nil { - e.ResponseWithError(r.Context(), w, http.StatusForbidden, fmt.Errorf("browser not recognized, user-agent: %s", r.Header.Get("User-Agent")), startTime, r.URL.Path, bodySize) - return - } - - geoInfo := e.ExtractGeoData(r) - - userUUID := uuid.GetUUID(req.UserUUID) - tokenData, err := e.services.Tokenizer.Parse(req.Token) - if err != nil || req.Reset { // Starting the new one - dice := byte(rand.Intn(100)) - // Use condition rate if it's set - if req.Condition != "" { - rate, err := e.services.Conditions.GetRate(p.ProjectID, req.Condition, int(p.SampleRate)) - if err != nil { - e.log.Warn(r.Context(), "can't get condition rate, condition: %s, err: %s", req.Condition, err) - } else { - p.SampleRate = byte(rate) - } - } - if dice >= p.SampleRate { - e.ResponseWithError(r.Context(), w, http.StatusForbidden, fmt.Errorf("capture rate miss, rate: %d", p.SampleRate), startTime, r.URL.Path, bodySize) - return - } - - startTimeMili := startTime.UnixMilli() - sessionID, err := e.services.Flaker.Compose(uint64(startTimeMili)) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) - return - } - - expTime := startTime.Add(time.Duration(p.MaxSessionDuration) * time.Millisecond) - tokenData = &token.TokenData{ - ID: sessionID, - Delay: startTimeMili - req.Timestamp, - ExpTime: expTime.UnixMilli(), - } - - // Add sessionID to context - r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionID))) - - if recordSession(req) { - sessionStart := &SessionStart{ - Timestamp: getSessionTimestamp(req, startTimeMili), - ProjectID: uint64(p.ProjectID), - TrackerVersion: req.TrackerVersion, - RevID: req.RevID, - UserUUID: userUUID, - UserAgent: r.Header.Get("User-Agent"), - UserOS: ua.OS, - UserOSVersion: ua.OSVersion, - UserBrowser: ua.Browser, - UserBrowserVersion: ua.BrowserVersion, - UserDevice: ua.Device, - UserDeviceType: ua.DeviceType, - UserCountry: geoInfo.Pack(), - UserDeviceMemorySize: req.DeviceMemory, - UserDeviceHeapSize: req.JsHeapSizeLimit, - UserID: req.UserID, - } - - // Save sessionStart to db - if err := e.services.Sessions.Add(&sessions.Session{ - SessionID: sessionID, - Platform: "web", - Timestamp: sessionStart.Timestamp, - Timezone: req.Timezone, - ProjectID: uint32(sessionStart.ProjectID), - TrackerVersion: sessionStart.TrackerVersion, - RevID: sessionStart.RevID, - UserUUID: sessionStart.UserUUID, - UserOS: sessionStart.UserOS, - UserOSVersion: sessionStart.UserOSVersion, - UserDevice: sessionStart.UserDevice, - UserCountry: geoInfo.Country, - UserState: geoInfo.State, - UserCity: geoInfo.City, - UserAgent: sessionStart.UserAgent, - UserBrowser: sessionStart.UserBrowser, - UserBrowserVersion: sessionStart.UserBrowserVersion, - UserDeviceType: sessionStart.UserDeviceType, - UserDeviceMemorySize: sessionStart.UserDeviceMemorySize, - UserDeviceHeapSize: sessionStart.UserDeviceHeapSize, - UserID: &sessionStart.UserID, - ScreenWidth: req.Width, - ScreenHeight: req.Height, - }); err != nil { - e.log.Warn(r.Context(), "can't insert sessionStart to DB: %s", err) - } - - // Send sessionStart message to kafka - if err := e.services.Producer.Produce(e.cfg.TopicRawWeb, tokenData.ID, sessionStart.Encode()); err != nil { - e.log.Error(r.Context(), "can't send sessionStart to queue: %s", err) - } - } - } else { - r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", tokenData.ID))) - } - - // Save information about session beacon size - e.addBeaconSize(tokenData.ID, p.BeaconSize) - - startResponse := &StartSessionResponse{ - Token: e.services.Tokenizer.Compose(*tokenData), - UserUUID: userUUID, - UserOS: ua.OS, - UserDevice: ua.Device, - UserBrowser: ua.Browser, - UserCountry: geoInfo.Country, - UserState: geoInfo.State, - UserCity: geoInfo.City, - SessionID: strconv.FormatUint(tokenData.ID, 10), - ProjectID: strconv.FormatUint(uint64(p.ProjectID), 10), - BeaconSizeLimit: e.getBeaconSize(tokenData.ID), - CompressionThreshold: e.getCompressionThreshold(), - StartTimestamp: int64(flakeid.ExtractTimestamp(tokenData.ID)), - Delay: tokenData.Delay, - CanvasEnabled: e.cfg.RecordCanvas, - CanvasImageQuality: e.cfg.CanvasQuality, - CanvasFrameRate: e.cfg.CanvasFps, - Features: e.features, - } - modifyResponse(req, startResponse) - - e.ResponseWithJSON(r.Context(), w, startResponse, startTime, r.URL.Path, bodySize) -} - -func (e *Router) pushMessagesHandlerWeb(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - bodySize := 0 - - // Get debug header with batch info - if batch := r.URL.Query().Get("batch"); batch != "" { - r = r.WithContext(context.WithValue(r.Context(), "batch", batch)) - } - - // Check authorization - sessionData, err := e.services.Tokenizer.ParseFromHTTPRequest(r) - if sessionData != nil { - r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) - } - tokenJustExpired := false - if err != nil { - if errors.Is(err, token.JUST_EXPIRED) { - tokenJustExpired = true - } else { - e.ResponseWithError(r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, bodySize) - return - } - } - - // Add sessionID and projectID to context - if info, err := e.services.Sessions.Get(sessionData.ID); err == nil { - r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", info.ProjectID))) - } - - // Check request body - if r.Body == nil { - errCode := http.StatusBadRequest - if tokenJustExpired { - errCode = http.StatusUnauthorized - } - e.ResponseWithError(r.Context(), w, errCode, errors.New("request body is empty"), startTime, r.URL.Path, bodySize) - return - } - - bodyBytes, err := e.readBody(w, r, e.getBeaconSize(sessionData.ID)) - if err != nil { - errCode := http.StatusRequestEntityTooLarge - if tokenJustExpired { - errCode = http.StatusUnauthorized - } - e.ResponseWithError(r.Context(), w, errCode, err, startTime, r.URL.Path, bodySize) - return - } - bodySize = len(bodyBytes) - - // Send processed messages to queue as array of bytes - err = e.services.Producer.Produce(e.cfg.TopicRawWeb, sessionData.ID, bodyBytes) - if err != nil { - e.log.Error(r.Context(), "can't send messages batch to queue: %s", err) - errCode := http.StatusInternalServerError - if tokenJustExpired { - errCode = http.StatusUnauthorized - } - e.ResponseWithError(r.Context(), w, errCode, errors.New("can't save message, try again"), startTime, r.URL.Path, bodySize) - return - } - - if tokenJustExpired { - e.ResponseWithError(r.Context(), w, http.StatusUnauthorized, errors.New("token expired"), startTime, r.URL.Path, bodySize) - return - } - e.ResponseOK(r.Context(), w, startTime, r.URL.Path, bodySize) -} - -func (e *Router) notStartedHandlerWeb(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - bodySize := 0 - - if r.Body == nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, errors.New("request body is empty"), startTime, r.URL.Path, bodySize) - return - } - bodyBytes, err := e.readBody(w, r, e.cfg.JsonSizeLimit) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) - return - } - bodySize = len(bodyBytes) - - req := &NotStartedRequest{} - if err := json.Unmarshal(bodyBytes, req); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) - return - } - - // Add tracker version to context - r = r.WithContext(context.WithValue(r.Context(), "tracker", req.TrackerVersion)) - - // Handler's logic - if req.ProjectKey == nil { - e.ResponseWithError(r.Context(), w, http.StatusForbidden, errors.New("projectKey value required"), startTime, r.URL.Path, bodySize) - return - } - p, err := e.services.Projects.GetProjectByKey(*req.ProjectKey) - if err != nil { - if postgres.IsNoRowsErr(err) { - logErr := fmt.Errorf("project doesn't exist or is not active, key: %s", *req.ProjectKey) - e.ResponseWithError(r.Context(), w, http.StatusNotFound, logErr, startTime, r.URL.Path, bodySize) - } else { - e.log.Error(r.Context(), "can't find a project: %s", err) - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, errors.New("can't find a project"), startTime, r.URL.Path, bodySize) - } - return - } - - // Add projectID to context - r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", p.ProjectID))) - - ua := e.services.UaParser.ParseFromHTTPRequest(r) - if ua == nil { - e.ResponseWithError(r.Context(), w, http.StatusForbidden, fmt.Errorf("browser not recognized, user-agent: %s", r.Header.Get("User-Agent")), startTime, r.URL.Path, bodySize) - return - } - geoInfo := e.ExtractGeoData(r) - err = e.services.Sessions.AddUnStarted(&sessions.UnStartedSession{ - ProjectKey: *req.ProjectKey, - TrackerVersion: req.TrackerVersion, - DoNotTrack: req.DoNotTrack, - Platform: "web", - UserAgent: r.Header.Get("User-Agent"), - UserOS: ua.OS, - UserOSVersion: ua.OSVersion, - UserBrowser: ua.Browser, - UserBrowserVersion: ua.BrowserVersion, - UserDevice: ua.Device, - UserDeviceType: ua.DeviceType, - UserCountry: geoInfo.Country, - UserState: geoInfo.State, - UserCity: geoInfo.City, - }) - if err != nil { - e.log.Warn(r.Context(), "can't insert un-started session: %s", err) - } - // response ok anyway - e.ResponseOK(r.Context(), w, startTime, r.URL.Path, bodySize) -} - -func (e *Router) featureFlagsHandlerWeb(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - bodySize := 0 - - // Check authorization - sessionData, err := e.services.Tokenizer.ParseFromHTTPRequest(r) - if sessionData != nil { - r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) - } - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, bodySize) - return - } - - // Add sessionID and projectID to context - if info, err := e.services.Sessions.Get(sessionData.ID); err == nil { - r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", info.ProjectID))) - } - - if r.Body == nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, errors.New("request body is empty"), startTime, r.URL.Path, bodySize) - return - } - bodyBytes, err := e.readBody(w, r, e.cfg.JsonSizeLimit) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) - return - } - bodySize = len(bodyBytes) - - // Parse request body - req := &featureflags.FeatureFlagsRequest{} - if err := json.Unmarshal(bodyBytes, req); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) - return - } - - computedFlags, err := e.services.FeatureFlags.ComputeFlagsForSession(req) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) - return - } - - resp := &featureflags.FeatureFlagsResponse{ - Flags: computedFlags, - } - e.ResponseWithJSON(r.Context(), w, resp, startTime, r.URL.Path, bodySize) -} - -func (e *Router) getUXTestInfo(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - bodySize := 0 - - // Check authorization - sessionData, err := e.services.Tokenizer.ParseFromHTTPRequest(r) - if sessionData != nil { - r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) - } - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, bodySize) - return - } - - sess, err := e.services.Sessions.Get(sessionData.ID) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusForbidden, err, startTime, r.URL.Path, bodySize) - return - } - - // Add projectID to context - r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", sess.ProjectID))) - - // Get taskID - vars := mux.Vars(r) - id := vars["id"] - - // Get task info - info, err := e.services.UXTesting.GetInfo(id) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) - return - } - if sess.ProjectID != info.ProjectID { - e.ResponseWithError(r.Context(), w, http.StatusForbidden, errors.New("project mismatch"), startTime, r.URL.Path, bodySize) - return - } - type TaskInfoResponse struct { - Task *uxtesting.UXTestInfo `json:"test"` - } - e.ResponseWithJSON(r.Context(), w, &TaskInfoResponse{Task: info}, startTime, r.URL.Path, bodySize) -} - -func (e *Router) sendUXTestSignal(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - bodySize := 0 - - // Check authorization - sessionData, err := e.services.Tokenizer.ParseFromHTTPRequest(r) - if sessionData != nil { - r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) - } - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, bodySize) - return - } - - // Add sessionID and projectID to context - if info, err := e.services.Sessions.Get(sessionData.ID); err == nil { - r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", info.ProjectID))) - } - - bodyBytes, err := e.readBody(w, r, e.cfg.JsonSizeLimit) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) - return - } - bodySize = len(bodyBytes) - - // Parse request body - req := &uxtesting.TestSignal{} - - if err := json.Unmarshal(bodyBytes, req); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) - return - } - req.SessionID = sessionData.ID - - // Save test signal - if err := e.services.UXTesting.SetTestSignal(req); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) - return - } - e.ResponseOK(r.Context(), w, startTime, r.URL.Path, bodySize) -} - -func (e *Router) sendUXTaskSignal(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - bodySize := 0 - - // Check authorization - sessionData, err := e.services.Tokenizer.ParseFromHTTPRequest(r) - if sessionData != nil { - r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) - } - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, bodySize) - return - } - - // Add sessionID and projectID to context - if info, err := e.services.Sessions.Get(sessionData.ID); err == nil { - r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", info.ProjectID))) - } - - bodyBytes, err := e.readBody(w, r, e.cfg.JsonSizeLimit) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) - return - } - bodySize = len(bodyBytes) - - // Parse request body - req := &uxtesting.TaskSignal{} - - if err := json.Unmarshal(bodyBytes, req); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) - return - } - req.SessionID = sessionData.ID - - // Save test signal - if err := e.services.UXTesting.SetTaskSignal(req); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) - return - } - e.ResponseOK(r.Context(), w, startTime, r.URL.Path, bodySize) -} - -func (e *Router) getUXUploadUrl(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - bodySize := 0 - - // Check authorization - sessionData, err := e.services.Tokenizer.ParseFromHTTPRequest(r) - if sessionData != nil { - r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) - } - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, bodySize) - return - } - - // Add sessionID and projectID to context - if info, err := e.services.Sessions.Get(sessionData.ID); err == nil { - r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", info.ProjectID))) - } - - key := fmt.Sprintf("%d/ux_webcam_record.webm", sessionData.ID) - url, err := e.services.ObjStorage.GetPreSignedUploadUrl(key) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) - return - } - type UrlResponse struct { - URL string `json:"url"` - } - e.ResponseWithJSON(r.Context(), w, &UrlResponse{URL: url}, startTime, r.URL.Path, bodySize) -} - -type ScreenshotMessage struct { - Name string - Data []byte -} - -func (e *Router) imagesUploaderHandlerWeb(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - - sessionData, err := e.services.Tokenizer.ParseFromHTTPRequest(r) - if sessionData != nil { - r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) - } - if err != nil { // Should accept expired token? - e.ResponseWithError(r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, 0) - return - } - - // Add sessionID and projectID to context - if info, err := e.services.Sessions.Get(sessionData.ID); err == nil { - r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", info.ProjectID))) - } - - if r.Body == nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, errors.New("request body is empty"), startTime, r.URL.Path, 0) - return - } - r.Body = http.MaxBytesReader(w, r.Body, e.cfg.FileSizeLimit) - defer r.Body.Close() - - // Parse the multipart form - err = r.ParseMultipartForm(10 << 20) // Max upload size 10 MB - if err == http.ErrNotMultipart || err == http.ErrMissingBoundary { - e.ResponseWithError(r.Context(), w, http.StatusUnsupportedMediaType, err, startTime, r.URL.Path, 0) - return - } else if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, 0) - return - } - - // Iterate over uploaded files - for _, fileHeaderList := range r.MultipartForm.File { - for _, fileHeader := range fileHeaderList { - file, err := fileHeader.Open() - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, 0) - return - } - - // Read the file content - fileBytes, err := io.ReadAll(file) - if err != nil { - file.Close() - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, 0) - return - } - file.Close() - - fileName := util.SafeString(fileHeader.Filename) - - // Create a message to send to Kafka - msg := ScreenshotMessage{ - Name: fileName, - Data: fileBytes, - } - data, err := json.Marshal(&msg) - if err != nil { - e.log.Warn(r.Context(), "can't marshal screenshot message, err: %s", err) - continue - } - - // Send the message to queue - if err := e.services.Producer.Produce(e.cfg.TopicCanvasImages, sessionData.ID, data); err != nil { - e.log.Warn(r.Context(), "can't send screenshot message to queue, err: %s", err) - } - } - } - e.ResponseOK(r.Context(), w, startTime, r.URL.Path, 0) -} - -func (e *Router) getTags(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - bodySize := 0 - - // Check authorization - sessionData, err := e.services.Tokenizer.ParseFromHTTPRequest(r) - if sessionData != nil { - r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) - } - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, bodySize) - return - } - sessInfo, err := e.services.Sessions.Get(sessionData.ID) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, bodySize) - return - } - - // Add sessionID and projectID to context - r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) - r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", sessInfo.ProjectID))) - - // Get tags - tags, err := e.services.Tags.Get(sessInfo.ProjectID) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) - return - } - type UrlResponse struct { - Tags interface{} `json:"tags"` - } - e.ResponseWithJSON(r.Context(), w, &UrlResponse{Tags: tags}, startTime, r.URL.Path, bodySize) -} diff --git a/backend/internal/http/router/handlers.go b/backend/internal/http/router/handlers.go deleted file mode 100644 index 16a04ffeb..000000000 --- a/backend/internal/http/router/handlers.go +++ /dev/null @@ -1,41 +0,0 @@ -package router - -import ( - "io" - "net/http" - "time" - - gzip "github.com/klauspost/pgzip" -) - -func (e *Router) pushMessages(w http.ResponseWriter, r *http.Request, sessionID uint64, topicName string) { - start := time.Now() - body := http.MaxBytesReader(w, r.Body, e.cfg.BeaconSizeLimit) - defer body.Close() - - var reader io.ReadCloser - var err error - - switch r.Header.Get("Content-Encoding") { - case "gzip": - reader, err = gzip.NewReader(body) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, start, r.URL.Path, 0) - return - } - defer reader.Close() - default: - reader = body - } - buf, err := io.ReadAll(reader) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, start, r.URL.Path, 0) - return - } - if err := e.services.Producer.Produce(topicName, sessionID, buf); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, start, r.URL.Path, 0) - return - } - w.WriteHeader(http.StatusOK) - e.log.Info(r.Context(), "response ok") -} diff --git a/backend/internal/http/router/response.go b/backend/internal/http/router/response.go deleted file mode 100644 index 512d442e9..000000000 --- a/backend/internal/http/router/response.go +++ /dev/null @@ -1,50 +0,0 @@ -package router - -import ( - "context" - "encoding/json" - "net/http" - "time" - - metrics "openreplay/backend/pkg/metrics/http" -) - -func recordMetrics(requestStart time.Time, url string, code, bodySize int) { - if bodySize > 0 { - metrics.RecordRequestSize(float64(bodySize), url, code) - } - metrics.IncreaseTotalRequests() - metrics.RecordRequestDuration(float64(time.Now().Sub(requestStart).Milliseconds()), url, code) -} - -func (e *Router) ResponseOK(ctx context.Context, w http.ResponseWriter, requestStart time.Time, url string, bodySize int) { - w.WriteHeader(http.StatusOK) - e.log.Info(ctx, "response ok") - recordMetrics(requestStart, url, http.StatusOK, bodySize) -} - -func (e *Router) ResponseWithJSON(ctx context.Context, w http.ResponseWriter, res interface{}, requestStart time.Time, url string, bodySize int) { - e.log.Info(ctx, "response ok") - body, err := json.Marshal(res) - if err != nil { - e.log.Error(ctx, "can't marshal response: %s", err) - } - w.Header().Set("Content-Type", "application/json") - w.Write(body) - recordMetrics(requestStart, url, http.StatusOK, bodySize) -} - -type response struct { - Error string `json:"error"` -} - -func (e *Router) ResponseWithError(ctx context.Context, w http.ResponseWriter, code int, err error, requestStart time.Time, url string, bodySize int) { - e.log.Error(ctx, "response error, code: %d, error: %s", code, err) - body, err := json.Marshal(&response{err.Error()}) - if err != nil { - e.log.Error(ctx, "can't marshal response: %s", err) - } - w.WriteHeader(code) - w.Write(body) - recordMetrics(requestStart, url, code, bodySize) -} diff --git a/backend/internal/http/router/router.go b/backend/internal/http/router/router.go deleted file mode 100644 index 6bc158c44..000000000 --- a/backend/internal/http/router/router.go +++ /dev/null @@ -1,178 +0,0 @@ -package router - -import ( - "fmt" - "github.com/docker/distribution/context" - "github.com/tomasen/realip" - "net" - "net/http" - "openreplay/backend/internal/http/geoip" - "openreplay/backend/pkg/logger" - "sync" - "time" - - "github.com/gorilla/mux" - http3 "openreplay/backend/internal/config/http" - http2 "openreplay/backend/internal/http/services" - "openreplay/backend/internal/http/util" -) - -type BeaconSize struct { - size int64 - time time.Time -} - -type Router struct { - log logger.Logger - cfg *http3.Config - router *mux.Router - mutex *sync.RWMutex - services *http2.ServicesBuilder - beaconSizeCache map[uint64]*BeaconSize // Cache for session's beaconSize - compressionThreshold int64 - features map[string]bool -} - -func NewRouter(cfg *http3.Config, log logger.Logger, services *http2.ServicesBuilder) (*Router, error) { - switch { - case cfg == nil: - return nil, fmt.Errorf("config is empty") - case services == nil: - return nil, fmt.Errorf("services is empty") - case log == nil: - return nil, fmt.Errorf("logger is empty") - } - e := &Router{ - log: log, - cfg: cfg, - mutex: &sync.RWMutex{}, - services: services, - beaconSizeCache: make(map[uint64]*BeaconSize), - compressionThreshold: cfg.CompressionThreshold, - features: map[string]bool{ - "feature-flags": cfg.IsFeatureFlagEnabled, - "usability-test": cfg.IsUsabilityTestEnabled, - }, - } - e.init() - go e.clearBeaconSizes() - return e, nil -} - -func (e *Router) addBeaconSize(sessionID uint64, size int64) { - if size <= 0 { - return - } - e.mutex.Lock() - defer e.mutex.Unlock() - e.beaconSizeCache[sessionID] = &BeaconSize{ - size: size, - time: time.Now(), - } -} - -func (e *Router) getBeaconSize(sessionID uint64) int64 { - e.mutex.RLock() - defer e.mutex.RUnlock() - if beaconSize, ok := e.beaconSizeCache[sessionID]; ok { - beaconSize.time = time.Now() - return beaconSize.size - } - return e.cfg.BeaconSizeLimit -} - -func (e *Router) getCompressionThreshold() int64 { - return e.compressionThreshold -} - -func (e *Router) clearBeaconSizes() { - for { - time.Sleep(time.Minute * 2) - now := time.Now() - e.mutex.Lock() - for sid, bs := range e.beaconSizeCache { - if now.Sub(bs.time) > time.Minute*3 { - delete(e.beaconSizeCache, sid) - } - } - e.mutex.Unlock() - } -} - -func (e *Router) ExtractGeoData(r *http.Request) *geoip.GeoRecord { - ip := net.ParseIP(realip.FromRequest(r)) - geoRec, err := e.services.GeoIP.Parse(ip) - if err != nil { - e.log.Warn(r.Context(), "failed to parse geo data: %v", err) - } - return geoRec -} - -func (e *Router) init() { - e.router = mux.NewRouter() - - // Root path - e.router.HandleFunc("/", e.root) - - handlers := map[string]func(http.ResponseWriter, *http.Request){ - "/v1/web/not-started": e.notStartedHandlerWeb, - "/v1/web/start": e.startSessionHandlerWeb, - "/v1/web/i": e.pushMessagesHandlerWeb, - "/v1/web/feature-flags": e.featureFlagsHandlerWeb, - "/v1/web/images": e.imagesUploaderHandlerWeb, - "/v1/mobile/start": e.startMobileSessionHandler, - "/v1/mobile/i": e.pushMobileMessagesHandler, - "/v1/mobile/late": e.pushMobileLateMessagesHandler, - "/v1/mobile/images": e.mobileImagesUploadHandler, - "/v1/web/uxt/signals/test": e.sendUXTestSignal, - "/v1/web/uxt/signals/task": e.sendUXTaskSignal, - } - getHandlers := map[string]func(http.ResponseWriter, *http.Request){ - "/v1/web/uxt/test/{id}": e.getUXTestInfo, - "/v1/web/uxt/upload-url": e.getUXUploadUrl, - "/v1/web/tags": e.getTags, - "/v1/web/conditions/{project}": e.getConditions, - "/v1/mobile/conditions/{project}": e.getConditions, - } - prefix := "/ingest" - - for path, handler := range handlers { - e.router.HandleFunc(path, handler).Methods("POST", "OPTIONS") - e.router.HandleFunc(prefix+path, handler).Methods("POST", "OPTIONS") - } - for path, handler := range getHandlers { - e.router.HandleFunc(path, handler).Methods("GET", "OPTIONS") - e.router.HandleFunc(prefix+path, handler).Methods("GET", "OPTIONS") - } - - // CORS middleware - e.router.Use(e.corsMiddleware) -} - -func (e *Router) root(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) -} - -func (e *Router) corsMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if e.cfg.UseAccessControlHeaders { - // Prepare headers for preflight requests - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "POST,GET") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type,Authorization,Content-Encoding") - } - if r.Method == http.MethodOptions { - w.Header().Set("Cache-Control", "max-age=86400") - w.WriteHeader(http.StatusOK) - return - } - r = r.WithContext(context.WithValues(r.Context(), map[string]interface{}{"httpMethod": r.Method, "url": util.SafeString(r.URL.Path)})) - - // Serve request - next.ServeHTTP(w, r) - }) -} - -func (e *Router) GetHandler() http.Handler { - return e.router -} diff --git a/backend/internal/http/server/server.go b/backend/internal/http/server/server.go deleted file mode 100644 index e1fdce74d..000000000 --- a/backend/internal/http/server/server.go +++ /dev/null @@ -1,43 +0,0 @@ -package server - -import ( - "context" - "errors" - "fmt" - "golang.org/x/net/http2" - "net/http" - "time" -) - -type Server struct { - server *http.Server -} - -func New(handler http.Handler, host, port string, timeout time.Duration) (*Server, error) { - switch { - case port == "": - return nil, errors.New("empty server port") - case handler == nil: - return nil, errors.New("empty handler") - case timeout < 1: - return nil, fmt.Errorf("invalid timeout %d", timeout) - } - server := &http.Server{ - Addr: fmt.Sprintf("%s:%s", host, port), - Handler: handler, - ReadTimeout: timeout, - WriteTimeout: timeout, - } - http2.ConfigureServer(server, nil) - return &Server{ - server: server, - }, nil -} - -func (s *Server) Start() error { - return s.server.ListenAndServe() -} - -func (s *Server) Stop() { - s.server.Shutdown(context.Background()) -} diff --git a/backend/internal/http/services/services.go b/backend/internal/http/services/services.go index 2ce1d8168..8866edb8e 100644 --- a/backend/internal/http/services/services.go +++ b/backend/internal/http/services/services.go @@ -5,44 +5,44 @@ import ( "openreplay/backend/internal/http/geoip" "openreplay/backend/internal/http/uaparser" "openreplay/backend/pkg/conditions" + conditionsAPI "openreplay/backend/pkg/conditions/api" "openreplay/backend/pkg/db/postgres/pool" "openreplay/backend/pkg/db/redis" "openreplay/backend/pkg/featureflags" + featureflagsAPI "openreplay/backend/pkg/featureflags/api" "openreplay/backend/pkg/flakeid" "openreplay/backend/pkg/logger" - "openreplay/backend/pkg/objectstorage" + "openreplay/backend/pkg/metrics/web" "openreplay/backend/pkg/objectstorage/store" "openreplay/backend/pkg/projects" "openreplay/backend/pkg/queue/types" + "openreplay/backend/pkg/server/api" "openreplay/backend/pkg/sessions" + mobilesessions "openreplay/backend/pkg/sessions/api/mobile" + websessions "openreplay/backend/pkg/sessions/api/web" "openreplay/backend/pkg/tags" + tagsAPI "openreplay/backend/pkg/tags/api" "openreplay/backend/pkg/token" "openreplay/backend/pkg/uxtesting" + uxtestingAPI "openreplay/backend/pkg/uxtesting/api" ) type ServicesBuilder struct { - Projects projects.Projects - Sessions sessions.Sessions - FeatureFlags featureflags.FeatureFlags - Producer types.Producer - Flaker *flakeid.Flaker - UaParser *uaparser.UAParser - GeoIP geoip.GeoParser - Tokenizer *token.Tokenizer - ObjStorage objectstorage.ObjectStorage - UXTesting uxtesting.UXTesting - Tags tags.Tags - Conditions conditions.Conditions + WebAPI api.Handlers + MobileAPI api.Handlers + ConditionsAPI api.Handlers + FeatureFlagsAPI api.Handlers + TagsAPI api.Handlers + UxTestsAPI api.Handlers } -func New(log logger.Logger, cfg *http.Config, producer types.Producer, pgconn pool.Pool, redis *redis.Client) (*ServicesBuilder, error) { +func New(log logger.Logger, cfg *http.Config, metrics web.Web, producer types.Producer, pgconn pool.Pool, redis *redis.Client) (*ServicesBuilder, error) { projs := projects.New(log, pgconn, redis) - // ObjectStorage client to generate pre-signed upload urls objStore, err := store.NewStore(&cfg.ObjectsConfig) if err != nil { return nil, err } - geoModule, err := geoip.New(cfg.MaxMinDBFile) + geoModule, err := geoip.New(log, cfg.MaxMinDBFile) if err != nil { return nil, err } @@ -50,18 +50,32 @@ func New(log logger.Logger, cfg *http.Config, producer types.Producer, pgconn po if err != nil { return nil, err } - return &ServicesBuilder{ - Projects: projs, - Sessions: sessions.New(log, pgconn, projs, redis), - FeatureFlags: featureflags.New(pgconn), - Producer: producer, - Tokenizer: token.NewTokenizer(cfg.TokenSecret), - UaParser: uaModule, - GeoIP: geoModule, - Flaker: flakeid.NewFlaker(cfg.WorkerID), - ObjStorage: objStore, - UXTesting: uxtesting.New(pgconn), - Tags: tags.New(log, pgconn), - Conditions: conditions.New(pgconn), - }, nil + tokenizer := token.NewTokenizer(cfg.TokenSecret) + conditions := conditions.New(pgconn) + flaker := flakeid.NewFlaker(cfg.WorkerID) + sessions := sessions.New(log, pgconn, projs, redis) + featureFlags := featureflags.New(pgconn) + tags := tags.New(log, pgconn) + uxTesting := uxtesting.New(pgconn) + responser := api.NewResponser(metrics) + builder := &ServicesBuilder{} + if builder.WebAPI, err = websessions.NewHandlers(cfg, log, responser, producer, projs, sessions, uaModule, geoModule, tokenizer, conditions, flaker); err != nil { + return nil, err + } + if builder.MobileAPI, err = mobilesessions.NewHandlers(cfg, log, responser, producer, projs, sessions, uaModule, geoModule, tokenizer, conditions, flaker); err != nil { + return nil, err + } + if builder.ConditionsAPI, err = conditionsAPI.NewHandlers(log, responser, tokenizer, conditions); err != nil { + return nil, err + } + if builder.FeatureFlagsAPI, err = featureflagsAPI.NewHandlers(log, responser, cfg.JsonSizeLimit, tokenizer, sessions, featureFlags); err != nil { + return nil, err + } + if builder.TagsAPI, err = tagsAPI.NewHandlers(log, responser, tokenizer, sessions, tags); err != nil { + return nil, err + } + if builder.UxTestsAPI, err = uxtestingAPI.NewHandlers(log, responser, cfg.JsonSizeLimit, tokenizer, sessions, uxTesting, objStore); err != nil { + return nil, err + } + return builder, nil } diff --git a/backend/pkg/conditions/api/handlers.go b/backend/pkg/conditions/api/handlers.go new file mode 100644 index 000000000..13179eb57 --- /dev/null +++ b/backend/pkg/conditions/api/handlers.go @@ -0,0 +1,34 @@ +package api + +import ( + "net/http" + "time" + + "openreplay/backend/pkg/conditions" + "openreplay/backend/pkg/logger" + "openreplay/backend/pkg/server/api" + "openreplay/backend/pkg/token" +) + +type handlersImpl struct { + log logger.Logger + responser *api.Responser +} + +func NewHandlers(log logger.Logger, responser *api.Responser, tokenizer *token.Tokenizer, conditions conditions.Conditions) (api.Handlers, error) { + return &handlersImpl{ + log: log, + responser: responser, + }, nil +} + +func (e *handlersImpl) GetAll() []*api.Description { + return []*api.Description{ + {"/v1/web/conditions/{project}", e.getConditions, "GET"}, + {"/v1/mobile/conditions/{project}", e.getConditions, "GET"}, + } +} + +func (e *handlersImpl) getConditions(w http.ResponseWriter, r *http.Request) { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusNotImplemented, nil, time.Now(), r.URL.Path, 0) +} diff --git a/backend/pkg/db/postgres/bulk.go b/backend/pkg/db/postgres/bulk.go index e474ee417..0b7fc03e4 100644 --- a/backend/pkg/db/postgres/bulk.go +++ b/backend/pkg/db/postgres/bulk.go @@ -4,9 +4,9 @@ import ( "bytes" "errors" "fmt" - "openreplay/backend/pkg/db/postgres/pool" "time" + "openreplay/backend/pkg/db/postgres/pool" "openreplay/backend/pkg/metrics/database" ) diff --git a/backend/pkg/featureflags/api/handlers.go b/backend/pkg/featureflags/api/handlers.go new file mode 100644 index 000000000..ded3a180a --- /dev/null +++ b/backend/pkg/featureflags/api/handlers.go @@ -0,0 +1,92 @@ +package api + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "time" + + "openreplay/backend/pkg/featureflags" + "openreplay/backend/pkg/logger" + "openreplay/backend/pkg/server/api" + "openreplay/backend/pkg/sessions" + "openreplay/backend/pkg/token" +) + +type handlersImpl struct { + log logger.Logger + responser *api.Responser + jsonSizeLimit int64 + tokenizer *token.Tokenizer + sessions sessions.Sessions + featureFlags featureflags.FeatureFlags +} + +func NewHandlers(log logger.Logger, responser *api.Responser, jsonSizeLimit int64, tokenizer *token.Tokenizer, sessions sessions.Sessions, + featureFlags featureflags.FeatureFlags) (api.Handlers, error) { + return &handlersImpl{ + log: log, + responser: responser, + jsonSizeLimit: jsonSizeLimit, + tokenizer: tokenizer, + sessions: sessions, + featureFlags: featureFlags, + }, nil +} + +func (e *handlersImpl) GetAll() []*api.Description { + return []*api.Description{ + {"/v1/web/feature-flags", e.featureFlagsHandlerWeb, "POST"}, + } +} + +func (e *handlersImpl) featureFlagsHandlerWeb(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + bodySize := 0 + + // Check authorization + sessionData, err := e.tokenizer.ParseFromHTTPRequest(r) + if sessionData != nil { + r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) + } + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, bodySize) + return + } + + // Add sessionID and projectID to context + if info, err := e.sessions.Get(sessionData.ID); err == nil { + r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", info.ProjectID))) + } + + if r.Body == nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, errors.New("request body is empty"), startTime, r.URL.Path, bodySize) + return + } + bodyBytes, err := api.ReadCompressedBody(e.log, w, r, e.jsonSizeLimit) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) + return + } + bodySize = len(bodyBytes) + + // Parse request body + req := &featureflags.FeatureFlagsRequest{} + if err := json.Unmarshal(bodyBytes, req); err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + return + } + + computedFlags, err := e.featureFlags.ComputeFlagsForSession(req) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + return + } + + resp := &featureflags.FeatureFlagsResponse{ + Flags: computedFlags, + } + e.responser.ResponseWithJSON(e.log, r.Context(), w, resp, startTime, r.URL.Path, bodySize) +} diff --git a/backend/pkg/integrations/api/handlers.go b/backend/pkg/integrations/api/handlers.go new file mode 100644 index 000000000..24e3312a4 --- /dev/null +++ b/backend/pkg/integrations/api/handlers.go @@ -0,0 +1,206 @@ +package api + +import ( + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/gorilla/mux" + + integrationsCfg "openreplay/backend/internal/config/integrations" + "openreplay/backend/pkg/integrations/service" + "openreplay/backend/pkg/logger" + "openreplay/backend/pkg/server/api" +) + +type handlersImpl struct { + log logger.Logger + responser *api.Responser + integrations service.Service + jsonSizeLimit int64 +} + +func NewHandlers(log logger.Logger, cfg *integrationsCfg.Config, responser *api.Responser, integrations service.Service) (api.Handlers, error) { + return &handlersImpl{ + log: log, + responser: responser, + integrations: integrations, + jsonSizeLimit: cfg.JsonSizeLimit, + }, nil +} + +func (e *handlersImpl) GetAll() []*api.Description { + return []*api.Description{ + {"/v1/integrations/{name}/{project}", e.createIntegration, "POST"}, + {"/v1/integrations/{name}/{project}", e.getIntegration, "GET"}, + {"/v1/integrations/{name}/{project}", e.updateIntegration, "PATCH"}, + {"/v1/integrations/{name}/{project}", e.deleteIntegration, "DELETE"}, + {"/v1/integrations/{name}/{project}/data/{session}", e.getIntegrationData, "GET"}, + } +} + +func getIntegrationsArgs(r *http.Request) (string, uint64, error) { + vars := mux.Vars(r) + name := vars["name"] + if name == "" { + return "", 0, fmt.Errorf("empty integration name") + } + project := vars["project"] + if project == "" { + return "", 0, fmt.Errorf("project id is empty") + } + projID, err := strconv.ParseUint(project, 10, 64) + if err != nil || projID <= 0 { + return "", 0, fmt.Errorf("invalid project id") + } + return name, projID, nil +} + +func getIntegrationSession(r *http.Request) (uint64, error) { + vars := mux.Vars(r) + session := vars["session"] + if session == "" { + return 0, fmt.Errorf("session id is empty") + } + sessID, err := strconv.ParseUint(session, 10, 64) + if err != nil || sessID <= 0 { + return 0, fmt.Errorf("invalid session id") + } + return sessID, nil +} + +type IntegrationRequest struct { + IntegrationData map[string]string `json:"data"` +} + +func (e *handlersImpl) createIntegration(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + bodySize := 0 + + bodyBytes, err := api.ReadBody(e.log, w, r, e.jsonSizeLimit) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) + return + } + bodySize = len(bodyBytes) + + integration, project, err := getIntegrationsArgs(r) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + return + } + + req := &IntegrationRequest{} + if err := json.Unmarshal(bodyBytes, req); err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + return + } + + if err := e.integrations.AddIntegration(project, integration, req.IntegrationData); err != nil { + if strings.Contains(err.Error(), "failed to validate") { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusUnprocessableEntity, err, startTime, r.URL.Path, bodySize) + } else { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + } + return + } + e.responser.ResponseOK(e.log, r.Context(), w, startTime, r.URL.Path, bodySize) +} + +func (e *handlersImpl) getIntegration(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + bodySize := 0 + + integration, project, err := getIntegrationsArgs(r) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + return + } + + intParams, err := e.integrations.GetIntegration(project, integration) + if err != nil { + if strings.Contains(err.Error(), "no rows in result set") { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusNotFound, err, startTime, r.URL.Path, bodySize) + } else { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + } + return + } + e.responser.ResponseWithJSON(e.log, r.Context(), w, intParams, startTime, r.URL.Path, bodySize) +} + +func (e *handlersImpl) updateIntegration(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + bodySize := 0 + + bodyBytes, err := api.ReadBody(e.log, w, r, e.jsonSizeLimit) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) + return + } + bodySize = len(bodyBytes) + + integration, project, err := getIntegrationsArgs(r) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + return + } + + req := &IntegrationRequest{} + if err := json.Unmarshal(bodyBytes, req); err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + return + } + + if err := e.integrations.UpdateIntegration(project, integration, req.IntegrationData); err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + return + } + e.responser.ResponseOK(e.log, r.Context(), w, startTime, r.URL.Path, bodySize) +} + +func (e *handlersImpl) deleteIntegration(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + bodySize := 0 + + integration, project, err := getIntegrationsArgs(r) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + return + } + + if err := e.integrations.DeleteIntegration(project, integration); err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + return + } + e.responser.ResponseOK(e.log, r.Context(), w, startTime, r.URL.Path, bodySize) +} + +func (e *handlersImpl) getIntegrationData(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + bodySize := 0 + + integration, project, err := getIntegrationsArgs(r) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + return + } + + session, err := getIntegrationSession(r) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + return + } + + url, err := e.integrations.GetSessionDataURL(project, integration, session) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + return + } + + resp := map[string]string{"url": url} + e.responser.ResponseWithJSON(e.log, r.Context(), w, resp, startTime, r.URL.Path, bodySize) +} diff --git a/backend/pkg/integrations/builder.go b/backend/pkg/integrations/builder.go index b151cc0e4..a8ed6c9fd 100644 --- a/backend/pkg/integrations/builder.go +++ b/backend/pkg/integrations/builder.go @@ -1,36 +1,51 @@ -package data_integration +package integrations import ( + "openreplay/backend/pkg/integrations/service" + "openreplay/backend/pkg/metrics/web" + "openreplay/backend/pkg/server/tracer" + "time" + "openreplay/backend/internal/config/integrations" "openreplay/backend/pkg/db/postgres/pool" - "openreplay/backend/pkg/flakeid" + integrationsAPI "openreplay/backend/pkg/integrations/api" "openreplay/backend/pkg/logger" - "openreplay/backend/pkg/objectstorage" "openreplay/backend/pkg/objectstorage/store" - "openreplay/backend/pkg/spot/auth" + "openreplay/backend/pkg/server/api" + "openreplay/backend/pkg/server/auth" + "openreplay/backend/pkg/server/limiter" ) type ServiceBuilder struct { - Flaker *flakeid.Flaker - ObjStorage objectstorage.ObjectStorage - Auth auth.Auth - Integrator Service + Auth auth.Auth + RateLimiter *limiter.UserRateLimiter + AuditTrail tracer.Tracer + IntegrationsAPI api.Handlers } -func NewServiceBuilder(log logger.Logger, cfg *integrations.Config, pgconn pool.Pool) (*ServiceBuilder, error) { +func NewServiceBuilder(log logger.Logger, cfg *integrations.Config, webMetrics web.Web, pgconn pool.Pool) (*ServiceBuilder, error) { objStore, err := store.NewStore(&cfg.ObjectsConfig) if err != nil { return nil, err } - integrator, err := NewService(log, pgconn, objStore) + integrator, err := service.NewService(log, pgconn, objStore) if err != nil { return nil, err } - flaker := flakeid.NewFlaker(cfg.WorkerID) - return &ServiceBuilder{ - Flaker: flaker, - ObjStorage: objStore, - Auth: auth.NewAuth(log, cfg.JWTSecret, "", pgconn), - Integrator: integrator, - }, nil + responser := api.NewResponser(webMetrics) + handlers, err := integrationsAPI.NewHandlers(log, cfg, responser, integrator) + if err != nil { + return nil, err + } + auditrail, err := tracer.NewTracer(log, pgconn) + if err != nil { + return nil, err + } + builder := &ServiceBuilder{ + Auth: auth.NewAuth(log, cfg.JWTSecret, "", pgconn, nil), + RateLimiter: limiter.NewUserRateLimiter(10, 30, 1*time.Minute, 5*time.Minute), + AuditTrail: auditrail, + IntegrationsAPI: handlers, + } + return builder, nil } diff --git a/backend/pkg/integrations/handlers.go b/backend/pkg/integrations/handlers.go deleted file mode 100644 index 1e784e86b..000000000 --- a/backend/pkg/integrations/handlers.go +++ /dev/null @@ -1,233 +0,0 @@ -package data_integration - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "strconv" - "strings" - "time" - - "github.com/gorilla/mux" - - metrics "openreplay/backend/pkg/metrics/heuristics" -) - -func getIntegrationsArgs(r *http.Request) (string, uint64, error) { - vars := mux.Vars(r) - name := vars["name"] - if name == "" { - return "", 0, fmt.Errorf("empty integration name") - } - project := vars["project"] - if project == "" { - return "", 0, fmt.Errorf("project id is empty") - } - projID, err := strconv.ParseUint(project, 10, 64) - if err != nil || projID <= 0 { - return "", 0, fmt.Errorf("invalid project id") - } - return name, projID, nil -} - -func getIntegrationSession(r *http.Request) (uint64, error) { - vars := mux.Vars(r) - session := vars["session"] - if session == "" { - return 0, fmt.Errorf("session id is empty") - } - sessID, err := strconv.ParseUint(session, 10, 64) - if err != nil || sessID <= 0 { - return 0, fmt.Errorf("invalid session id") - } - return sessID, nil -} - -type IntegrationRequest struct { - IntegrationData map[string]string `json:"data"` -} - -func (e *Router) createIntegration(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - bodySize := 0 - - bodyBytes, err := e.readBody(w, r, e.cfg.JsonSizeLimit) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) - return - } - bodySize = len(bodyBytes) - - integration, project, err := getIntegrationsArgs(r) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) - return - } - - req := &IntegrationRequest{} - if err := json.Unmarshal(bodyBytes, req); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) - return - } - - if err := e.services.Integrator.AddIntegration(project, integration, req.IntegrationData); err != nil { - if strings.Contains(err.Error(), "failed to validate") { - e.ResponseWithError(r.Context(), w, http.StatusUnprocessableEntity, err, startTime, r.URL.Path, bodySize) - } else { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) - } - return - } - e.ResponseOK(r.Context(), w, startTime, r.URL.Path, bodySize) -} - -func (e *Router) getIntegration(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - bodySize := 0 - - integration, project, err := getIntegrationsArgs(r) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) - return - } - - intParams, err := e.services.Integrator.GetIntegration(project, integration) - if err != nil { - if strings.Contains(err.Error(), "no rows in result set") { - e.ResponseWithError(r.Context(), w, http.StatusNotFound, err, startTime, r.URL.Path, bodySize) - } else { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) - } - return - } - e.ResponseWithJSON(r.Context(), w, intParams, startTime, r.URL.Path, bodySize) -} - -func (e *Router) updateIntegration(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - bodySize := 0 - - bodyBytes, err := e.readBody(w, r, e.cfg.JsonSizeLimit) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) - return - } - bodySize = len(bodyBytes) - - integration, project, err := getIntegrationsArgs(r) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) - return - } - - req := &IntegrationRequest{} - if err := json.Unmarshal(bodyBytes, req); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) - return - } - - if err := e.services.Integrator.UpdateIntegration(project, integration, req.IntegrationData); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) - return - } - e.ResponseOK(r.Context(), w, startTime, r.URL.Path, bodySize) -} - -func (e *Router) deleteIntegration(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - bodySize := 0 - - integration, project, err := getIntegrationsArgs(r) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) - return - } - - if err := e.services.Integrator.DeleteIntegration(project, integration); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) - return - } - e.ResponseOK(r.Context(), w, startTime, r.URL.Path, bodySize) -} - -func (e *Router) getIntegrationData(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - bodySize := 0 - - integration, project, err := getIntegrationsArgs(r) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) - return - } - - session, err := getIntegrationSession(r) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) - return - } - - url, err := e.services.Integrator.GetSessionDataURL(project, integration, session) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) - return - } - - resp := map[string]string{"url": url} - e.ResponseWithJSON(r.Context(), w, resp, startTime, r.URL.Path, bodySize) -} - -func recordMetrics(requestStart time.Time, url string, code, bodySize int) { - if bodySize > 0 { - metrics.RecordRequestSize(float64(bodySize), url, code) - } - metrics.IncreaseTotalRequests() - metrics.RecordRequestDuration(float64(time.Now().Sub(requestStart).Milliseconds()), url, code) -} - -func (e *Router) readBody(w http.ResponseWriter, r *http.Request, limit int64) ([]byte, error) { - body := http.MaxBytesReader(w, r.Body, limit) - bodyBytes, err := io.ReadAll(body) - - // Close body - if closeErr := body.Close(); closeErr != nil { - e.log.Warn(r.Context(), "error while closing request body: %s", closeErr) - } - if err != nil { - return nil, err - } - return bodyBytes, nil -} - -func (e *Router) ResponseOK(ctx context.Context, w http.ResponseWriter, requestStart time.Time, url string, bodySize int) { - w.WriteHeader(http.StatusOK) - e.log.Info(ctx, "response ok") - recordMetrics(requestStart, url, http.StatusOK, bodySize) -} - -func (e *Router) ResponseWithJSON(ctx context.Context, w http.ResponseWriter, res interface{}, requestStart time.Time, url string, bodySize int) { - e.log.Info(ctx, "response ok") - body, err := json.Marshal(res) - if err != nil { - e.log.Error(ctx, "can't marshal response: %s", err) - } - w.Header().Set("Content-Type", "application/json") - w.Write(body) - recordMetrics(requestStart, url, http.StatusOK, bodySize) -} - -type response struct { - Error string `json:"error"` -} - -func (e *Router) ResponseWithError(ctx context.Context, w http.ResponseWriter, code int, err error, requestStart time.Time, url string, bodySize int) { - e.log.Error(ctx, "response error, code: %d, error: %s", code, err) - body, err := json.Marshal(&response{err.Error()}) - if err != nil { - e.log.Error(ctx, "can't marshal response: %s", err) - } - w.WriteHeader(code) - w.Write(body) - recordMetrics(requestStart, url, code, bodySize) -} diff --git a/backend/pkg/integrations/router.go b/backend/pkg/integrations/router.go deleted file mode 100644 index a405c6065..000000000 --- a/backend/pkg/integrations/router.go +++ /dev/null @@ -1,170 +0,0 @@ -package data_integration - -import ( - "bytes" - "fmt" - "io" - "net/http" - "time" - - "github.com/docker/distribution/context" - "github.com/gorilla/mux" - - integration "openreplay/backend/internal/config/integrations" - "openreplay/backend/internal/http/util" - "openreplay/backend/pkg/logger" - limiter "openreplay/backend/pkg/spot/api" - "openreplay/backend/pkg/spot/auth" -) - -type Router struct { - log logger.Logger - cfg *integration.Config - router *mux.Router - services *ServiceBuilder - limiter *limiter.UserRateLimiter -} - -func NewRouter(cfg *integration.Config, log logger.Logger, services *ServiceBuilder) (*Router, error) { - switch { - case cfg == nil: - return nil, fmt.Errorf("config is empty") - case services == nil: - return nil, fmt.Errorf("services is empty") - case log == nil: - return nil, fmt.Errorf("logger is empty") - } - e := &Router{ - log: log, - cfg: cfg, - services: services, - limiter: limiter.NewUserRateLimiter(10, 30, 1*time.Minute, 5*time.Minute), - } - e.init() - return e, nil -} - -func (e *Router) init() { - e.router = mux.NewRouter() - - // Root route - e.router.HandleFunc("/", e.ping) - - e.router.HandleFunc("/v1/integrations/{name}/{project}", e.createIntegration).Methods("POST", "OPTIONS") - e.router.HandleFunc("/v1/integrations/{name}/{project}", e.getIntegration).Methods("GET", "OPTIONS") - e.router.HandleFunc("/v1/integrations/{name}/{project}", e.updateIntegration).Methods("PATCH", "OPTIONS") - e.router.HandleFunc("/v1/integrations/{name}/{project}", e.deleteIntegration).Methods("DELETE", "OPTIONS") - e.router.HandleFunc("/v1/integrations/{name}/{project}/data/{session}", e.getIntegrationData).Methods("GET", "OPTIONS") - - // CORS middleware - e.router.Use(e.corsMiddleware) - e.router.Use(e.authMiddleware) - e.router.Use(e.rateLimitMiddleware) - e.router.Use(e.actionMiddleware) -} - -func (e *Router) ping(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) -} - -func (e *Router) corsMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/" { - next.ServeHTTP(w, r) - } - if e.cfg.UseAccessControlHeaders { - // Prepare headers for preflight requests - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "POST,GET,PATCH,DELETE") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type,Authorization,Content-Encoding") - } - if r.Method == http.MethodOptions { - w.Header().Set("Cache-Control", "max-age=86400") - w.WriteHeader(http.StatusOK) - return - } - r = r.WithContext(context.WithValues(r.Context(), map[string]interface{}{"httpMethod": r.Method, "url": util.SafeString(r.URL.Path)})) - - next.ServeHTTP(w, r) - }) -} - -func (e *Router) authMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/" { - next.ServeHTTP(w, r) - } - - // Check if the request is authorized - user, err := e.services.Auth.IsAuthorized(r.Header.Get("Authorization"), nil, false) - if err != nil { - e.log.Warn(r.Context(), "Unauthorized request: %s", err) - w.WriteHeader(http.StatusUnauthorized) - return - } - - r = r.WithContext(context.WithValues(r.Context(), map[string]interface{}{"userData": user})) - next.ServeHTTP(w, r) - }) -} - -func (e *Router) rateLimitMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/" { - next.ServeHTTP(w, r) - } - user := r.Context().Value("userData").(*auth.User) - rl := e.limiter.GetRateLimiter(user.ID) - - if !rl.Allow() { - http.Error(w, "Too Many Requests", http.StatusTooManyRequests) - return - } - next.ServeHTTP(w, r) - }) -} - -type statusWriter struct { - http.ResponseWriter - statusCode int -} - -func (w *statusWriter) WriteHeader(statusCode int) { - w.statusCode = statusCode - w.ResponseWriter.WriteHeader(statusCode) -} - -func (w *statusWriter) Write(b []byte) (int, error) { - if w.statusCode == 0 { - w.statusCode = http.StatusOK - } - return w.ResponseWriter.Write(b) -} - -func (e *Router) actionMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/" { - next.ServeHTTP(w, r) - } - // Read body and restore the io.ReadCloser to its original state - bodyBytes, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "can't read body", http.StatusBadRequest) - return - } - r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - // Use custom response writer to get the status code - sw := &statusWriter{ResponseWriter: w} - // Serve the request - next.ServeHTTP(sw, r) - e.logRequest(r, bodyBytes, sw.statusCode) - }) -} - -func (e *Router) logRequest(r *http.Request, bodyBytes []byte, statusCode int) { - e.log.Info(r.Context(), "Request: %s %s %s %d", r.Method, r.URL.Path, bodyBytes, statusCode) -} - -func (e *Router) GetHandler() http.Handler { - return e.router -} diff --git a/backend/pkg/integrations/service.go b/backend/pkg/integrations/service/service.go similarity index 99% rename from backend/pkg/integrations/service.go rename to backend/pkg/integrations/service/service.go index 5bcac2ec5..41c7943d5 100644 --- a/backend/pkg/integrations/service.go +++ b/backend/pkg/integrations/service/service.go @@ -1,4 +1,4 @@ -package data_integration +package service import ( "bytes" diff --git a/backend/pkg/metrics/http/metrics.go b/backend/pkg/metrics/http/metrics.go deleted file mode 100644 index 7a835d7f6..000000000 --- a/backend/pkg/metrics/http/metrics.go +++ /dev/null @@ -1,55 +0,0 @@ -package http - -import ( - "github.com/prometheus/client_golang/prometheus" - "openreplay/backend/pkg/metrics/common" - "strconv" -) - -var httpRequestSize = prometheus.NewHistogramVec( - prometheus.HistogramOpts{ - Namespace: "http", - Name: "request_size_bytes", - Help: "A histogram displaying the size of each HTTP request in bytes.", - Buckets: common.DefaultSizeBuckets, - }, - []string{"url", "response_code"}, -) - -func RecordRequestSize(size float64, url string, code int) { - httpRequestSize.WithLabelValues(url, strconv.Itoa(code)).Observe(size) -} - -var httpRequestDuration = prometheus.NewHistogramVec( - prometheus.HistogramOpts{ - Namespace: "http", - Name: "request_duration_seconds", - Help: "A histogram displaying the duration of each HTTP request in seconds.", - Buckets: common.DefaultDurationBuckets, - }, - []string{"url", "response_code"}, -) - -func RecordRequestDuration(durMillis float64, url string, code int) { - httpRequestDuration.WithLabelValues(url, strconv.Itoa(code)).Observe(durMillis / 1000.0) -} - -var httpTotalRequests = prometheus.NewCounter( - prometheus.CounterOpts{ - Namespace: "http", - Name: "requests_total", - Help: "A counter displaying the number all HTTP requests.", - }, -) - -func IncreaseTotalRequests() { - httpTotalRequests.Inc() -} - -func List() []prometheus.Collector { - return []prometheus.Collector{ - httpRequestSize, - httpRequestDuration, - httpTotalRequests, - } -} diff --git a/backend/pkg/metrics/spot/spot.go b/backend/pkg/metrics/spot/spot.go index df5420a97..617559f67 100644 --- a/backend/pkg/metrics/spot/spot.go +++ b/backend/pkg/metrics/spot/spot.go @@ -1,53 +1,11 @@ package spot import ( - "strconv" - "github.com/prometheus/client_golang/prometheus" "openreplay/backend/pkg/metrics/common" ) -var spotRequestSize = prometheus.NewHistogramVec( - prometheus.HistogramOpts{ - Namespace: "spot", - Name: "request_size_bytes", - Help: "A histogram displaying the size of each HTTP request in bytes.", - Buckets: common.DefaultSizeBuckets, - }, - []string{"url", "response_code"}, -) - -func RecordRequestSize(size float64, url string, code int) { - spotRequestSize.WithLabelValues(url, strconv.Itoa(code)).Observe(size) -} - -var spotRequestDuration = prometheus.NewHistogramVec( - prometheus.HistogramOpts{ - Namespace: "spot", - Name: "request_duration_seconds", - Help: "A histogram displaying the duration of each HTTP request in seconds.", - Buckets: common.DefaultDurationBuckets, - }, - []string{"url", "response_code"}, -) - -func RecordRequestDuration(durMillis float64, url string, code int) { - spotRequestDuration.WithLabelValues(url, strconv.Itoa(code)).Observe(durMillis / 1000.0) -} - -var spotTotalRequests = prometheus.NewCounter( - prometheus.CounterOpts{ - Namespace: "spot", - Name: "requests_total", - Help: "A counter displaying the number all HTTP requests.", - }, -) - -func IncreaseTotalRequests() { - spotTotalRequests.Inc() -} - var spotOriginalVideoSize = prometheus.NewHistogram( prometheus.HistogramOpts{ Namespace: "spot", @@ -177,9 +135,6 @@ func RecordTranscodedVideoUploadDuration(durMillis float64) { func List() []prometheus.Collector { return []prometheus.Collector{ - spotRequestSize, - spotRequestDuration, - spotTotalRequests, spotOriginalVideoSize, spotCroppedVideoSize, spotVideosTotal, diff --git a/backend/pkg/metrics/web/metrics.go b/backend/pkg/metrics/web/metrics.go new file mode 100644 index 000000000..a86b67787 --- /dev/null +++ b/backend/pkg/metrics/web/metrics.go @@ -0,0 +1,84 @@ +package web + +import ( + "strconv" + + "github.com/prometheus/client_golang/prometheus" + + "openreplay/backend/pkg/metrics/common" +) + +type Web interface { + RecordRequestSize(size float64, url string, code int) + RecordRequestDuration(durMillis float64, url string, code int) + IncreaseTotalRequests() + List() []prometheus.Collector +} + +type webImpl struct { + httpRequestSize *prometheus.HistogramVec + httpRequestDuration *prometheus.HistogramVec + httpTotalRequests prometheus.Counter +} + +func New(serviceName string) Web { + return &webImpl{ + httpRequestSize: newRequestSizeMetric(serviceName), + httpRequestDuration: newRequestDurationMetric(serviceName), + httpTotalRequests: newTotalRequestsMetric(serviceName), + } +} + +func (w *webImpl) List() []prometheus.Collector { + return []prometheus.Collector{ + w.httpRequestSize, + w.httpRequestDuration, + w.httpTotalRequests, + } +} + +func newRequestSizeMetric(serviceName string) *prometheus.HistogramVec { + return prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: serviceName, + Name: "request_size_bytes", + Help: "A histogram displaying the size of each HTTP request in bytes.", + Buckets: common.DefaultSizeBuckets, + }, + []string{"url", "response_code"}, + ) +} + +func (w *webImpl) RecordRequestSize(size float64, url string, code int) { + w.httpRequestSize.WithLabelValues(url, strconv.Itoa(code)).Observe(size) +} + +func newRequestDurationMetric(serviceName string) *prometheus.HistogramVec { + return prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: serviceName, + Name: "request_duration_seconds", + Help: "A histogram displaying the duration of each HTTP request in seconds.", + Buckets: common.DefaultDurationBuckets, + }, + []string{"url", "response_code"}, + ) +} + +func (w *webImpl) RecordRequestDuration(durMillis float64, url string, code int) { + w.httpRequestDuration.WithLabelValues(url, strconv.Itoa(code)).Observe(durMillis / 1000.0) +} + +func newTotalRequestsMetric(serviceName string) prometheus.Counter { + return prometheus.NewCounter( + prometheus.CounterOpts{ + Namespace: serviceName, + Name: "requests_total", + Help: "A counter displaying the number all HTTP requests.", + }, + ) +} + +func (w *webImpl) IncreaseTotalRequests() { + w.httpTotalRequests.Inc() +} diff --git a/backend/pkg/server/api/body-reader.go b/backend/pkg/server/api/body-reader.go new file mode 100644 index 000000000..1cfcc92d4 --- /dev/null +++ b/backend/pkg/server/api/body-reader.go @@ -0,0 +1,59 @@ +package api + +import ( + "fmt" + "io" + "net/http" + + "github.com/klauspost/compress/gzip" + + "openreplay/backend/pkg/logger" +) + +func ReadBody(log logger.Logger, w http.ResponseWriter, r *http.Request, limit int64) ([]byte, error) { + body := http.MaxBytesReader(w, r.Body, limit) + bodyBytes, err := io.ReadAll(body) + + // Close body + if closeErr := body.Close(); closeErr != nil { + log.Warn(r.Context(), "error while closing request body: %s", closeErr) + } + if err != nil { + return nil, err + } + return bodyBytes, nil +} + +func ReadCompressedBody(log logger.Logger, w http.ResponseWriter, r *http.Request, limit int64) ([]byte, error) { + body := http.MaxBytesReader(w, r.Body, limit) + var ( + bodyBytes []byte + err error + ) + + // Check if body is gzipped and decompress it + if r.Header.Get("Content-Encoding") == "gzip" { + reader, err := gzip.NewReader(body) + if err != nil { + return nil, fmt.Errorf("can't create gzip reader: %s", err) + } + bodyBytes, err = io.ReadAll(reader) + if err != nil { + return nil, fmt.Errorf("can't read gzip body: %s", err) + } + if err := reader.Close(); err != nil { + log.Warn(r.Context(), "can't close gzip reader: %s", err) + } + } else { + bodyBytes, err = io.ReadAll(body) + } + + // Close body + if closeErr := body.Close(); closeErr != nil { + log.Warn(r.Context(), "error while closing request body: %s", closeErr) + } + if err != nil { + return nil, err + } + return bodyBytes, nil +} diff --git a/backend/pkg/server/api/handlers.go b/backend/pkg/server/api/handlers.go new file mode 100644 index 000000000..c7e6f0811 --- /dev/null +++ b/backend/pkg/server/api/handlers.go @@ -0,0 +1,13 @@ +package api + +import "net/http" + +type Description struct { + Path string + Handler http.HandlerFunc + Method string +} + +type Handlers interface { + GetAll() []*Description +} diff --git a/backend/pkg/server/api/middleware.go b/backend/pkg/server/api/middleware.go new file mode 100644 index 000000000..423e7e0d9 --- /dev/null +++ b/backend/pkg/server/api/middleware.go @@ -0,0 +1,41 @@ +package api + +import ( + "net/http" + + ctxStore "github.com/docker/distribution/context" + "openreplay/backend/internal/http/util" +) + +func (e *routerImpl) health(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) +} + +func (e *routerImpl) healthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/" { + w.WriteHeader(http.StatusOK) + return + } + next.ServeHTTP(w, r) + }) +} + +func (e *routerImpl) corsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if e.cfg.UseAccessControlHeaders { + // Prepare headers for preflight requests + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "POST,GET,PATCH,DELETE") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type,Authorization,Content-Encoding") + } + if r.Method == http.MethodOptions { + w.Header().Set("Cache-Control", "max-age=86400") + w.WriteHeader(http.StatusOK) + return + } + + r = r.WithContext(ctxStore.WithValues(r.Context(), map[string]interface{}{"httpMethod": r.Method, "url": util.SafeString(r.URL.Path)})) + next.ServeHTTP(w, r) + }) +} diff --git a/backend/pkg/server/api/responser.go b/backend/pkg/server/api/responser.go new file mode 100644 index 000000000..5611e9856 --- /dev/null +++ b/backend/pkg/server/api/responser.go @@ -0,0 +1,61 @@ +package api + +import ( + "context" + "encoding/json" + "net/http" + "time" + + "openreplay/backend/pkg/logger" + "openreplay/backend/pkg/metrics/web" +) + +type Responser struct { + metrics web.Web +} + +func NewResponser(webMetrics web.Web) *Responser { + return &Responser{ + metrics: webMetrics, + } +} + +type response struct { + Error string `json:"error"` +} + +func (r *Responser) ResponseOK(log logger.Logger, ctx context.Context, w http.ResponseWriter, requestStart time.Time, url string, bodySize int) { + w.WriteHeader(http.StatusOK) + log.Info(ctx, "response ok") + r.recordMetrics(requestStart, url, http.StatusOK, bodySize) +} + +func (r *Responser) ResponseWithJSON(log logger.Logger, ctx context.Context, w http.ResponseWriter, res interface{}, requestStart time.Time, url string, bodySize int) { + log.Info(ctx, "response ok") + body, err := json.Marshal(res) + if err != nil { + log.Error(ctx, "can't marshal response: %s", err) + } + w.Header().Set("Content-Type", "application/json") + w.Write(body) + r.recordMetrics(requestStart, url, http.StatusOK, bodySize) +} + +func (r *Responser) ResponseWithError(log logger.Logger, ctx context.Context, w http.ResponseWriter, code int, err error, requestStart time.Time, url string, bodySize int) { + log.Error(ctx, "response error, code: %d, error: %s", code, err) + body, err := json.Marshal(&response{err.Error()}) + if err != nil { + log.Error(ctx, "can't marshal response: %s", err) + } + w.WriteHeader(code) + w.Write(body) + r.recordMetrics(requestStart, url, code, bodySize) +} + +func (r *Responser) recordMetrics(requestStart time.Time, url string, code, bodySize int) { + if bodySize > 0 { + r.metrics.RecordRequestSize(float64(bodySize), url, code) + } + r.metrics.IncreaseTotalRequests() + r.metrics.RecordRequestDuration(float64(time.Now().Sub(requestStart).Milliseconds()), url, code) +} diff --git a/backend/pkg/server/api/router.go b/backend/pkg/server/api/router.go new file mode 100644 index 000000000..13f615e5e --- /dev/null +++ b/backend/pkg/server/api/router.go @@ -0,0 +1,69 @@ +package api + +import ( + "fmt" + "net/http" + + "github.com/gorilla/mux" + + "openreplay/backend/internal/config/common" + "openreplay/backend/pkg/logger" +) + +type Router interface { + AddHandlers(prefix string, handlers ...Handlers) + AddMiddlewares(middlewares ...func(http.Handler) http.Handler) + Get() http.Handler +} + +type routerImpl struct { + log logger.Logger + cfg *common.HTTP + router *mux.Router +} + +func NewRouter(cfg *common.HTTP, log logger.Logger) (Router, error) { + switch { + case cfg == nil: + return nil, fmt.Errorf("config is empty") + case log == nil: + return nil, fmt.Errorf("logger is empty") + } + e := &routerImpl{ + log: log, + cfg: cfg, + router: mux.NewRouter(), + } + e.initRouter() + return e, nil +} + +func (e *routerImpl) initRouter() { + e.router.HandleFunc("/", e.health) + // Default middlewares + e.router.Use(e.healthMiddleware) + e.router.Use(e.corsMiddleware) +} + +const NoPrefix = "" + +func (e *routerImpl) AddHandlers(prefix string, handlers ...Handlers) { + for _, handlersSet := range handlers { + for _, handler := range handlersSet.GetAll() { + e.router.HandleFunc(handler.Path, handler.Handler).Methods(handler.Method, "OPTIONS") + if prefix != NoPrefix { + e.router.HandleFunc(prefix+handler.Path, handler.Handler).Methods(handler.Method, "OPTIONS") + } + } + } +} + +func (e *routerImpl) AddMiddlewares(middlewares ...func(http.Handler) http.Handler) { + for _, middleware := range middlewares { + e.router.Use(middleware) + } +} + +func (e *routerImpl) Get() http.Handler { + return e.router +} diff --git a/backend/pkg/spot/auth/auth.go b/backend/pkg/server/auth/auth.go similarity index 75% rename from backend/pkg/spot/auth/auth.go rename to backend/pkg/server/auth/auth.go index 498e16e0a..fe817bce0 100644 --- a/backend/pkg/spot/auth/auth.go +++ b/backend/pkg/server/auth/auth.go @@ -2,16 +2,20 @@ package auth import ( "fmt" + "net/http" "strings" "github.com/golang-jwt/jwt/v5" "openreplay/backend/pkg/db/postgres/pool" "openreplay/backend/pkg/logger" + "openreplay/backend/pkg/server/keys" + "openreplay/backend/pkg/server/user" ) type Auth interface { - IsAuthorized(authHeader string, permissions []string, isExtension bool) (*User, error) + IsAuthorized(authHeader string, permissions []string, isExtension bool) (*user.User, error) + Middleware(next http.Handler) http.Handler } type authImpl struct { @@ -19,18 +23,20 @@ type authImpl struct { secret string spotSecret string pgconn pool.Pool + keys keys.Keys } -func NewAuth(log logger.Logger, jwtSecret, jwtSpotSecret string, conn pool.Pool) Auth { +func NewAuth(log logger.Logger, jwtSecret, jwtSpotSecret string, conn pool.Pool, keys keys.Keys) Auth { return &authImpl{ log: log, secret: jwtSecret, spotSecret: jwtSpotSecret, pgconn: conn, + keys: keys, } } -func parseJWT(authHeader, secret string) (*JWTClaims, error) { +func parseJWT(authHeader, secret string) (*user.JWTClaims, error) { if authHeader == "" { return nil, fmt.Errorf("authorization header missing") } @@ -40,7 +46,7 @@ func parseJWT(authHeader, secret string) (*JWTClaims, error) { } tokenString := tokenParts[1] - claims := &JWTClaims{} + claims := &user.JWTClaims{} token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { return []byte(secret), nil diff --git a/backend/pkg/spot/auth/authorizer.go b/backend/pkg/server/auth/authorizer.go similarity index 77% rename from backend/pkg/spot/auth/authorizer.go rename to backend/pkg/server/auth/authorizer.go index 63067fee9..2f2a2571e 100644 --- a/backend/pkg/spot/auth/authorizer.go +++ b/backend/pkg/server/auth/authorizer.go @@ -1,6 +1,8 @@ package auth -func (a *authImpl) IsAuthorized(authHeader string, permissions []string, isExtension bool) (*User, error) { +import "openreplay/backend/pkg/server/user" + +func (a *authImpl) IsAuthorized(authHeader string, permissions []string, isExtension bool) (*user.User, error) { secret := a.secret if isExtension { secret = a.spotSecret diff --git a/backend/pkg/server/auth/middleware.go b/backend/pkg/server/auth/middleware.go new file mode 100644 index 000000000..a6a9f7fcb --- /dev/null +++ b/backend/pkg/server/auth/middleware.go @@ -0,0 +1,65 @@ +package auth + +import ( + "net/http" + + "github.com/gorilla/mux" + + ctxStore "github.com/docker/distribution/context" +) + +func (e *authImpl) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user, err := e.IsAuthorized(r.Header.Get("Authorization"), getPermissions(r.URL.Path), e.isExtensionRequest(r)) + if err != nil { + if !e.isSpotWithKeyRequest(r) { + e.log.Warn(r.Context(), "Unauthorized request, wrong jwt token: %s", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + + user, err = e.keys.IsValid(r.URL.Query().Get("key")) + if err != nil { + e.log.Warn(r.Context(), "Unauthorized request, wrong public key: %s", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + } + + r = r.WithContext(ctxStore.WithValues(r.Context(), map[string]interface{}{"userData": user})) + next.ServeHTTP(w, r) + }) +} + +func (e *authImpl) isExtensionRequest(r *http.Request) bool { + pathTemplate, err := mux.CurrentRoute(r).GetPathTemplate() + if err != nil { + e.log.Error(r.Context(), "failed to get path template: %s", err) + } else { + if pathTemplate == "/v1/ping" || + (pathTemplate == "/v1/spots" && r.Method == "POST") || + (pathTemplate == "/v1/spots/{id}/uploaded" && r.Method == "POST") { + return true + } + } + return false +} + +func (e *authImpl) isSpotWithKeyRequest(r *http.Request) bool { + if e.keys == nil { + return false + } + pathTemplate, err := mux.CurrentRoute(r).GetPathTemplate() + if err != nil { + return false + } + getSpotPrefix := "/v1/spots/{id}" // GET + addCommentPrefix := "/v1/spots/{id}/comment" // POST + getStatusPrefix := "/v1/spots/{id}/status" // GET + if (pathTemplate == getSpotPrefix && r.Method == "GET") || + (pathTemplate == addCommentPrefix && r.Method == "POST") || + (pathTemplate == getStatusPrefix && r.Method == "GET") { + return true + } + return false +} diff --git a/backend/pkg/spot/api/permissions.go b/backend/pkg/server/auth/permissions.go similarity index 82% rename from backend/pkg/spot/api/permissions.go rename to backend/pkg/server/auth/permissions.go index f8392bb70..6edc34199 100644 --- a/backend/pkg/spot/api/permissions.go +++ b/backend/pkg/server/auth/permissions.go @@ -1,4 +1,4 @@ -package api +package auth func getPermissions(urlPath string) []string { return nil diff --git a/backend/pkg/spot/auth/storage.go b/backend/pkg/server/auth/storage.go similarity index 57% rename from backend/pkg/spot/auth/storage.go rename to backend/pkg/server/auth/storage.go index 0647af1be..9097f0239 100644 --- a/backend/pkg/spot/auth/storage.go +++ b/backend/pkg/server/auth/storage.go @@ -3,10 +3,11 @@ package auth import ( "fmt" "openreplay/backend/pkg/db/postgres/pool" + "openreplay/backend/pkg/server/user" "strings" ) -func authUser(conn pool.Pool, userID, tenantID, jwtIAT int, isExtension bool) (*User, error) { +func authUser(conn pool.Pool, userID, tenantID, jwtIAT int, isExtension bool) (*user.User, error) { sql := ` SELECT user_id, name, email, EXTRACT(epoch FROM spot_jwt_iat)::BIGINT AS spot_jwt_iat FROM public.users @@ -15,12 +16,19 @@ func authUser(conn pool.Pool, userID, tenantID, jwtIAT int, isExtension bool) (* if !isExtension { sql = strings.ReplaceAll(sql, "spot_jwt_iat", "jwt_iat") } - user := &User{TenantID: 1, AuthMethod: "jwt"} - if err := conn.QueryRow(sql, userID).Scan(&user.ID, &user.Name, &user.Email, &user.JwtIat); err != nil { + newUser := &user.User{TenantID: 1, AuthMethod: "jwt"} + if err := conn.QueryRow(sql, userID).Scan(&newUser.ID, &newUser.Name, &newUser.Email, &newUser.JwtIat); err != nil { return nil, fmt.Errorf("user not found") } - if user.JwtIat == 0 || abs(jwtIAT-user.JwtIat) > 1 { + if newUser.JwtIat == 0 || abs(jwtIAT-newUser.JwtIat) > 1 { return nil, fmt.Errorf("token has been updated") } - return user, nil + return newUser, nil +} + +func abs(x int) int { + if x < 0 { + return -x + } + return x } diff --git a/backend/pkg/spot/service/public_key.go b/backend/pkg/server/keys/public_key.go similarity index 89% rename from backend/pkg/spot/service/public_key.go rename to backend/pkg/server/keys/public_key.go index 28eab5aee..c13631f93 100644 --- a/backend/pkg/spot/service/public_key.go +++ b/backend/pkg/server/keys/public_key.go @@ -1,15 +1,15 @@ -package service +package keys import ( "context" "fmt" + "openreplay/backend/pkg/server/user" "time" "github.com/rs/xid" "openreplay/backend/pkg/db/postgres/pool" "openreplay/backend/pkg/logger" - "openreplay/backend/pkg/spot/auth" ) type Key struct { @@ -22,9 +22,9 @@ type Key struct { } type Keys interface { - Set(spotID, expiration uint64, user *auth.User) (*Key, error) - Get(spotID uint64, user *auth.User) (*Key, error) - IsValid(key string) (*auth.User, error) + Set(spotID, expiration uint64, user *user.User) (*Key, error) + Get(spotID uint64, user *user.User) (*Key, error) + IsValid(key string) (*user.User, error) } type keysImpl struct { @@ -32,7 +32,7 @@ type keysImpl struct { conn pool.Pool } -func (k *keysImpl) Set(spotID, expiration uint64, user *auth.User) (*Key, error) { +func (k *keysImpl) Set(spotID, expiration uint64, user *user.User) (*Key, error) { switch { case spotID == 0: return nil, fmt.Errorf("spotID is required") @@ -89,7 +89,7 @@ func (k *keysImpl) Set(spotID, expiration uint64, user *auth.User) (*Key, error) return key, nil } -func (k *keysImpl) Get(spotID uint64, user *auth.User) (*Key, error) { +func (k *keysImpl) Get(spotID uint64, user *user.User) (*Key, error) { switch { case spotID == 0: return nil, fmt.Errorf("spotID is required") @@ -114,7 +114,7 @@ func (k *keysImpl) Get(spotID uint64, user *auth.User) (*Key, error) { return key, nil } -func (k *keysImpl) IsValid(key string) (*auth.User, error) { +func (k *keysImpl) IsValid(key string) (*user.User, error) { if key == "" { return nil, fmt.Errorf("key is required") } @@ -133,7 +133,7 @@ func (k *keysImpl) IsValid(key string) (*auth.User, error) { return nil, fmt.Errorf("key is expired") } // Get user info by userID - user := &auth.User{ID: userID, AuthMethod: "public-key"} + user := &user.User{ID: userID, AuthMethod: "public-key"} // We don't need tenantID here if err := k.conn.QueryRow(getUserSQL, userID).Scan(&user.TenantID, &user.Name, &user.Email); err != nil { k.log.Error(context.Background(), "failed to get user: %v", err) diff --git a/backend/pkg/spot/service/user.go b/backend/pkg/server/keys/user.go similarity index 87% rename from backend/pkg/spot/service/user.go rename to backend/pkg/server/keys/user.go index 1f2b16c33..6db03c3a1 100644 --- a/backend/pkg/spot/service/user.go +++ b/backend/pkg/server/keys/user.go @@ -1,3 +1,3 @@ -package service +package keys var getUserSQL = `SELECT 1, name, email FROM public.users WHERE user_id = $1 AND deleted_at IS NULL LIMIT 1` diff --git a/backend/pkg/spot/api/limiter.go b/backend/pkg/server/limiter/limiter.go similarity index 99% rename from backend/pkg/spot/api/limiter.go rename to backend/pkg/server/limiter/limiter.go index 004f8be86..b72a4d7b3 100644 --- a/backend/pkg/spot/api/limiter.go +++ b/backend/pkg/server/limiter/limiter.go @@ -1,4 +1,4 @@ -package api +package limiter import ( "sync" diff --git a/backend/pkg/server/limiter/middleware.go b/backend/pkg/server/limiter/middleware.go new file mode 100644 index 000000000..4d50161ce --- /dev/null +++ b/backend/pkg/server/limiter/middleware.go @@ -0,0 +1,24 @@ +package limiter + +import ( + "net/http" + "openreplay/backend/pkg/server/user" +) + +func (rl *UserRateLimiter) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userContext := r.Context().Value("userData") + if userContext == nil { + next.ServeHTTP(w, r) + return + } + authUser := userContext.(*user.User) + rl := rl.GetRateLimiter(authUser.ID) + + if !rl.Allow() { + http.Error(w, "Too Many Requests", http.StatusTooManyRequests) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/backend/pkg/server/server.go b/backend/pkg/server/server.go new file mode 100644 index 000000000..eb7208742 --- /dev/null +++ b/backend/pkg/server/server.go @@ -0,0 +1,75 @@ +package server + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "golang.org/x/net/http2" + + "openreplay/backend/internal/config/common" + "openreplay/backend/pkg/logger" + "openreplay/backend/pkg/server/api" +) + +type Server struct { + server *http.Server +} + +func New(handler http.Handler, host, port string, timeout time.Duration) (*Server, error) { + switch { + case port == "": + return nil, errors.New("empty server port") + case handler == nil: + return nil, errors.New("empty handler") + case timeout < 1: + return nil, fmt.Errorf("invalid timeout %d", timeout) + } + server := &http.Server{ + Addr: fmt.Sprintf("%s:%s", host, port), + Handler: handler, + ReadTimeout: timeout, + WriteTimeout: timeout, + } + if err := http2.ConfigureServer(server, nil); err != nil { + return nil, fmt.Errorf("error configuring server: %s", err) + } + return &Server{ + server: server, + }, nil +} + +func (s *Server) Start() error { + return s.server.ListenAndServe() +} + +func (s *Server) Stop() { + if err := s.server.Shutdown(context.Background()); err != nil { + fmt.Printf("error shutting down server: %s\n", err) + } +} + +func Run(ctx context.Context, log logger.Logger, cfg *common.HTTP, router api.Router) { + webServer, err := New(router.Get(), cfg.HTTPHost, cfg.HTTPPort, cfg.HTTPTimeout) + if err != nil { + log.Fatal(ctx, "failed while creating server: %s", err) + } + go func() { + if err := webServer.Start(); err != nil { + log.Fatal(ctx, "http server error: %s", err) + } + }() + log.Info(ctx, "server successfully started on port %s", cfg.HTTPPort) + + // Wait stop signal to shut down server gracefully + sigchan := make(chan os.Signal, 1) + signal.Notify(sigchan, syscall.SIGINT, syscall.SIGTERM) + <-sigchan + log.Info(ctx, "shutting down the server") + webServer.Stop() +} diff --git a/backend/pkg/server/tracer/middleware.go b/backend/pkg/server/tracer/middleware.go new file mode 100644 index 000000000..45b3bc2ea --- /dev/null +++ b/backend/pkg/server/tracer/middleware.go @@ -0,0 +1,11 @@ +package tracer + +import ( + "net/http" +) + +func (t *tracerImpl) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + }) +} diff --git a/backend/pkg/server/tracer/tracer.go b/backend/pkg/server/tracer/tracer.go new file mode 100644 index 000000000..a3c980d24 --- /dev/null +++ b/backend/pkg/server/tracer/tracer.go @@ -0,0 +1,23 @@ +package tracer + +import ( + "net/http" + + db "openreplay/backend/pkg/db/postgres/pool" + "openreplay/backend/pkg/logger" +) + +type Tracer interface { + Middleware(next http.Handler) http.Handler + Close() error +} + +type tracerImpl struct{} + +func NewTracer(log logger.Logger, conn db.Pool) (Tracer, error) { + return &tracerImpl{}, nil +} + +func (t *tracerImpl) Close() error { + return nil +} diff --git a/backend/pkg/spot/auth/model.go b/backend/pkg/server/user/model.go similarity index 89% rename from backend/pkg/spot/auth/model.go rename to backend/pkg/server/user/model.go index ef2f09d75..e429e2d7d 100644 --- a/backend/pkg/spot/auth/model.go +++ b/backend/pkg/server/user/model.go @@ -1,4 +1,4 @@ -package auth +package user import "github.com/golang-jwt/jwt/v5" @@ -25,10 +25,3 @@ func (u *User) HasPermission(perm string) bool { _, ok := u.Permissions[perm] return ok } - -func abs(x int) int { - if x < 0 { - return -x - } - return x -} diff --git a/backend/pkg/sessions/api/beacon-cache.go b/backend/pkg/sessions/api/beacon-cache.go new file mode 100644 index 000000000..7219cdd6c --- /dev/null +++ b/backend/pkg/sessions/api/beacon-cache.go @@ -0,0 +1,63 @@ +package api + +import ( + "sync" + "time" +) + +type BeaconSize struct { + size int64 + time time.Time +} + +type BeaconCache struct { + mutex *sync.RWMutex + beaconSizeCache map[uint64]*BeaconSize + defaultLimit int64 +} + +func NewBeaconCache(limit int64) *BeaconCache { + cache := &BeaconCache{ + mutex: &sync.RWMutex{}, + beaconSizeCache: make(map[uint64]*BeaconSize), + defaultLimit: limit, + } + go cache.cleaner() + return cache +} + +func (e *BeaconCache) Add(sessionID uint64, size int64) { + if size <= 0 { + return + } + e.mutex.Lock() + defer e.mutex.Unlock() + e.beaconSizeCache[sessionID] = &BeaconSize{ + size: size, + time: time.Now(), + } +} + +func (e *BeaconCache) Get(sessionID uint64) int64 { + e.mutex.RLock() + defer e.mutex.RUnlock() + if beaconSize, ok := e.beaconSizeCache[sessionID]; ok { + beaconSize.time = time.Now() + return beaconSize.size + } + return e.defaultLimit +} + +func (e *BeaconCache) cleaner() { + for { + time.Sleep(time.Minute * 2) + now := time.Now() + e.mutex.Lock() + for sid, bs := range e.beaconSizeCache { + if now.Sub(bs.time) > time.Minute*3 { + delete(e.beaconSizeCache, sid) + } + } + e.mutex.Unlock() + } +} diff --git a/backend/pkg/sessions/api/mobile/handlers.go b/backend/pkg/sessions/api/mobile/handlers.go new file mode 100644 index 000000000..6fd46801b --- /dev/null +++ b/backend/pkg/sessions/api/mobile/handlers.go @@ -0,0 +1,379 @@ +package mobile + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "math/rand" + "net/http" + "strconv" + "strings" + "time" + + "github.com/Masterminds/semver" + gzip "github.com/klauspost/pgzip" + + httpCfg "openreplay/backend/internal/config/http" + "openreplay/backend/internal/http/geoip" + "openreplay/backend/internal/http/ios" + "openreplay/backend/internal/http/uaparser" + "openreplay/backend/internal/http/uuid" + "openreplay/backend/pkg/conditions" + "openreplay/backend/pkg/db/postgres" + "openreplay/backend/pkg/flakeid" + "openreplay/backend/pkg/logger" + "openreplay/backend/pkg/messages" + "openreplay/backend/pkg/projects" + "openreplay/backend/pkg/queue/types" + "openreplay/backend/pkg/server/api" + "openreplay/backend/pkg/sessions" + "openreplay/backend/pkg/token" +) + +func checkMobileTrackerVersion(ver string) bool { + c, err := semver.NewConstraint(">=1.0.9") + if err != nil { + return false + } + // Check for beta version + parts := strings.Split(ver, "-") + if len(parts) > 1 { + ver = parts[0] + } + v, err := semver.NewVersion(ver) + if err != nil { + return false + } + return c.Check(v) +} + +type handlersImpl struct { + log logger.Logger + cfg *httpCfg.Config + responser *api.Responser + producer types.Producer + projects projects.Projects + sessions sessions.Sessions + uaParser *uaparser.UAParser + geoIP geoip.GeoParser + tokenizer *token.Tokenizer + conditions conditions.Conditions + flaker *flakeid.Flaker + features map[string]bool +} + +func NewHandlers(cfg *httpCfg.Config, log logger.Logger, responser *api.Responser, producer types.Producer, projects projects.Projects, + sessions sessions.Sessions, uaParser *uaparser.UAParser, geoIP geoip.GeoParser, tokenizer *token.Tokenizer, + conditions conditions.Conditions, flaker *flakeid.Flaker) (api.Handlers, error) { + return &handlersImpl{ + log: log, + cfg: cfg, + responser: responser, + producer: producer, + projects: projects, + sessions: sessions, + uaParser: uaParser, + geoIP: geoIP, + tokenizer: tokenizer, + conditions: conditions, + flaker: flaker, + features: map[string]bool{ + "feature-flags": cfg.IsFeatureFlagEnabled, + "usability-test": cfg.IsUsabilityTestEnabled, + }, + }, nil +} + +func (e *handlersImpl) GetAll() []*api.Description { + return []*api.Description{ + {"/v1/mobile/start", e.startMobileSessionHandler, "POST"}, + {"/v1/mobile/i", e.pushMobileMessagesHandler, "POST"}, + {"/v1/mobile/late", e.pushMobileLateMessagesHandler, "POST"}, + {"/v1/mobile/images", e.mobileImagesUploadHandler, "POST"}, + } +} + +func (e *handlersImpl) startMobileSessionHandler(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + + if r.Body == nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, errors.New("request body is empty"), startTime, r.URL.Path, 0) + return + } + body := http.MaxBytesReader(w, r.Body, e.cfg.JsonSizeLimit) + defer body.Close() + + req := &StartMobileSessionRequest{} + if err := json.NewDecoder(body).Decode(req); err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, 0) + return + } + + // Add tracker version to context + r = r.WithContext(context.WithValue(r.Context(), "tracker", req.TrackerVersion)) + + if req.ProjectKey == nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusForbidden, errors.New("projectKey value required"), startTime, r.URL.Path, 0) + return + } + + p, err := e.projects.GetProjectByKey(*req.ProjectKey) + if err != nil { + if postgres.IsNoRowsErr(err) { + logErr := fmt.Errorf("project doesn't exist or is not active, key: %s", *req.ProjectKey) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusNotFound, logErr, startTime, r.URL.Path, 0) + } else { + e.log.Error(r.Context(), "failed to get project by key: %s, err: %s", *req.ProjectKey, err) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, errors.New("can't find a project"), startTime, r.URL.Path, 0) + } + return + } + + // Add projectID to context + r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", p.ProjectID))) + + // Check if the project supports mobile sessions + if !p.IsMobile() { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusForbidden, errors.New("project doesn't support mobile sessions"), startTime, r.URL.Path, 0) + return + } + + if !checkMobileTrackerVersion(req.TrackerVersion) { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusUpgradeRequired, errors.New("tracker version not supported"), startTime, r.URL.Path, 0) + return + } + + userUUID := uuid.GetUUID(req.UserUUID) + tokenData, err := e.tokenizer.Parse(req.Token) + + if err != nil { // Starting the new one + dice := byte(rand.Intn(100)) // [0, 100) + // Use condition rate if it's set + if req.Condition != "" { + rate, err := e.conditions.GetRate(p.ProjectID, req.Condition, int(p.SampleRate)) + if err != nil { + e.log.Warn(r.Context(), "can't get condition rate, condition: %s, err: %s", req.Condition, err) + } else { + p.SampleRate = byte(rate) + } + } + if dice >= p.SampleRate { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusForbidden, fmt.Errorf("capture rate miss, rate: %d", p.SampleRate), startTime, r.URL.Path, 0) + return + } + + ua := e.uaParser.ParseFromHTTPRequest(r) + if ua == nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusForbidden, fmt.Errorf("browser not recognized, user-agent: %s", r.Header.Get("User-Agent")), startTime, r.URL.Path, 0) + return + } + sessionID, err := e.flaker.Compose(uint64(startTime.UnixMilli())) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, 0) + return + } + + expTime := startTime.Add(time.Duration(p.MaxSessionDuration) * time.Millisecond) + tokenData = &token.TokenData{sessionID, 0, expTime.UnixMilli()} + + // Add sessionID to context + r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionID))) + + geoInfo := e.geoIP.ExtractGeoData(r) + deviceType, platform, os := ios.GetIOSDeviceType(req.UserDevice), "ios", "IOS" + if req.Platform != "" && req.Platform != "ios" { + deviceType = req.UserDeviceType + platform = req.Platform + os = "Android" + } + + if !req.DoNotRecord { + if err := e.sessions.Add(&sessions.Session{ + SessionID: sessionID, + Platform: platform, + Timestamp: req.Timestamp, + Timezone: req.Timezone, + ProjectID: p.ProjectID, + TrackerVersion: req.TrackerVersion, + RevID: req.RevID, + UserUUID: userUUID, + UserOS: os, + UserOSVersion: req.UserOSVersion, + UserDevice: ios.MapIOSDevice(req.UserDevice), + UserDeviceType: deviceType, + UserCountry: geoInfo.Country, + UserState: geoInfo.State, + UserCity: geoInfo.City, + UserDeviceMemorySize: req.DeviceMemory, + UserDeviceHeapSize: req.DeviceMemory, + ScreenWidth: req.Width, + ScreenHeight: req.Height, + }); err != nil { + e.log.Warn(r.Context(), "failed to add mobile session to DB: %s", err) + } + + sessStart := &messages.MobileSessionStart{ + Timestamp: req.Timestamp, + ProjectID: uint64(p.ProjectID), + TrackerVersion: req.TrackerVersion, + RevID: req.RevID, + UserUUID: userUUID, + UserOS: os, + UserOSVersion: req.UserOSVersion, + UserDevice: ios.MapIOSDevice(req.UserDevice), + UserDeviceType: deviceType, + UserCountry: geoInfo.Pack(), + } + + if err := e.producer.Produce(e.cfg.TopicRawMobile, tokenData.ID, sessStart.Encode()); err != nil { + e.log.Error(r.Context(), "failed to send mobile sessionStart event to queue: %s", err) + } + } + } + + e.responser.ResponseWithJSON(e.log, r.Context(), w, &StartMobileSessionResponse{ + Token: e.tokenizer.Compose(*tokenData), + UserUUID: userUUID, + SessionID: strconv.FormatUint(tokenData.ID, 10), + BeaconSizeLimit: e.cfg.BeaconSizeLimit, + ImageQuality: e.cfg.MobileQuality, + FrameRate: e.cfg.MobileFps, + ProjectID: strconv.FormatUint(uint64(p.ProjectID), 10), + Features: e.features, + }, startTime, r.URL.Path, 0) +} + +func (e *handlersImpl) pushMobileMessagesHandler(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + + sessionData, err := e.tokenizer.ParseFromHTTPRequest(r) + if sessionData != nil { + r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) + } + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, 0) + return + } + + // Add sessionID and projectID to context + if info, err := e.sessions.Get(sessionData.ID); err == nil { + r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", info.ProjectID))) + } + + e.pushMessages(w, r, sessionData.ID, e.cfg.TopicRawMobile) +} + +func (e *handlersImpl) pushMobileLateMessagesHandler(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + sessionData, err := e.tokenizer.ParseFromHTTPRequest(r) + if sessionData != nil { + r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) + } + + if err != nil && err != token.EXPIRED { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, 0) + return + } + // Check timestamps here? + e.pushMessages(w, r, sessionData.ID, e.cfg.TopicRawMobile) +} + +func (e *handlersImpl) mobileImagesUploadHandler(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + + sessionData, err := e.tokenizer.ParseFromHTTPRequest(r) + if sessionData != nil { + r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) + } + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, 0) + return + } + + // Add sessionID and projectID to context + if info, err := e.sessions.Get(sessionData.ID); err == nil { + r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", info.ProjectID))) + } + + if r.Body == nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, errors.New("request body is empty"), startTime, r.URL.Path, 0) + return + } + r.Body = http.MaxBytesReader(w, r.Body, e.cfg.FileSizeLimit) + defer r.Body.Close() + + err = r.ParseMultipartForm(5 * 1e6) // ~5Mb + if err == http.ErrNotMultipart || err == http.ErrMissingBoundary { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusUnsupportedMediaType, err, startTime, r.URL.Path, 0) + return + } else if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, 0) // TODO: send error here only on staging + return + } + + if r.MultipartForm == nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, errors.New("multipart not parsed"), startTime, r.URL.Path, 0) + return + } + + if len(r.MultipartForm.Value["projectKey"]) == 0 { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, errors.New("projectKey parameter missing"), startTime, r.URL.Path, 0) // status for missing/wrong parameter? + return + } + + for _, fileHeaderList := range r.MultipartForm.File { + for _, fileHeader := range fileHeaderList { + file, err := fileHeader.Open() + if err != nil { + continue + } + + data, err := io.ReadAll(file) + if err != nil { + file.Close() + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, 0) + return + } + file.Close() + + if err := e.producer.Produce(e.cfg.TopicRawImages, sessionData.ID, data); err != nil { + e.log.Warn(r.Context(), "failed to send image to queue: %s", err) + } + } + } + e.responser.ResponseOK(e.log, r.Context(), w, startTime, r.URL.Path, 0) +} + +func (e *handlersImpl) pushMessages(w http.ResponseWriter, r *http.Request, sessionID uint64, topicName string) { + start := time.Now() + body := http.MaxBytesReader(w, r.Body, e.cfg.BeaconSizeLimit) + defer body.Close() + + var reader io.ReadCloser + var err error + + switch r.Header.Get("Content-Encoding") { + case "gzip": + reader, err = gzip.NewReader(body) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, start, r.URL.Path, 0) + return + } + defer reader.Close() + default: + reader = body + } + buf, err := io.ReadAll(reader) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, start, r.URL.Path, 0) + return + } + if err := e.producer.Produce(topicName, sessionID, buf); err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, start, r.URL.Path, 0) + return + } + w.WriteHeader(http.StatusOK) + e.log.Info(r.Context(), "response ok") +} diff --git a/backend/internal/http/router/model.go b/backend/pkg/sessions/api/mobile/model.go similarity index 87% rename from backend/internal/http/router/model.go rename to backend/pkg/sessions/api/mobile/model.go index 6649b26cc..b42b9d25a 100644 --- a/backend/internal/http/router/model.go +++ b/backend/pkg/sessions/api/mobile/model.go @@ -1,10 +1,4 @@ -package router - -type NotStartedRequest struct { - ProjectKey *string `json:"projectKey"` - TrackerVersion string `json:"trackerVersion"` - DoNotTrack bool `json:"DoNotTrack"` -} +package mobile type StartMobileSessionRequest struct { Token string `json:"token"` diff --git a/backend/pkg/sessions/api/web/handlers.go b/backend/pkg/sessions/api/web/handlers.go new file mode 100644 index 000000000..6d50854c0 --- /dev/null +++ b/backend/pkg/sessions/api/web/handlers.go @@ -0,0 +1,512 @@ +package web + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "math/rand" + "net/http" + "strconv" + "strings" + "time" + + "github.com/Masterminds/semver" + + httpCfg "openreplay/backend/internal/config/http" + "openreplay/backend/internal/http/geoip" + "openreplay/backend/internal/http/uaparser" + "openreplay/backend/internal/http/util" + "openreplay/backend/internal/http/uuid" + "openreplay/backend/pkg/conditions" + "openreplay/backend/pkg/db/postgres" + "openreplay/backend/pkg/flakeid" + "openreplay/backend/pkg/logger" + . "openreplay/backend/pkg/messages" + "openreplay/backend/pkg/projects" + "openreplay/backend/pkg/queue/types" + "openreplay/backend/pkg/server/api" + "openreplay/backend/pkg/sessions" + beacons "openreplay/backend/pkg/sessions/api" + "openreplay/backend/pkg/token" +) + +type handlersImpl struct { + log logger.Logger + cfg *httpCfg.Config + responser *api.Responser + producer types.Producer + projects projects.Projects + sessions sessions.Sessions + uaParser *uaparser.UAParser + geoIP geoip.GeoParser + tokenizer *token.Tokenizer + conditions conditions.Conditions + flaker *flakeid.Flaker + beaconSizeCache *beacons.BeaconCache + features map[string]bool +} + +func NewHandlers(cfg *httpCfg.Config, log logger.Logger, responser *api.Responser, producer types.Producer, projects projects.Projects, + sessions sessions.Sessions, uaParser *uaparser.UAParser, geoIP geoip.GeoParser, tokenizer *token.Tokenizer, + conditions conditions.Conditions, flaker *flakeid.Flaker) (api.Handlers, error) { + return &handlersImpl{ + log: log, + cfg: cfg, + responser: responser, + producer: producer, + projects: projects, + sessions: sessions, + uaParser: uaParser, + geoIP: geoIP, + tokenizer: tokenizer, + conditions: conditions, + flaker: flaker, + beaconSizeCache: beacons.NewBeaconCache(cfg.BeaconSizeLimit), + features: map[string]bool{ + "feature-flags": cfg.IsFeatureFlagEnabled, + "usability-test": cfg.IsUsabilityTestEnabled, + }, + }, nil +} + +func (e *handlersImpl) GetAll() []*api.Description { + return []*api.Description{ + {"/v1/web/not-started", e.notStartedHandlerWeb, "POST"}, + {"/v1/web/start", e.startSessionHandlerWeb, "POST"}, + {"/v1/web/i", e.pushMessagesHandlerWeb, "POST"}, + {"/v1/web/images", e.imagesUploaderHandlerWeb, "POST"}, + } +} + +func getSessionTimestamp(req *StartSessionRequest, startTimeMili int64) (ts uint64) { + ts = uint64(req.Timestamp) + if req.IsOffline { + return + } + c, err := semver.NewConstraint(">=4.1.6") + if err != nil { + return + } + ver := req.TrackerVersion + parts := strings.Split(ver, "-") + if len(parts) > 1 { + ver = parts[0] + } + v, err := semver.NewVersion(ver) + if err != nil { + return + } + if c.Check(v) { + ts = uint64(startTimeMili) + if req.BufferDiff > 0 && req.BufferDiff < 5*60*1000 { + ts -= req.BufferDiff + } + } + return +} + +func (e *handlersImpl) startSessionHandlerWeb(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + bodySize := 0 + + // Check request body + if r.Body == nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, errors.New("request body is empty"), startTime, r.URL.Path, bodySize) + return + } + + bodyBytes, err := api.ReadCompressedBody(e.log, w, r, e.cfg.JsonSizeLimit) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) + return + } + bodySize = len(bodyBytes) + + // Parse request body + req := &StartSessionRequest{} + if err := json.Unmarshal(bodyBytes, req); err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + return + } + + // Add tracker version to context + r = r.WithContext(context.WithValue(r.Context(), "tracker", req.TrackerVersion)) + + // Handler's logic + if req.ProjectKey == nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusForbidden, errors.New("ProjectKey value required"), startTime, r.URL.Path, bodySize) + return + } + + p, err := e.projects.GetProjectByKey(*req.ProjectKey) + if err != nil { + if postgres.IsNoRowsErr(err) { + logErr := fmt.Errorf("project doesn't exist or is not active, key: %s", *req.ProjectKey) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusNotFound, logErr, startTime, r.URL.Path, bodySize) + } else { + e.log.Error(r.Context(), "failed to get project by key: %s, err: %s", *req.ProjectKey, err) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, errors.New("can't find a project"), startTime, r.URL.Path, bodySize) + } + return + } + + // Add projectID to context + r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", p.ProjectID))) + + // Check if the project supports mobile sessions + if !p.IsWeb() { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusForbidden, errors.New("project doesn't support web sessions"), startTime, r.URL.Path, bodySize) + return + } + + ua := e.uaParser.ParseFromHTTPRequest(r) + if ua == nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusForbidden, fmt.Errorf("browser not recognized, user-agent: %s", r.Header.Get("User-Agent")), startTime, r.URL.Path, bodySize) + return + } + + geoInfo := e.geoIP.ExtractGeoData(r) + + userUUID := uuid.GetUUID(req.UserUUID) + tokenData, err := e.tokenizer.Parse(req.Token) + if err != nil || req.Reset { // Starting the new one + dice := byte(rand.Intn(100)) + // Use condition rate if it's set + if req.Condition != "" { + rate, err := e.conditions.GetRate(p.ProjectID, req.Condition, int(p.SampleRate)) + if err != nil { + e.log.Warn(r.Context(), "can't get condition rate, condition: %s, err: %s", req.Condition, err) + } else { + p.SampleRate = byte(rate) + } + } + if dice >= p.SampleRate { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusForbidden, fmt.Errorf("capture rate miss, rate: %d", p.SampleRate), startTime, r.URL.Path, bodySize) + return + } + + startTimeMili := startTime.UnixMilli() + sessionID, err := e.flaker.Compose(uint64(startTimeMili)) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + return + } + + expTime := startTime.Add(time.Duration(p.MaxSessionDuration) * time.Millisecond) + tokenData = &token.TokenData{ + ID: sessionID, + Delay: startTimeMili - req.Timestamp, + ExpTime: expTime.UnixMilli(), + } + + // Add sessionID to context + r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionID))) + + if recordSession(req) { + sessionStart := &SessionStart{ + Timestamp: getSessionTimestamp(req, startTimeMili), + ProjectID: uint64(p.ProjectID), + TrackerVersion: req.TrackerVersion, + RevID: req.RevID, + UserUUID: userUUID, + UserAgent: r.Header.Get("User-Agent"), + UserOS: ua.OS, + UserOSVersion: ua.OSVersion, + UserBrowser: ua.Browser, + UserBrowserVersion: ua.BrowserVersion, + UserDevice: ua.Device, + UserDeviceType: ua.DeviceType, + UserCountry: geoInfo.Pack(), + UserDeviceMemorySize: req.DeviceMemory, + UserDeviceHeapSize: req.JsHeapSizeLimit, + UserID: req.UserID, + } + + // Save sessionStart to db + if err := e.sessions.Add(&sessions.Session{ + SessionID: sessionID, + Platform: "web", + Timestamp: sessionStart.Timestamp, + Timezone: req.Timezone, + ProjectID: uint32(sessionStart.ProjectID), + TrackerVersion: sessionStart.TrackerVersion, + RevID: sessionStart.RevID, + UserUUID: sessionStart.UserUUID, + UserOS: sessionStart.UserOS, + UserOSVersion: sessionStart.UserOSVersion, + UserDevice: sessionStart.UserDevice, + UserCountry: geoInfo.Country, + UserState: geoInfo.State, + UserCity: geoInfo.City, + UserAgent: sessionStart.UserAgent, + UserBrowser: sessionStart.UserBrowser, + UserBrowserVersion: sessionStart.UserBrowserVersion, + UserDeviceType: sessionStart.UserDeviceType, + UserDeviceMemorySize: sessionStart.UserDeviceMemorySize, + UserDeviceHeapSize: sessionStart.UserDeviceHeapSize, + UserID: &sessionStart.UserID, + ScreenWidth: req.Width, + ScreenHeight: req.Height, + }); err != nil { + e.log.Warn(r.Context(), "can't insert sessionStart to DB: %s", err) + } + + // Send sessionStart message to kafka + if err := e.producer.Produce(e.cfg.TopicRawWeb, tokenData.ID, sessionStart.Encode()); err != nil { + e.log.Error(r.Context(), "can't send sessionStart to queue: %s", err) + } + } + } else { + r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", tokenData.ID))) + } + + // Save information about session beacon size + e.beaconSizeCache.Add(tokenData.ID, p.BeaconSize) + + startResponse := &StartSessionResponse{ + Token: e.tokenizer.Compose(*tokenData), + UserUUID: userUUID, + UserOS: ua.OS, + UserDevice: ua.Device, + UserBrowser: ua.Browser, + UserCountry: geoInfo.Country, + UserState: geoInfo.State, + UserCity: geoInfo.City, + SessionID: strconv.FormatUint(tokenData.ID, 10), + ProjectID: strconv.FormatUint(uint64(p.ProjectID), 10), + BeaconSizeLimit: e.beaconSizeCache.Get(tokenData.ID), + CompressionThreshold: e.cfg.CompressionThreshold, + StartTimestamp: int64(flakeid.ExtractTimestamp(tokenData.ID)), + Delay: tokenData.Delay, + CanvasEnabled: e.cfg.RecordCanvas, + CanvasImageQuality: e.cfg.CanvasQuality, + CanvasFrameRate: e.cfg.CanvasFps, + Features: e.features, + } + modifyResponse(req, startResponse) + + e.responser.ResponseWithJSON(e.log, r.Context(), w, startResponse, startTime, r.URL.Path, bodySize) +} + +func (e *handlersImpl) pushMessagesHandlerWeb(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + bodySize := 0 + + // Get debug header with batch info + if batch := r.URL.Query().Get("batch"); batch != "" { + r = r.WithContext(context.WithValue(r.Context(), "batch", batch)) + } + + // Check authorization + sessionData, err := e.tokenizer.ParseFromHTTPRequest(r) + if sessionData != nil { + r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) + } + tokenJustExpired := false + if err != nil { + if errors.Is(err, token.JUST_EXPIRED) { + tokenJustExpired = true + } else { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, bodySize) + return + } + } + + // Add sessionID and projectID to context + if info, err := e.sessions.Get(sessionData.ID); err == nil { + r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", info.ProjectID))) + } + + // Check request body + if r.Body == nil { + errCode := http.StatusBadRequest + if tokenJustExpired { + errCode = http.StatusUnauthorized + } + e.responser.ResponseWithError(e.log, r.Context(), w, errCode, errors.New("request body is empty"), startTime, r.URL.Path, bodySize) + return + } + + bodyBytes, err := api.ReadCompressedBody(e.log, w, r, e.beaconSizeCache.Get(sessionData.ID)) + if err != nil { + errCode := http.StatusRequestEntityTooLarge + if tokenJustExpired { + errCode = http.StatusUnauthorized + } + e.responser.ResponseWithError(e.log, r.Context(), w, errCode, err, startTime, r.URL.Path, bodySize) + return + } + bodySize = len(bodyBytes) + + // Send processed messages to queue as array of bytes + err = e.producer.Produce(e.cfg.TopicRawWeb, sessionData.ID, bodyBytes) + if err != nil { + e.log.Error(r.Context(), "can't send messages batch to queue: %s", err) + errCode := http.StatusInternalServerError + if tokenJustExpired { + errCode = http.StatusUnauthorized + } + e.responser.ResponseWithError(e.log, r.Context(), w, errCode, errors.New("can't save message, try again"), startTime, r.URL.Path, bodySize) + return + } + + if tokenJustExpired { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusUnauthorized, errors.New("token expired"), startTime, r.URL.Path, bodySize) + return + } + e.responser.ResponseOK(e.log, r.Context(), w, startTime, r.URL.Path, bodySize) +} + +func (e *handlersImpl) notStartedHandlerWeb(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + bodySize := 0 + + if r.Body == nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, errors.New("request body is empty"), startTime, r.URL.Path, bodySize) + return + } + bodyBytes, err := api.ReadCompressedBody(e.log, w, r, e.cfg.JsonSizeLimit) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) + return + } + bodySize = len(bodyBytes) + + req := &NotStartedRequest{} + if err := json.Unmarshal(bodyBytes, req); err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + return + } + + // Add tracker version to context + r = r.WithContext(context.WithValue(r.Context(), "tracker", req.TrackerVersion)) + + // Handler's logic + if req.ProjectKey == nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusForbidden, errors.New("projectKey value required"), startTime, r.URL.Path, bodySize) + return + } + p, err := e.projects.GetProjectByKey(*req.ProjectKey) + if err != nil { + if postgres.IsNoRowsErr(err) { + logErr := fmt.Errorf("project doesn't exist or is not active, key: %s", *req.ProjectKey) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusNotFound, logErr, startTime, r.URL.Path, bodySize) + } else { + e.log.Error(r.Context(), "can't find a project: %s", err) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, errors.New("can't find a project"), startTime, r.URL.Path, bodySize) + } + return + } + + // Add projectID to context + r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", p.ProjectID))) + + ua := e.uaParser.ParseFromHTTPRequest(r) + if ua == nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusForbidden, fmt.Errorf("browser not recognized, user-agent: %s", r.Header.Get("User-Agent")), startTime, r.URL.Path, bodySize) + return + } + geoInfo := e.geoIP.ExtractGeoData(r) + err = e.sessions.AddUnStarted(&sessions.UnStartedSession{ + ProjectKey: *req.ProjectKey, + TrackerVersion: req.TrackerVersion, + DoNotTrack: req.DoNotTrack, + Platform: "web", + UserAgent: r.Header.Get("User-Agent"), + UserOS: ua.OS, + UserOSVersion: ua.OSVersion, + UserBrowser: ua.Browser, + UserBrowserVersion: ua.BrowserVersion, + UserDevice: ua.Device, + UserDeviceType: ua.DeviceType, + UserCountry: geoInfo.Country, + UserState: geoInfo.State, + UserCity: geoInfo.City, + }) + if err != nil { + e.log.Warn(r.Context(), "can't insert un-started session: %s", err) + } + // response ok anyway + e.responser.ResponseOK(e.log, r.Context(), w, startTime, r.URL.Path, bodySize) +} + +type ScreenshotMessage struct { + Name string + Data []byte +} + +func (e *handlersImpl) imagesUploaderHandlerWeb(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + + sessionData, err := e.tokenizer.ParseFromHTTPRequest(r) + if sessionData != nil { + r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) + } + if err != nil { // Should accept expired token? + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, 0) + return + } + + // Add sessionID and projectID to context + if info, err := e.sessions.Get(sessionData.ID); err == nil { + r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", info.ProjectID))) + } + + if r.Body == nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, errors.New("request body is empty"), startTime, r.URL.Path, 0) + return + } + r.Body = http.MaxBytesReader(w, r.Body, e.cfg.FileSizeLimit) + defer r.Body.Close() + + // Parse the multipart form + err = r.ParseMultipartForm(10 << 20) // Max upload size 10 MB + if err == http.ErrNotMultipart || err == http.ErrMissingBoundary { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusUnsupportedMediaType, err, startTime, r.URL.Path, 0) + return + } else if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, 0) + return + } + + // Iterate over uploaded files + for _, fileHeaderList := range r.MultipartForm.File { + for _, fileHeader := range fileHeaderList { + file, err := fileHeader.Open() + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, 0) + return + } + + // Read the file content + fileBytes, err := io.ReadAll(file) + if err != nil { + file.Close() + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, 0) + return + } + file.Close() + + fileName := util.SafeString(fileHeader.Filename) + + // Create a message to send to Kafka + msg := ScreenshotMessage{ + Name: fileName, + Data: fileBytes, + } + data, err := json.Marshal(&msg) + if err != nil { + e.log.Warn(r.Context(), "can't marshal screenshot message, err: %s", err) + continue + } + + // Send the message to queue + if err := e.producer.Produce(e.cfg.TopicCanvasImages, sessionData.ID, data); err != nil { + e.log.Warn(r.Context(), "can't send screenshot message to queue, err: %s", err) + } + } + } + e.responser.ResponseOK(e.log, r.Context(), w, startTime, r.URL.Path, 0) +} diff --git a/backend/internal/http/router/web-start.go b/backend/pkg/sessions/api/web/model.go similarity index 92% rename from backend/internal/http/router/web-start.go rename to backend/pkg/sessions/api/web/model.go index 3b21936c7..493178a3f 100644 --- a/backend/internal/http/router/web-start.go +++ b/backend/pkg/sessions/api/web/model.go @@ -1,4 +1,10 @@ -package router +package web + +type NotStartedRequest struct { + ProjectKey *string `json:"projectKey"` + TrackerVersion string `json:"trackerVersion"` + DoNotTrack bool `json:"DoNotTrack"` +} type StartSessionRequest struct { Token string `json:"token"` diff --git a/backend/pkg/spot/api/handlers.go b/backend/pkg/spot/api/handlers.go index 42765e213..2b6f66c09 100644 --- a/backend/pkg/spot/api/handlers.go +++ b/backend/pkg/spot/api/handlers.go @@ -2,12 +2,10 @@ package api import ( "bytes" - "context" "encoding/base64" "encoding/json" "errors" "fmt" - "io" "net/http" "strconv" "strings" @@ -15,59 +13,106 @@ import ( "github.com/gorilla/mux" - metrics "openreplay/backend/pkg/metrics/spot" + spotConfig "openreplay/backend/internal/config/spot" + "openreplay/backend/pkg/logger" "openreplay/backend/pkg/objectstorage" - "openreplay/backend/pkg/spot/auth" + "openreplay/backend/pkg/server/api" + "openreplay/backend/pkg/server/keys" + "openreplay/backend/pkg/server/user" "openreplay/backend/pkg/spot/service" + "openreplay/backend/pkg/spot/transcoder" ) -func (e *Router) createSpot(w http.ResponseWriter, r *http.Request) { +type handlersImpl struct { + log logger.Logger + responser *api.Responser + jsonSizeLimit int64 + spots service.Spots + objStorage objectstorage.ObjectStorage + transcoder transcoder.Transcoder + keys keys.Keys +} + +func NewHandlers(log logger.Logger, cfg *spotConfig.Config, responser *api.Responser, spots service.Spots, objStore objectstorage.ObjectStorage, transcoder transcoder.Transcoder, keys keys.Keys) (api.Handlers, error) { + return &handlersImpl{ + log: log, + responser: responser, + jsonSizeLimit: cfg.JsonSizeLimit, + spots: spots, + objStorage: objStore, + transcoder: transcoder, + keys: keys, + }, nil +} + +func (e *handlersImpl) GetAll() []*api.Description { + return []*api.Description{ + {"/v1/spots", e.createSpot, "POST"}, + {"/v1/spots/{id}", e.getSpot, "GET"}, + {"/v1/spots/{id}", e.updateSpot, "PATCH"}, + {"/v1/spots", e.getSpots, "GET"}, + {"/v1/spots", e.deleteSpots, "DELETE"}, + {"/v1/spots/{id}/comment", e.addComment, "POST"}, + {"/v1/spots/{id}/uploaded", e.uploadedSpot, "POST"}, + {"/v1/spots/{id}/video", e.getSpotVideo, "GET"}, + {"/v1/spots/{id}/public-key", e.getPublicKey, "GET"}, + {"/v1/spots/{id}/public-key", e.updatePublicKey, "PATCH"}, + {"/v1/spots/{id}/status", e.spotStatus, "GET"}, + {"/v1/ping", e.ping, "GET"}, + } +} + +func (e *handlersImpl) ping(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) +} + +func (e *handlersImpl) createSpot(w http.ResponseWriter, r *http.Request) { startTime := time.Now() bodySize := 0 - bodyBytes, err := e.readBody(w, r, e.cfg.JsonSizeLimit) + bodyBytes, err := api.ReadBody(e.log, w, r, e.jsonSizeLimit) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) return } bodySize = len(bodyBytes) req := &CreateSpotRequest{} if err := json.Unmarshal(bodyBytes, req); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) return } // Creat a spot - currUser := r.Context().Value("userData").(*auth.User) - newSpot, err := e.services.Spots.Add(currUser, req.Name, req.Comment, req.Duration, req.Crop) + currUser := r.Context().Value("userData").(*user.User) + newSpot, err := e.spots.Add(currUser, req.Name, req.Comment, req.Duration, req.Crop) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) return } // Parse and upload preview image previewImage, err := getSpotPreview(req.Preview) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) return } previewName := fmt.Sprintf("%d/preview.jpeg", newSpot.ID) - if err = e.services.ObjStorage.Upload(bytes.NewReader(previewImage), previewName, "image/jpeg", objectstorage.NoCompression); err != nil { + if err = e.objStorage.Upload(bytes.NewReader(previewImage), previewName, "image/jpeg", objectstorage.NoCompression); err != nil { e.log.Error(r.Context(), "can't upload preview image: %s", err) - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, errors.New("can't upload preview image"), startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, errors.New("can't upload preview image"), startTime, r.URL.Path, bodySize) return } mobURL, err := e.getUploadMobURL(newSpot.ID) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) return } videoURL, err := e.getUploadVideoURL(newSpot.ID) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) return } @@ -76,7 +121,7 @@ func (e *Router) createSpot(w http.ResponseWriter, r *http.Request) { MobURL: mobURL, VideoURL: videoURL, } - e.ResponseWithJSON(r.Context(), w, resp, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithJSON(e.log, r.Context(), w, resp, startTime, r.URL.Path, bodySize) } func getSpotPreview(preview string) ([]byte, error) { @@ -93,18 +138,18 @@ func getSpotPreview(preview string) ([]byte, error) { return data, nil } -func (e *Router) getUploadMobURL(spotID uint64) (string, error) { +func (e *handlersImpl) getUploadMobURL(spotID uint64) (string, error) { mobKey := fmt.Sprintf("%d/events.mob", spotID) - mobURL, err := e.services.ObjStorage.GetPreSignedUploadUrl(mobKey) + mobURL, err := e.objStorage.GetPreSignedUploadUrl(mobKey) if err != nil { return "", fmt.Errorf("can't get mob URL: %s", err) } return mobURL, nil } -func (e *Router) getUploadVideoURL(spotID uint64) (string, error) { +func (e *handlersImpl) getUploadVideoURL(spotID uint64) (string, error) { mobKey := fmt.Sprintf("%d/video.webm", spotID) - mobURL, err := e.services.ObjStorage.GetPreSignedUploadUrl(mobKey) + mobURL, err := e.objStorage.GetPreSignedUploadUrl(mobKey) if err != nil { return "", fmt.Errorf("can't get video URL: %s", err) } @@ -143,51 +188,51 @@ func getSpotsRequest(r *http.Request) (*GetSpotsRequest, error) { return req, nil } -func (e *Router) getPreviewURL(spotID uint64) (string, error) { +func (e *handlersImpl) getPreviewURL(spotID uint64) (string, error) { previewKey := fmt.Sprintf("%d/preview.jpeg", spotID) - previewURL, err := e.services.ObjStorage.GetPreSignedDownloadUrl(previewKey) + previewURL, err := e.objStorage.GetPreSignedDownloadUrl(previewKey) if err != nil { return "", fmt.Errorf("can't get preview URL: %s", err) } return previewURL, nil } -func (e *Router) getMobURL(spotID uint64) (string, error) { +func (e *handlersImpl) getMobURL(spotID uint64) (string, error) { mobKey := fmt.Sprintf("%d/events.mob", spotID) - mobURL, err := e.services.ObjStorage.GetPreSignedDownloadUrl(mobKey) + mobURL, err := e.objStorage.GetPreSignedDownloadUrl(mobKey) if err != nil { return "", fmt.Errorf("can't get mob URL: %s", err) } return mobURL, nil } -func (e *Router) getVideoURL(spotID uint64) (string, error) { +func (e *handlersImpl) getVideoURL(spotID uint64) (string, error) { mobKey := fmt.Sprintf("%d/video.webm", spotID) // TODO: later return url to m3u8 file - mobURL, err := e.services.ObjStorage.GetPreSignedDownloadUrl(mobKey) + mobURL, err := e.objStorage.GetPreSignedDownloadUrl(mobKey) if err != nil { return "", fmt.Errorf("can't get video URL: %s", err) } return mobURL, nil } -func (e *Router) getSpot(w http.ResponseWriter, r *http.Request) { +func (e *handlersImpl) getSpot(w http.ResponseWriter, r *http.Request) { startTime := time.Now() bodySize := 0 id, err := getSpotID(r) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) return } - user := r.Context().Value("userData").(*auth.User) - res, err := e.services.Spots.GetByID(user, id) + user := r.Context().Value("userData").(*user.User) + res, err := e.spots.GetByID(user, id) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) return } if res == nil { - e.ResponseWithError(r.Context(), w, http.StatusNotFound, fmt.Errorf("spot not found"), startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusNotFound, fmt.Errorf("spot not found"), startTime, r.URL.Path, bodySize) return } @@ -197,12 +242,12 @@ func (e *Router) getSpot(w http.ResponseWriter, r *http.Request) { } mobURL, err := e.getMobURL(id) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) return } videoURL, err := e.getVideoURL(id) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) return } @@ -216,60 +261,60 @@ func (e *Router) getSpot(w http.ResponseWriter, r *http.Request) { MobURL: mobURL, VideoURL: videoURL, } - playlist, err := e.services.Transcoder.GetSpotStreamPlaylist(id) + playlist, err := e.transcoder.GetSpotStreamPlaylist(id) if err != nil { e.log.Warn(r.Context(), "can't get stream playlist: %s", err) } else { spotInfo.StreamFile = base64.StdEncoding.EncodeToString(playlist) } - e.ResponseWithJSON(r.Context(), w, &GetSpotResponse{Spot: spotInfo}, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithJSON(e.log, r.Context(), w, &GetSpotResponse{Spot: spotInfo}, startTime, r.URL.Path, bodySize) } -func (e *Router) updateSpot(w http.ResponseWriter, r *http.Request) { +func (e *handlersImpl) updateSpot(w http.ResponseWriter, r *http.Request) { startTime := time.Now() bodySize := 0 id, err := getSpotID(r) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) return } - bodyBytes, err := e.readBody(w, r, e.cfg.JsonSizeLimit) + bodyBytes, err := api.ReadBody(e.log, w, r, e.jsonSizeLimit) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) return } bodySize = len(bodyBytes) req := &UpdateSpotRequest{} if err := json.Unmarshal(bodyBytes, req); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) return } - user := r.Context().Value("userData").(*auth.User) - _, err = e.services.Spots.UpdateName(user, id, req.Name) + user := r.Context().Value("userData").(*user.User) + _, err = e.spots.UpdateName(user, id, req.Name) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) return } - e.ResponseOK(r.Context(), w, startTime, r.URL.Path, bodySize) + e.responser.ResponseOK(e.log, r.Context(), w, startTime, r.URL.Path, bodySize) } -func (e *Router) getSpots(w http.ResponseWriter, r *http.Request) { +func (e *handlersImpl) getSpots(w http.ResponseWriter, r *http.Request) { startTime := time.Now() bodySize := 0 req, err := getSpotsRequest(r) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) return } - user := r.Context().Value("userData").(*auth.User) + user := r.Context().Value("userData").(*user.User) opts := &service.GetOpts{ NameFilter: req.Query, Order: req.Order, Page: req.Page, Limit: req.Limit} switch req.FilterBy { @@ -278,9 +323,9 @@ func (e *Router) getSpots(w http.ResponseWriter, r *http.Request) { default: opts.TenantID = user.TenantID } - spots, total, tenantHasSpots, err := e.services.Spots.Get(user, opts) + spots, total, tenantHasSpots, err := e.spots.Get(user, opts) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) return } res := make([]ShortInfo, 0, len(spots)) @@ -298,82 +343,82 @@ func (e *Router) getSpots(w http.ResponseWriter, r *http.Request) { PreviewURL: previewUrl, }) } - e.ResponseWithJSON(r.Context(), w, &GetSpotsResponse{Spots: res, Total: total, TenantHasSpots: tenantHasSpots}, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithJSON(e.log, r.Context(), w, &GetSpotsResponse{Spots: res, Total: total, TenantHasSpots: tenantHasSpots}, startTime, r.URL.Path, bodySize) } -func (e *Router) deleteSpots(w http.ResponseWriter, r *http.Request) { +func (e *handlersImpl) deleteSpots(w http.ResponseWriter, r *http.Request) { startTime := time.Now() bodySize := 0 - bodyBytes, err := e.readBody(w, r, e.cfg.JsonSizeLimit) + bodyBytes, err := api.ReadBody(e.log, w, r, e.jsonSizeLimit) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) return } bodySize = len(bodyBytes) req := &DeleteSpotRequest{} if err := json.Unmarshal(bodyBytes, req); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) return } spotsToDelete := make([]uint64, 0, len(req.SpotIDs)) for _, idStr := range req.SpotIDs { id, err := strconv.ParseUint(idStr, 10, 64) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, fmt.Errorf("invalid spot id: %s", idStr), startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, fmt.Errorf("invalid spot id: %s", idStr), startTime, r.URL.Path, bodySize) return } spotsToDelete = append(spotsToDelete, id) } - user := r.Context().Value("userData").(*auth.User) - if err := e.services.Spots.Delete(user, spotsToDelete); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + user := r.Context().Value("userData").(*user.User) + if err := e.spots.Delete(user, spotsToDelete); err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) return } - e.ResponseOK(r.Context(), w, startTime, r.URL.Path, bodySize) + e.responser.ResponseOK(e.log, r.Context(), w, startTime, r.URL.Path, bodySize) } -func (e *Router) addComment(w http.ResponseWriter, r *http.Request) { +func (e *handlersImpl) addComment(w http.ResponseWriter, r *http.Request) { startTime := time.Now() bodySize := 0 id, err := getSpotID(r) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) return } - bodyBytes, err := e.readBody(w, r, e.cfg.JsonSizeLimit) + bodyBytes, err := api.ReadBody(e.log, w, r, e.jsonSizeLimit) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) return } bodySize = len(bodyBytes) req := &AddCommentRequest{} if err := json.Unmarshal(bodyBytes, req); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) return } - user := r.Context().Value("userData").(*auth.User) - updatedSpot, err := e.services.Spots.AddComment(user, id, &service.Comment{UserName: req.UserName, Text: req.Comment}) + user := r.Context().Value("userData").(*user.User) + updatedSpot, err := e.spots.AddComment(user, id, &service.Comment{UserName: req.UserName, Text: req.Comment}) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) return } mobURL, err := e.getMobURL(id) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) return } videoURL, err := e.getVideoURL(id) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) return } @@ -385,70 +430,70 @@ func (e *Router) addComment(w http.ResponseWriter, r *http.Request) { MobURL: mobURL, VideoURL: videoURL, } - e.ResponseWithJSON(r.Context(), w, &GetSpotResponse{Spot: spotInfo}, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithJSON(e.log, r.Context(), w, &GetSpotResponse{Spot: spotInfo}, startTime, r.URL.Path, bodySize) } -func (e *Router) uploadedSpot(w http.ResponseWriter, r *http.Request) { +func (e *handlersImpl) uploadedSpot(w http.ResponseWriter, r *http.Request) { startTime := time.Now() bodySize := 0 id, err := getSpotID(r) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) return } - user := r.Context().Value("userData").(*auth.User) - spot, err := e.services.Spots.GetByID(user, id) // check if spot exists + user := r.Context().Value("userData").(*user.User) + spot, err := e.spots.GetByID(user, id) // check if spot exists if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) return } e.log.Info(r.Context(), "uploaded spot %+v, from user: %+v", spot, user) - if err := e.services.Transcoder.Process(spot); err != nil { + if err := e.transcoder.Process(spot); err != nil { e.log.Error(r.Context(), "can't add transcoding task: %s", err) } - e.ResponseOK(r.Context(), w, startTime, r.URL.Path, bodySize) + e.responser.ResponseOK(e.log, r.Context(), w, startTime, r.URL.Path, bodySize) } -func (e *Router) getSpotVideo(w http.ResponseWriter, r *http.Request) { +func (e *handlersImpl) getSpotVideo(w http.ResponseWriter, r *http.Request) { startTime := time.Now() bodySize := 0 id, err := getSpotID(r) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) return } key := fmt.Sprintf("%d/video.webm", id) - videoURL, err := e.services.ObjStorage.GetPreSignedDownloadUrl(key) + videoURL, err := e.objStorage.GetPreSignedDownloadUrl(key) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) return } resp := map[string]interface{}{ "url": videoURL, } - e.ResponseWithJSON(r.Context(), w, resp, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithJSON(e.log, r.Context(), w, resp, startTime, r.URL.Path, bodySize) } -func (e *Router) getSpotStream(w http.ResponseWriter, r *http.Request) { +func (e *handlersImpl) getSpotStream(w http.ResponseWriter, r *http.Request) { startTime := time.Now() bodySize := 0 id, err := getSpotID(r) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) return } // Example data to serve as the file content - streamPlaylist, err := e.services.Transcoder.GetSpotStreamPlaylist(id) + streamPlaylist, err := e.transcoder.GetSpotStreamPlaylist(id) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) return } @@ -462,144 +507,90 @@ func (e *Router) getSpotStream(w http.ResponseWriter, r *http.Request) { // Write the content of the buffer to the response writer if _, err := buffer.WriteTo(w); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) return } } -func (e *Router) getPublicKey(w http.ResponseWriter, r *http.Request) { +func (e *handlersImpl) getPublicKey(w http.ResponseWriter, r *http.Request) { startTime := time.Now() bodySize := 0 id, err := getSpotID(r) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) return } - user := r.Context().Value("userData").(*auth.User) - key, err := e.services.Keys.Get(id, user) + user := r.Context().Value("userData").(*user.User) + key, err := e.keys.Get(id, user) if err != nil { if strings.Contains(err.Error(), "not found") { - e.ResponseWithError(r.Context(), w, http.StatusNotFound, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusNotFound, err, startTime, r.URL.Path, bodySize) } else { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) } return } resp := map[string]interface{}{ "key": key, } - e.ResponseWithJSON(r.Context(), w, resp, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithJSON(e.log, r.Context(), w, resp, startTime, r.URL.Path, bodySize) } -func (e *Router) updatePublicKey(w http.ResponseWriter, r *http.Request) { +func (e *handlersImpl) updatePublicKey(w http.ResponseWriter, r *http.Request) { startTime := time.Now() bodySize := 0 id, err := getSpotID(r) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) return } - bodyBytes, err := e.readBody(w, r, e.cfg.JsonSizeLimit) + bodyBytes, err := api.ReadBody(e.log, w, r, e.jsonSizeLimit) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) return } bodySize = len(bodyBytes) req := &UpdateSpotPublicKeyRequest{} if err := json.Unmarshal(bodyBytes, req); err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) return } - user := r.Context().Value("userData").(*auth.User) - key, err := e.services.Keys.Set(id, req.Expiration, user) + user := r.Context().Value("userData").(*user.User) + key, err := e.keys.Set(id, req.Expiration, user) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) return } resp := map[string]interface{}{ "key": key, } - e.ResponseWithJSON(r.Context(), w, resp, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithJSON(e.log, r.Context(), w, resp, startTime, r.URL.Path, bodySize) } -func (e *Router) spotStatus(w http.ResponseWriter, r *http.Request) { +func (e *handlersImpl) spotStatus(w http.ResponseWriter, r *http.Request) { startTime := time.Now() bodySize := 0 id, err := getSpotID(r) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) return } - user := r.Context().Value("userData").(*auth.User) - status, err := e.services.Spots.GetStatus(user, id) + user := r.Context().Value("userData").(*user.User) + status, err := e.spots.GetStatus(user, id) if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) return } resp := map[string]interface{}{ "status": status, } - e.ResponseWithJSON(r.Context(), w, resp, startTime, r.URL.Path, bodySize) -} - -func recordMetrics(requestStart time.Time, url string, code, bodySize int) { - if bodySize > 0 { - metrics.RecordRequestSize(float64(bodySize), url, code) - } - metrics.IncreaseTotalRequests() - metrics.RecordRequestDuration(float64(time.Now().Sub(requestStart).Milliseconds()), url, code) -} - -func (e *Router) readBody(w http.ResponseWriter, r *http.Request, limit int64) ([]byte, error) { - body := http.MaxBytesReader(w, r.Body, limit) - bodyBytes, err := io.ReadAll(body) - - // Close body - if closeErr := body.Close(); closeErr != nil { - e.log.Warn(r.Context(), "error while closing request body: %s", closeErr) - } - if err != nil { - return nil, err - } - return bodyBytes, nil -} - -func (e *Router) ResponseOK(ctx context.Context, w http.ResponseWriter, requestStart time.Time, url string, bodySize int) { - w.WriteHeader(http.StatusOK) - e.log.Info(ctx, "response ok") - recordMetrics(requestStart, url, http.StatusOK, bodySize) -} - -func (e *Router) ResponseWithJSON(ctx context.Context, w http.ResponseWriter, res interface{}, requestStart time.Time, url string, bodySize int) { - e.log.Info(ctx, "response ok") - body, err := json.Marshal(res) - if err != nil { - e.log.Error(ctx, "can't marshal response: %s", err) - } - w.Header().Set("Content-Type", "application/json") - w.Write(body) - recordMetrics(requestStart, url, http.StatusOK, bodySize) -} - -type response struct { - Error string `json:"error"` -} - -func (e *Router) ResponseWithError(ctx context.Context, w http.ResponseWriter, code int, err error, requestStart time.Time, url string, bodySize int) { - e.log.Error(ctx, "response error, code: %d, error: %s", code, err) - body, err := json.Marshal(&response{err.Error()}) - if err != nil { - e.log.Error(ctx, "can't marshal response: %s", err) - } - w.WriteHeader(code) - w.Write(body) - recordMetrics(requestStart, url, code, bodySize) + e.responser.ResponseWithJSON(e.log, r.Context(), w, resp, startTime, r.URL.Path, bodySize) } diff --git a/backend/pkg/spot/api/router.go b/backend/pkg/spot/api/router.go deleted file mode 100644 index a6fda7b6e..000000000 --- a/backend/pkg/spot/api/router.go +++ /dev/null @@ -1,213 +0,0 @@ -package api - -import ( - "bytes" - "fmt" - "io" - "net/http" - "openreplay/backend/pkg/spot" - "openreplay/backend/pkg/spot/auth" - "sync" - "time" - - "github.com/docker/distribution/context" - "github.com/gorilla/mux" - - spotConfig "openreplay/backend/internal/config/spot" - "openreplay/backend/internal/http/util" - "openreplay/backend/pkg/logger" -) - -type Router struct { - log logger.Logger - cfg *spotConfig.Config - router *mux.Router - mutex *sync.RWMutex - services *spot.ServicesBuilder - limiter *UserRateLimiter -} - -func NewRouter(cfg *spotConfig.Config, log logger.Logger, services *spot.ServicesBuilder) (*Router, error) { - switch { - case cfg == nil: - return nil, fmt.Errorf("config is empty") - case services == nil: - return nil, fmt.Errorf("services is empty") - case log == nil: - return nil, fmt.Errorf("logger is empty") - } - e := &Router{ - log: log, - cfg: cfg, - mutex: &sync.RWMutex{}, - services: services, - limiter: NewUserRateLimiter(10, 30, 1*time.Minute, 5*time.Minute), - } - e.init() - return e, nil -} - -func (e *Router) init() { - e.router = mux.NewRouter() - - // Root route - e.router.HandleFunc("/", e.ping) - - // Spot routes - e.router.HandleFunc("/v1/spots", e.createSpot).Methods("POST", "OPTIONS") - e.router.HandleFunc("/v1/spots/{id}", e.getSpot).Methods("GET", "OPTIONS") - e.router.HandleFunc("/v1/spots/{id}", e.updateSpot).Methods("PATCH", "OPTIONS") - e.router.HandleFunc("/v1/spots", e.getSpots).Methods("GET", "OPTIONS") - e.router.HandleFunc("/v1/spots", e.deleteSpots).Methods("DELETE", "OPTIONS") - e.router.HandleFunc("/v1/spots/{id}/comment", e.addComment).Methods("POST", "OPTIONS") - e.router.HandleFunc("/v1/spots/{id}/uploaded", e.uploadedSpot).Methods("POST", "OPTIONS") - e.router.HandleFunc("/v1/spots/{id}/video", e.getSpotVideo).Methods("GET", "OPTIONS") - e.router.HandleFunc("/v1/spots/{id}/public-key", e.getPublicKey).Methods("GET", "OPTIONS") - e.router.HandleFunc("/v1/spots/{id}/public-key", e.updatePublicKey).Methods("PATCH", "OPTIONS") - e.router.HandleFunc("/v1/spots/{id}/status", e.spotStatus).Methods("GET", "OPTIONS") - e.router.HandleFunc("/v1/ping", e.ping).Methods("GET", "OPTIONS") - - // CORS middleware - e.router.Use(e.corsMiddleware) - e.router.Use(e.authMiddleware) - e.router.Use(e.rateLimitMiddleware) - e.router.Use(e.actionMiddleware) -} - -func (e *Router) ping(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) -} - -func (e *Router) corsMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/" { - next.ServeHTTP(w, r) - } - if e.cfg.UseAccessControlHeaders { - // Prepare headers for preflight requests - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "POST,GET,PATCH,DELETE") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type,Authorization,Content-Encoding") - } - if r.Method == http.MethodOptions { - w.Header().Set("Cache-Control", "max-age=86400") - w.WriteHeader(http.StatusOK) - return - } - r = r.WithContext(context.WithValues(r.Context(), map[string]interface{}{"httpMethod": r.Method, "url": util.SafeString(r.URL.Path)})) - - next.ServeHTTP(w, r) - }) -} - -func (e *Router) authMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/" { - next.ServeHTTP(w, r) - } - isExtension := false - pathTemplate, err := mux.CurrentRoute(r).GetPathTemplate() - if err != nil { - e.log.Error(r.Context(), "failed to get path template: %s", err) - } else { - if pathTemplate == "/v1/ping" || - (pathTemplate == "/v1/spots" && r.Method == "POST") || - (pathTemplate == "/v1/spots/{id}/uploaded" && r.Method == "POST") { - isExtension = true - } - } - - // Check if the request is authorized - user, err := e.services.Auth.IsAuthorized(r.Header.Get("Authorization"), getPermissions(r.URL.Path), isExtension) - if err != nil { - e.log.Warn(r.Context(), "Unauthorized request: %s", err) - if !isSpotWithKeyRequest(r) { - w.WriteHeader(http.StatusUnauthorized) - return - } - - user, err = e.services.Keys.IsValid(r.URL.Query().Get("key")) - if err != nil { - e.log.Warn(r.Context(), "Wrong public key: %s", err) - w.WriteHeader(http.StatusUnauthorized) - return - } - } - - r = r.WithContext(context.WithValues(r.Context(), map[string]interface{}{"userData": user})) - next.ServeHTTP(w, r) - }) -} - -func isSpotWithKeyRequest(r *http.Request) bool { - pathTemplate, err := mux.CurrentRoute(r).GetPathTemplate() - if err != nil { - return false - } - getSpotPrefix := "/v1/spots/{id}" // GET - addCommentPrefix := "/v1/spots/{id}/comment" // POST - getStatusPrefix := "/v1/spots/{id}/status" // GET - if (pathTemplate == getSpotPrefix && r.Method == "GET") || - (pathTemplate == addCommentPrefix && r.Method == "POST") || - (pathTemplate == getStatusPrefix && r.Method == "GET") { - return true - } - return false -} - -func (e *Router) rateLimitMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/" { - next.ServeHTTP(w, r) - } - user := r.Context().Value("userData").(*auth.User) - rl := e.limiter.GetRateLimiter(user.ID) - - if !rl.Allow() { - http.Error(w, "Too Many Requests", http.StatusTooManyRequests) - return - } - next.ServeHTTP(w, r) - }) -} - -type statusWriter struct { - http.ResponseWriter - statusCode int -} - -func (w *statusWriter) WriteHeader(statusCode int) { - w.statusCode = statusCode - w.ResponseWriter.WriteHeader(statusCode) -} - -func (w *statusWriter) Write(b []byte) (int, error) { - if w.statusCode == 0 { - w.statusCode = http.StatusOK // Default status code is 200 - } - return w.ResponseWriter.Write(b) -} - -func (e *Router) actionMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/" { - next.ServeHTTP(w, r) - } - // Read body and restore the io.ReadCloser to its original state - bodyBytes, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "can't read body", http.StatusBadRequest) - return - } - r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - // Use custom response writer to get the status code - sw := &statusWriter{ResponseWriter: w} - // Serve the request - next.ServeHTTP(sw, r) - e.logRequest(r, bodyBytes, sw.statusCode) - }) -} - -func (e *Router) GetHandler() http.Handler { - return e.router -} diff --git a/backend/pkg/spot/api/tracer.go b/backend/pkg/spot/api/tracer.go deleted file mode 100644 index 14e006fe0..000000000 --- a/backend/pkg/spot/api/tracer.go +++ /dev/null @@ -1,7 +0,0 @@ -package api - -import ( - "net/http" -) - -func (e *Router) logRequest(r *http.Request, bodyBytes []byte, statusCode int) {} diff --git a/backend/pkg/spot/builder.go b/backend/pkg/spot/builder.go index 047318844..14ae61365 100644 --- a/backend/pkg/spot/builder.go +++ b/backend/pkg/spot/builder.go @@ -1,39 +1,53 @@ package spot import ( + "openreplay/backend/pkg/metrics/web" + "openreplay/backend/pkg/server/tracer" + "time" + "openreplay/backend/internal/config/spot" "openreplay/backend/pkg/db/postgres/pool" "openreplay/backend/pkg/flakeid" "openreplay/backend/pkg/logger" - "openreplay/backend/pkg/objectstorage" "openreplay/backend/pkg/objectstorage/store" - "openreplay/backend/pkg/spot/auth" + "openreplay/backend/pkg/server/api" + "openreplay/backend/pkg/server/auth" + "openreplay/backend/pkg/server/keys" + "openreplay/backend/pkg/server/limiter" + spotAPI "openreplay/backend/pkg/spot/api" "openreplay/backend/pkg/spot/service" "openreplay/backend/pkg/spot/transcoder" ) type ServicesBuilder struct { - Flaker *flakeid.Flaker - ObjStorage objectstorage.ObjectStorage - Auth auth.Auth - Spots service.Spots - Keys service.Keys - Transcoder transcoder.Transcoder + Auth auth.Auth + RateLimiter *limiter.UserRateLimiter + AuditTrail tracer.Tracer + SpotsAPI api.Handlers } -func NewServiceBuilder(log logger.Logger, cfg *spot.Config, pgconn pool.Pool) (*ServicesBuilder, error) { +func NewServiceBuilder(log logger.Logger, cfg *spot.Config, webMetrics web.Web, pgconn pool.Pool) (*ServicesBuilder, error) { objStore, err := store.NewStore(&cfg.ObjectsConfig) if err != nil { return nil, err } flaker := flakeid.NewFlaker(cfg.WorkerID) spots := service.NewSpots(log, pgconn, flaker) + transcoder := transcoder.NewTranscoder(cfg, log, objStore, pgconn, spots) + keys := keys.NewKeys(log, pgconn) + auditrail, err := tracer.NewTracer(log, pgconn) + if err != nil { + return nil, err + } + responser := api.NewResponser(webMetrics) + handlers, err := spotAPI.NewHandlers(log, cfg, responser, spots, objStore, transcoder, keys) + if err != nil { + return nil, err + } return &ServicesBuilder{ - Flaker: flaker, - ObjStorage: objStore, - Auth: auth.NewAuth(log, cfg.JWTSecret, cfg.JWTSpotSecret, pgconn), - Spots: spots, - Keys: service.NewKeys(log, pgconn), - Transcoder: transcoder.NewTranscoder(cfg, log, objStore, pgconn, spots), + Auth: auth.NewAuth(log, cfg.JWTSecret, cfg.JWTSpotSecret, pgconn, keys), + RateLimiter: limiter.NewUserRateLimiter(10, 30, 1*time.Minute, 5*time.Minute), + AuditTrail: auditrail, + SpotsAPI: handlers, }, nil } diff --git a/backend/pkg/spot/service/spot.go b/backend/pkg/spot/service/spot.go index a2ef2bca9..79545c7bb 100644 --- a/backend/pkg/spot/service/spot.go +++ b/backend/pkg/spot/service/spot.go @@ -4,12 +4,12 @@ import ( "context" "encoding/json" "fmt" + "openreplay/backend/pkg/server/user" "time" "openreplay/backend/pkg/db/postgres/pool" "openreplay/backend/pkg/flakeid" "openreplay/backend/pkg/logger" - "openreplay/backend/pkg/spot/auth" ) const MaxNameLength = 64 @@ -58,14 +58,14 @@ type Update struct { } type Spots interface { - Add(user *auth.User, name, comment string, duration int, crop []int) (*Spot, error) - GetByID(user *auth.User, spotID uint64) (*Spot, error) - Get(user *auth.User, opts *GetOpts) ([]*Spot, uint64, bool, error) - UpdateName(user *auth.User, spotID uint64, newName string) (*Spot, error) - AddComment(user *auth.User, spotID uint64, comment *Comment) (*Spot, error) - Delete(user *auth.User, spotIds []uint64) error + Add(user *user.User, name, comment string, duration int, crop []int) (*Spot, error) + GetByID(user *user.User, spotID uint64) (*Spot, error) + Get(user *user.User, opts *GetOpts) ([]*Spot, uint64, bool, error) + UpdateName(user *user.User, spotID uint64, newName string) (*Spot, error) + AddComment(user *user.User, spotID uint64, comment *Comment) (*Spot, error) + Delete(user *user.User, spotIds []uint64) error SetStatus(spotID uint64, status string) error - GetStatus(user *auth.User, spotID uint64) (string, error) + GetStatus(user *user.User, spotID uint64) (string, error) } func NewSpots(log logger.Logger, pgconn pool.Pool, flaker *flakeid.Flaker) Spots { @@ -76,7 +76,7 @@ func NewSpots(log logger.Logger, pgconn pool.Pool, flaker *flakeid.Flaker) Spots } } -func (s *spotsImpl) Add(user *auth.User, name, comment string, duration int, crop []int) (*Spot, error) { +func (s *spotsImpl) Add(user *user.User, name, comment string, duration int, crop []int) (*Spot, error) { switch { case user == nil: return nil, fmt.Errorf("user is required") @@ -142,7 +142,7 @@ func (s *spotsImpl) add(spot *Spot) error { return nil } -func (s *spotsImpl) GetByID(user *auth.User, spotID uint64) (*Spot, error) { +func (s *spotsImpl) GetByID(user *user.User, spotID uint64) (*Spot, error) { switch { case user == nil: return nil, fmt.Errorf("user is required") @@ -152,7 +152,7 @@ func (s *spotsImpl) GetByID(user *auth.User, spotID uint64) (*Spot, error) { return s.getByID(spotID, user) } -func (s *spotsImpl) getByID(spotID uint64, user *auth.User) (*Spot, error) { +func (s *spotsImpl) getByID(spotID uint64, user *user.User) (*Spot, error) { sql := `SELECT s.name, u.email, s.duration, s.crop, s.comments, s.created_at FROM spots.spots s JOIN public.users u ON s.user_id = u.user_id @@ -176,7 +176,7 @@ func (s *spotsImpl) getByID(spotID uint64, user *auth.User) (*Spot, error) { return spot, nil } -func (s *spotsImpl) Get(user *auth.User, opts *GetOpts) ([]*Spot, uint64, bool, error) { +func (s *spotsImpl) Get(user *user.User, opts *GetOpts) ([]*Spot, uint64, bool, error) { switch { case user == nil: return nil, 0, false, fmt.Errorf("user is required") @@ -200,7 +200,7 @@ func (s *spotsImpl) Get(user *auth.User, opts *GetOpts) ([]*Spot, uint64, bool, return s.getAll(user, opts) } -func (s *spotsImpl) getAll(user *auth.User, opts *GetOpts) ([]*Spot, uint64, bool, error) { +func (s *spotsImpl) getAll(user *user.User, opts *GetOpts) ([]*Spot, uint64, bool, error) { sql := `SELECT COUNT(1) OVER () AS total, s.spot_id, s.name, u.email, s.duration, s.created_at FROM spots.spots s JOIN public.users u ON s.user_id = u.user_id @@ -261,7 +261,7 @@ func (s *spotsImpl) doesTenantHasSpots(tenantID uint64) bool { return count > 0 } -func (s *spotsImpl) UpdateName(user *auth.User, spotID uint64, newName string) (*Spot, error) { +func (s *spotsImpl) UpdateName(user *user.User, spotID uint64, newName string) (*Spot, error) { switch { case user == nil: return nil, fmt.Errorf("user is required") @@ -276,7 +276,7 @@ func (s *spotsImpl) UpdateName(user *auth.User, spotID uint64, newName string) ( return s.updateName(spotID, newName, user) } -func (s *spotsImpl) updateName(spotID uint64, newName string, user *auth.User) (*Spot, error) { +func (s *spotsImpl) updateName(spotID uint64, newName string, user *user.User) (*Spot, error) { sql := `WITH updated AS ( UPDATE spots.spots SET name = $1, updated_at = $2 WHERE spot_id = $3 AND tenant_id = $4 AND deleted_at IS NULL RETURNING *) @@ -291,7 +291,7 @@ func (s *spotsImpl) updateName(spotID uint64, newName string, user *auth.User) ( return &Spot{ID: spotID, Name: newName}, nil } -func (s *spotsImpl) AddComment(user *auth.User, spotID uint64, comment *Comment) (*Spot, error) { +func (s *spotsImpl) AddComment(user *user.User, spotID uint64, comment *Comment) (*Spot, error) { switch { case user == nil: return nil, fmt.Errorf("user is required") @@ -311,7 +311,7 @@ func (s *spotsImpl) AddComment(user *auth.User, spotID uint64, comment *Comment) return s.addComment(spotID, comment, user) } -func (s *spotsImpl) addComment(spotID uint64, newComment *Comment, user *auth.User) (*Spot, error) { +func (s *spotsImpl) addComment(spotID uint64, newComment *Comment, user *user.User) (*Spot, error) { sql := `WITH updated AS ( UPDATE spots.spots SET comments = array_append(comments, $1), updated_at = $2 @@ -332,7 +332,7 @@ func (s *spotsImpl) addComment(spotID uint64, newComment *Comment, user *auth.Us return &Spot{ID: spotID}, nil } -func (s *spotsImpl) Delete(user *auth.User, spotIds []uint64) error { +func (s *spotsImpl) Delete(user *user.User, spotIds []uint64) error { switch { case user == nil: return fmt.Errorf("user is required") @@ -342,7 +342,7 @@ func (s *spotsImpl) Delete(user *auth.User, spotIds []uint64) error { return s.deleteSpots(spotIds, user) } -func (s *spotsImpl) deleteSpots(spotIds []uint64, user *auth.User) error { +func (s *spotsImpl) deleteSpots(spotIds []uint64, user *user.User) error { sql := `WITH updated AS (UPDATE spots.spots SET deleted_at = NOW() WHERE tenant_id = $1 AND spot_id IN (` args := []interface{}{user.TenantID} for i, spotID := range spotIds { @@ -378,7 +378,7 @@ func (s *spotsImpl) SetStatus(spotID uint64, status string) error { return nil } -func (s *spotsImpl) GetStatus(user *auth.User, spotID uint64) (string, error) { +func (s *spotsImpl) GetStatus(user *user.User, spotID uint64) (string, error) { switch { case user == nil: return "", fmt.Errorf("user is required") diff --git a/backend/pkg/tags/api/handlers.go b/backend/pkg/tags/api/handlers.go new file mode 100644 index 000000000..c9a92bd2e --- /dev/null +++ b/backend/pkg/tags/api/handlers.go @@ -0,0 +1,73 @@ +package api + +import ( + "context" + "fmt" + "net/http" + "time" + + "openreplay/backend/pkg/logger" + "openreplay/backend/pkg/server/api" + "openreplay/backend/pkg/sessions" + "openreplay/backend/pkg/tags" + "openreplay/backend/pkg/token" +) + +type handlersImpl struct { + log logger.Logger + responser *api.Responser + tokenizer *token.Tokenizer + sessions sessions.Sessions + tags tags.Tags +} + +func NewHandlers(log logger.Logger, responser *api.Responser, tokenizer *token.Tokenizer, sessions sessions.Sessions, tags tags.Tags) (api.Handlers, error) { + return &handlersImpl{ + log: log, + responser: responser, + tokenizer: tokenizer, + sessions: sessions, + tags: tags, + }, nil +} + +func (e *handlersImpl) GetAll() []*api.Description { + return []*api.Description{ + {"/v1/tags", e.getTags, "GET"}, + } +} + +func (e *handlersImpl) getTags(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + bodySize := 0 + + // TODO: move check authorization into middleware (we gonna have 2 different auth middlewares) + sessionData, err := e.tokenizer.ParseFromHTTPRequest(r) + if sessionData != nil { + r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) + } + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, bodySize) + return + } + sessInfo, err := e.sessions.Get(sessionData.ID) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, bodySize) + return + } + + // Add sessionID and projectID to context + r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) + r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", sessInfo.ProjectID))) + + // Get tags + tags, err := e.tags.Get(sessInfo.ProjectID) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + return + } + type UrlResponse struct { + Tags interface{} `json:"tags"` + } + e.responser.ResponseWithJSON(e.log, r.Context(), w, &UrlResponse{Tags: tags}, startTime, r.URL.Path, bodySize) +} diff --git a/backend/pkg/uxtesting/api/handlers.go b/backend/pkg/uxtesting/api/handlers.go new file mode 100644 index 000000000..9188d6c88 --- /dev/null +++ b/backend/pkg/uxtesting/api/handlers.go @@ -0,0 +1,211 @@ +package api + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "time" + + "github.com/gorilla/mux" + + "openreplay/backend/pkg/logger" + "openreplay/backend/pkg/objectstorage" + "openreplay/backend/pkg/server/api" + "openreplay/backend/pkg/sessions" + "openreplay/backend/pkg/token" + "openreplay/backend/pkg/uxtesting" +) + +type handlersImpl struct { + log logger.Logger + responser *api.Responser + jsonSizeLimit int64 + tokenizer *token.Tokenizer + sessions sessions.Sessions + uxTesting uxtesting.UXTesting + objStorage objectstorage.ObjectStorage +} + +func NewHandlers(log logger.Logger, responser *api.Responser, jsonSizeLimit int64, tokenizer *token.Tokenizer, sessions sessions.Sessions, + uxTesting uxtesting.UXTesting, objStorage objectstorage.ObjectStorage) (api.Handlers, error) { + return &handlersImpl{ + log: log, + responser: responser, + jsonSizeLimit: jsonSizeLimit, + tokenizer: tokenizer, + sessions: sessions, + uxTesting: uxTesting, + objStorage: objStorage, + }, nil +} + +func (e *handlersImpl) GetAll() []*api.Description { + return []*api.Description{ + {"/v1/web/uxt/signals/test", e.sendUXTestSignal, "POST"}, + {"/v1/web/uxt/signals/task", e.sendUXTaskSignal, "POST"}, + {"/v1/web/uxt/test/{id}", e.getUXTestInfo, "GET"}, + {"/v1/web/uxt/upload-url", e.getUXUploadUrl, "GET"}, + } +} + +func (e *handlersImpl) getUXTestInfo(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + bodySize := 0 + + // Check authorization + sessionData, err := e.tokenizer.ParseFromHTTPRequest(r) + if sessionData != nil { + r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) + } + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, bodySize) + return + } + + sess, err := e.sessions.Get(sessionData.ID) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusForbidden, err, startTime, r.URL.Path, bodySize) + return + } + + // Add projectID to context + r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", sess.ProjectID))) + + // Get taskID + vars := mux.Vars(r) + id := vars["id"] + + // Get task info + info, err := e.uxTesting.GetInfo(id) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + return + } + if sess.ProjectID != info.ProjectID { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusForbidden, errors.New("project mismatch"), startTime, r.URL.Path, bodySize) + return + } + type TaskInfoResponse struct { + Task *uxtesting.UXTestInfo `json:"test"` + } + e.responser.ResponseWithJSON(e.log, r.Context(), w, &TaskInfoResponse{Task: info}, startTime, r.URL.Path, bodySize) +} + +func (e *handlersImpl) sendUXTestSignal(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + bodySize := 0 + + // Check authorization + sessionData, err := e.tokenizer.ParseFromHTTPRequest(r) + if sessionData != nil { + r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) + } + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, bodySize) + return + } + + // Add sessionID and projectID to context + if info, err := e.sessions.Get(sessionData.ID); err == nil { + r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", info.ProjectID))) + } + + bodyBytes, err := api.ReadBody(e.log, w, r, e.jsonSizeLimit) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) + return + } + bodySize = len(bodyBytes) + + // Parse request body + req := &uxtesting.TestSignal{} + + if err := json.Unmarshal(bodyBytes, req); err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + return + } + req.SessionID = sessionData.ID + + // Save test signal + if err := e.uxTesting.SetTestSignal(req); err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + return + } + e.responser.ResponseOK(e.log, r.Context(), w, startTime, r.URL.Path, bodySize) +} + +func (e *handlersImpl) sendUXTaskSignal(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + bodySize := 0 + + // Check authorization + sessionData, err := e.tokenizer.ParseFromHTTPRequest(r) + if sessionData != nil { + r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) + } + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, bodySize) + return + } + + // Add sessionID and projectID to context + if info, err := e.sessions.Get(sessionData.ID); err == nil { + r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", info.ProjectID))) + } + + bodyBytes, err := api.ReadBody(e.log, w, r, e.jsonSizeLimit) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusRequestEntityTooLarge, err, startTime, r.URL.Path, bodySize) + return + } + bodySize = len(bodyBytes) + + // Parse request body + req := &uxtesting.TaskSignal{} + + if err := json.Unmarshal(bodyBytes, req); err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + return + } + req.SessionID = sessionData.ID + + // Save test signal + if err := e.uxTesting.SetTaskSignal(req); err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + return + } + e.responser.ResponseOK(e.log, r.Context(), w, startTime, r.URL.Path, bodySize) +} + +func (e *handlersImpl) getUXUploadUrl(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + bodySize := 0 + + // Check authorization + sessionData, err := e.tokenizer.ParseFromHTTPRequest(r) + if sessionData != nil { + r = r.WithContext(context.WithValue(r.Context(), "sessionID", fmt.Sprintf("%d", sessionData.ID))) + } + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, bodySize) + return + } + + // Add sessionID and projectID to context + if info, err := e.sessions.Get(sessionData.ID); err == nil { + r = r.WithContext(context.WithValue(r.Context(), "projectID", fmt.Sprintf("%d", info.ProjectID))) + } + + key := fmt.Sprintf("%d/ux_webcam_record.webm", sessionData.ID) + url, err := e.objStorage.GetPreSignedUploadUrl(key) + if err != nil { + e.responser.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + return + } + type UrlResponse struct { + URL string `json:"url"` + } + e.responser.ResponseWithJSON(e.log, r.Context(), w, &UrlResponse{URL: url}, startTime, r.URL.Path, bodySize) +} diff --git a/ee/backend/internal/http/router/conditions.go b/ee/backend/internal/http/router/conditions.go deleted file mode 100644 index 94c103c1d..000000000 --- a/ee/backend/internal/http/router/conditions.go +++ /dev/null @@ -1,37 +0,0 @@ -package router - -import ( - "github.com/gorilla/mux" - "net/http" - "strconv" - "time" -) - -func (e *Router) getConditions(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() - bodySize := 0 - - // Check authorization - _, err := e.services.Tokenizer.ParseFromHTTPRequest(r) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, bodySize) - return - } - - // Get taskID - vars := mux.Vars(r) - projID := vars["project"] - projectID, err := strconv.Atoi(projID) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) - return - } - - // Get task info - info, err := e.services.Conditions.Get(uint32(projectID)) - if err != nil { - e.ResponseWithError(r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) - return - } - e.ResponseWithJSON(r.Context(), w, info, startTime, r.URL.Path, bodySize) -} diff --git a/ee/backend/pkg/conditions/api/handlers.go b/ee/backend/pkg/conditions/api/handlers.go new file mode 100644 index 000000000..1d53e71d2 --- /dev/null +++ b/ee/backend/pkg/conditions/api/handlers.go @@ -0,0 +1,64 @@ +package api + +import ( + "net/http" + "strconv" + "time" + + "github.com/gorilla/mux" + + "openreplay/backend/pkg/conditions" + "openreplay/backend/pkg/logger" + "openreplay/backend/pkg/server/api" + "openreplay/backend/pkg/token" +) + +type handlersImpl struct { + log logger.Logger + tokenizer *token.Tokenizer + conditions conditions.Conditions +} + +func NewHandlers(log logger.Logger, tokenizer *token.Tokenizer, conditions conditions.Conditions) (api.Handlers, error) { + return &handlersImpl{ + log: log, + tokenizer: tokenizer, + conditions: conditions, + }, nil +} + +func (e *handlersImpl) GetAll() []*api.Description { + return []*api.Description{ + {"/v1/web/conditions/{project}", e.getConditions, "GET"}, + {"/v1/mobile/conditions/{project}", e.getConditions, "GET"}, + } +} + +func (e *handlersImpl) getConditions(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + bodySize := 0 + + // Check authorization + _, err := e.tokenizer.ParseFromHTTPRequest(r) + if err != nil { + api.ResponseWithError(e.log, r.Context(), w, http.StatusUnauthorized, err, startTime, r.URL.Path, bodySize) + return + } + + // Get taskID + vars := mux.Vars(r) + projID := vars["project"] + projectID, err := strconv.Atoi(projID) + if err != nil { + api.ResponseWithError(e.log, r.Context(), w, http.StatusBadRequest, err, startTime, r.URL.Path, bodySize) + return + } + + // Get task info + info, err := e.conditions.Get(uint32(projectID)) + if err != nil { + api.ResponseWithError(e.log, r.Context(), w, http.StatusInternalServerError, err, startTime, r.URL.Path, bodySize) + return + } + api.ResponseWithJSON(e.log, r.Context(), w, info, startTime, r.URL.Path, bodySize) +} diff --git a/ee/backend/pkg/server/tracer/middleware.go b/ee/backend/pkg/server/tracer/middleware.go new file mode 100644 index 000000000..546e36428 --- /dev/null +++ b/ee/backend/pkg/server/tracer/middleware.go @@ -0,0 +1,96 @@ +package tracer + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + + "github.com/gorilla/mux" + + "openreplay/backend/pkg/server/user" +) + +type statusWriter struct { + http.ResponseWriter + statusCode int +} + +func (w *statusWriter) WriteHeader(statusCode int) { + w.statusCode = statusCode + w.ResponseWriter.WriteHeader(statusCode) +} + +func (w *statusWriter) Write(b []byte) (int, error) { + if w.statusCode == 0 { + w.statusCode = http.StatusOK + } + return w.ResponseWriter.Write(b) +} + +func (t *tracerImpl) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Read body and restore the io.ReadCloser to its original state + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "can't read body", http.StatusBadRequest) + return + } + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + // Use custom response writer to get the status code + sw := &statusWriter{ResponseWriter: w} + // Serve the request + next.ServeHTTP(sw, r) + t.logRequest(r, bodyBytes, sw.statusCode) + }) +} + +var routeMatch = map[string]string{ + "POST" + "/spot/v1/spots": "createSpot", + "GET" + "/spot/v1/spots/{id}": "getSpot", + "PATCH" + "/spot/v1/spots/{id}": "updateSpot", + "GET" + "/spot/v1/spots": "getSpots", + "DELETE" + "/spot/v1/spots": "deleteSpots", + "POST" + "/spot/v1/spots/{id}/comment": "addComment", + "GET" + "/spot/v1/spots/{id}/video": "getSpotVideo", + "PATCH" + "/spot/v1/spots/{id}/public-key": "updatePublicKey", +} + +func (t *tracerImpl) logRequest(r *http.Request, bodyBytes []byte, statusCode int) { + pathTemplate, err := mux.CurrentRoute(r).GetPathTemplate() + if err != nil { + t.log.Error(r.Context(), "failed to get path template: %s", err) + } + t.log.Debug(r.Context(), "path template: %s", pathTemplate) + if _, ok := routeMatch[r.Method+pathTemplate]; !ok { + t.log.Debug(r.Context(), "no match for route: %s %s", r.Method, pathTemplate) + return + } + // Convert the parameters to json + query := r.URL.Query() + params := make(map[string]interface{}) + for key, values := range query { + if len(values) > 1 { + params[key] = values + } else { + params[key] = values[0] + } + } + jsonData, err := json.Marshal(params) + if err != nil { + t.log.Error(r.Context(), "failed to marshal query parameters: %s", err) + } + requestData := &RequestData{ + Action: routeMatch[r.Method+pathTemplate], + Method: r.Method, + PathFormat: pathTemplate, + Endpoint: r.URL.Path, + Payload: bodyBytes, + Parameters: jsonData, + Status: statusCode, + } + userData := r.Context().Value("userData").(*user.User) + t.trace(userData, requestData) + // DEBUG + t.log.Debug(r.Context(), "request data: %v", requestData) +} diff --git a/ee/backend/pkg/server/tracer/tracer.go b/ee/backend/pkg/server/tracer/tracer.go new file mode 100644 index 000000000..b66d2cbe8 --- /dev/null +++ b/ee/backend/pkg/server/tracer/tracer.go @@ -0,0 +1,106 @@ +package tracer + +import ( + "context" + "errors" + "net/http" + + "openreplay/backend/pkg/db/postgres" + db "openreplay/backend/pkg/db/postgres/pool" + "openreplay/backend/pkg/logger" + "openreplay/backend/pkg/pool" + "openreplay/backend/pkg/server/user" +) + +type Tracer interface { + Middleware(next http.Handler) http.Handler + Close() error +} + +type tracerImpl struct { + log logger.Logger + conn db.Pool + traces postgres.Bulk + saver pool.WorkerPool +} + +func NewTracer(log logger.Logger, conn db.Pool) (Tracer, error) { + switch { + case log == nil: + return nil, errors.New("logger is required") + case conn == nil: + return nil, errors.New("connection is required") + } + tracer := &tracerImpl{ + log: log, + conn: conn, + } + if err := tracer.initBulk(); err != nil { + return nil, err + } + tracer.saver = pool.NewPool(1, 200, tracer.sendTraces) + return tracer, nil +} + +func (t *tracerImpl) initBulk() (err error) { + t.traces, err = postgres.NewBulk(t.conn, + "traces", + "(user_id, tenant_id, auth, action, method, path_format, endpoint, payload, parameters, status)", + "($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d)", + 10, 50) + if err != nil { + return err + } + return nil +} + +type Task struct { + UserID *uint64 + TenantID uint64 + Auth *string + Data *RequestData +} + +func (t *tracerImpl) sendTraces(payload interface{}) { + rec := payload.(*Task) + t.log.Debug(context.Background(), "Sending traces, %v", rec) + if err := t.traces.Append(rec.UserID, rec.TenantID, rec.Auth, rec.Data.Action, rec.Data.Method, rec.Data.PathFormat, + rec.Data.Endpoint, rec.Data.Payload, rec.Data.Parameters, rec.Data.Status); err != nil { + t.log.Error(context.Background(), "can't append trace: %s", err) + } +} + +type RequestData struct { + Action string + Method string + PathFormat string + Endpoint string + Payload []byte + Parameters []byte + Status int +} + +func (t *tracerImpl) trace(user *user.User, data *RequestData) error { + switch { + case user == nil: + return errors.New("user is required") + case data == nil: + return errors.New("request is required") + } + trace := &Task{ + UserID: &user.ID, + TenantID: user.TenantID, + Auth: &user.AuthMethod, + Data: data, + } + t.saver.Submit(trace) + return nil +} + +func (t *tracerImpl) Close() error { + t.saver.Stop() + if err := t.traces.Send(); err != nil { + return err + } + return nil +} diff --git a/ee/backend/internal/http/router/web-start.go b/ee/backend/pkg/sessions/api/web/model.go similarity index 93% rename from ee/backend/internal/http/router/web-start.go rename to ee/backend/pkg/sessions/api/web/model.go index 7514f0164..edeb7f643 100644 --- a/ee/backend/internal/http/router/web-start.go +++ b/ee/backend/pkg/sessions/api/web/model.go @@ -1,4 +1,10 @@ -package router +package web + +type NotStartedRequest struct { + ProjectKey *string `json:"projectKey"` + TrackerVersion string `json:"trackerVersion"` + DoNotTrack bool `json:"DoNotTrack"` +} type StartSessionRequest struct { Token string `json:"token"`