98 lines
2.7 KiB
Go
98 lines
2.7 KiB
Go
package jwt
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
|
|
ctxStore "github.com/docker/distribution/context"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
|
|
"openreplay/backend/pkg/db/postgres/pool"
|
|
"openreplay/backend/pkg/logger"
|
|
"openreplay/backend/pkg/server/api"
|
|
"openreplay/backend/pkg/server/user"
|
|
)
|
|
|
|
type authImpl struct {
|
|
log logger.Logger
|
|
jwtSecret string
|
|
pgconn pool.Pool
|
|
}
|
|
|
|
func NewAuth(log logger.Logger, jwtSecret string, pgConn pool.Pool) api.RouterMiddleware {
|
|
return &authImpl{
|
|
log: log,
|
|
jwtSecret: jwtSecret,
|
|
pgconn: pgConn,
|
|
}
|
|
}
|
|
|
|
func (a *authImpl) Middleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
jwtInfo, err := parseJWT(authHeader, a.jwtSecret)
|
|
if err != nil {
|
|
a.log.Warn(r.Context(), "Unauthorized request, wrong jwt token: %s", err)
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
userInfo, err := authUser(a.pgconn, jwtInfo.UserId, jwtInfo.TenantID, int(jwtInfo.IssuedAt.Unix()))
|
|
if err != nil {
|
|
a.log.Warn(r.Context(), "Unauthorized request, user not found: %s", err)
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
r = r.WithContext(ctxStore.WithValues(r.Context(), map[string]interface{}{"userData": userInfo}))
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func parseJWT(authHeader, secret string) (*user.JWTClaims, error) {
|
|
if authHeader == "" {
|
|
return nil, fmt.Errorf("authorization header missing")
|
|
}
|
|
tokenParts := strings.Split(authHeader, "Bearer ")
|
|
if len(tokenParts) != 2 {
|
|
return nil, fmt.Errorf("invalid authorization header")
|
|
}
|
|
tokenString := tokenParts[1]
|
|
|
|
claims := &user.JWTClaims{}
|
|
token, err := jwt.ParseWithClaims(tokenString, claims,
|
|
func(token *jwt.Token) (interface{}, error) {
|
|
return []byte(secret), nil
|
|
})
|
|
if err != nil || !token.Valid {
|
|
fmt.Printf("token err: %v\n", err)
|
|
return nil, fmt.Errorf("invalid token")
|
|
}
|
|
return claims, nil
|
|
}
|
|
|
|
func authUser(conn pool.Pool, userID, tenantID, jwtIAT int) (*user.User, error) {
|
|
sql := `
|
|
SELECT user_id, name, email, EXTRACT(epoch FROM spot_jwt_iat)::BIGINT AS spot_jwt_iat
|
|
FROM public.users
|
|
WHERE user_id = $1 AND deleted_at IS NULL
|
|
LIMIT 1;`
|
|
newUser := &user.User{TenantID: 1, AuthMethod: "jwt"}
|
|
if err := conn.QueryRow(sql, userID).Scan(&newUser.ID, &newUser.Name, &newUser.Email, &newUser.JwtIat); err != nil {
|
|
return nil, fmt.Errorf("user not found")
|
|
}
|
|
if newUser.JwtIat == 0 || abs(jwtIAT-newUser.JwtIat) > 1 {
|
|
return nil, fmt.Errorf("token has been updated")
|
|
}
|
|
return newUser, nil
|
|
}
|
|
|
|
func abs(x int) int {
|
|
if x < 0 {
|
|
return -x
|
|
}
|
|
return x
|
|
}
|