diff --git a/internal/api/external.go b/internal/api/external.go index f941f55e5..89f75a4fc 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -15,6 +15,7 @@ import ( jwt "github.com/golang-jwt/jwt/v5" "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" "github.com/supabase/auth/internal/storage" @@ -106,8 +107,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ claims.LinkingTargetID = linkingTargetUser.ID.String() } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenString, err := token.SignedString([]byte(config.JWT.Secret)) + tokenString, err := signJwt(&config.JWT, claims) if err != nil { return "", internalServerError("Error creating state").WithInternalError(err) } @@ -491,9 +491,20 @@ func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.C } config := a.config claims := ExternalProviderClaims{} - p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + p := jwt.NewParser(jwt.WithValidMethods(config.JWT.ValidMethods)) _, err := p.ParseWithClaims(state, &claims, func(token *jwt.Token) (interface{}, error) { - return []byte(config.JWT.Secret), nil + if kid, ok := token.Header["kid"]; ok { + if kidStr, ok := kid.(string); ok { + return conf.FindPublicKeyByKid(kidStr, &config.JWT) + } + } + if alg, ok := token.Header["alg"]; ok { + if alg == jwt.SigningMethodHS256.Name { + // preserve backward compatibility for cases where the kid is not set + return []byte(config.JWT.Secret), nil + } + } + return nil, fmt.Errorf("missing kid") }) if err != nil { return ctx, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err) diff --git a/internal/api/jwks.go b/internal/api/jwks.go index d03ae03fb..b8304d2dd 100644 --- a/internal/api/jwks.go +++ b/internal/api/jwks.go @@ -3,8 +3,10 @@ package api import ( "net/http" + jwt "github.com/golang-jwt/jwt/v5" "github.com/lestrrat-go/jwx/v2/jwa" jwk "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/supabase/auth/internal/conf" ) type JwksResponse struct { @@ -28,3 +30,32 @@ func (a *API) Jwks(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Cache-Control", "public, max-age=600") return sendJSON(w, http.StatusOK, resp) } + +func signJwt(config *conf.JWTConfiguration, claims jwt.Claims) (string, error) { + signingJwk, err := conf.GetSigningJwk(config) + if err != nil { + return "", err + } + signingMethod := conf.GetSigningAlg(signingJwk) + token := jwt.NewWithClaims(signingMethod, claims) + if token.Header == nil { + token.Header = make(map[string]interface{}) + } + + if _, ok := token.Header["kid"]; !ok { + if kid := signingJwk.KeyID(); kid != "" { + token.Header["kid"] = kid + } + } + // this serializes the aud claim to a string + jwt.MarshalSingleStringAsArray = false + signingKey, err := conf.GetSigningKey(signingJwk) + if err != nil { + return "", err + } + signed, err := token.SignedString(signingKey) + if err != nil { + return "", err + } + return signed, nil +} diff --git a/internal/api/token.go b/internal/api/token.go index f9dc89829..f9a13ac00 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -349,7 +349,6 @@ func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user IsAnonymous: user.IsAnonymous, } - var token *jwt.Token var gotrueClaims jwt.Claims = claims if config.Hook.CustomAccessToken.Enabled { input := hooks.CustomAccessTokenInput{ @@ -367,30 +366,7 @@ func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user gotrueClaims = jwt.MapClaims(output.Claims) } - signingJwk, err := conf.GetSigningJwk(&config.JWT) - if err != nil { - return "", 0, err - } - - signingMethod := conf.GetSigningAlg(signingJwk) - token = jwt.NewWithClaims(signingMethod, gotrueClaims) - if token.Header == nil { - token.Header = make(map[string]interface{}) - } - - if _, ok := token.Header["kid"]; !ok { - if kid := signingJwk.KeyID(); kid != "" { - token.Header["kid"] = kid - } - } - - // this serializes the aud claim to a string - jwt.MarshalSingleStringAsArray = false - signingKey, err := conf.GetSigningKey(signingJwk) - if err != nil { - return "", 0, err - } - signed, err := token.SignedString(signingKey) + signed, err := signJwt(&config.JWT, gotrueClaims) if err != nil { return "", 0, err }