From bf276ab49753642793471815727559172fea4efc Mon Sep 17 00:00:00 2001 From: Kang Ming Date: Wed, 28 Aug 2024 16:33:51 -0700 Subject: [PATCH] fix: apply shared limiters before email / sms is sent (#1748) ## What kind of change does this PR introduce? * Fixes https://github.com/supabase/auth/issues/1236 * Reduces the number of false positives arising from validation errors when counting rate limits for emails / sms sent * This change applies the shared rate limiter for email and phone functions before the actual email / sms is being sent out rather than at the start of the request * The `limitEmailOrPhoneSentHandler()` now initialises the rate limiters and sets it in the request context so we can subsequently retrieve it right before the email / sms is sent ## What is the current behavior? Please link any relevant issues here. ## What is the new behavior? Feel free to include screenshots if it includes visual changes. ## Additional context Add any other context or screenshots. --- internal/api/context.go | 19 +++++ internal/api/external.go | 6 +- internal/api/identity.go | 5 +- internal/api/invite.go | 2 +- internal/api/magic_link.go | 6 +- internal/api/mail.go | 147 ++++++++++++++++++-------------- internal/api/middleware.go | 36 ++------ internal/api/middleware_test.go | 41 +++++++-- internal/api/phone.go | 12 ++- internal/api/reauthenticate.go | 9 -- internal/api/recover.go | 6 +- internal/api/resend.go | 13 +-- internal/api/signup.go | 14 +-- internal/api/user.go | 6 +- internal/mailer/mailer.go | 1 - 15 files changed, 158 insertions(+), 165 deletions(-) diff --git a/internal/api/context.go b/internal/api/context.go index 3047f3dd6..ff01e7120 100644 --- a/internal/api/context.go +++ b/internal/api/context.go @@ -4,6 +4,7 @@ import ( "context" "net/url" + "github.com/didip/tollbooth/v5/limiter" jwt "github.com/golang-jwt/jwt/v5" "github.com/supabase/auth/internal/models" ) @@ -31,6 +32,7 @@ const ( ssoProviderKey = contextKey("sso_provider") externalHostKey = contextKey("external_host") flowStateKey = contextKey("flow_state_id") + sharedLimiterKey = contextKey("shared_limiter") ) // withToken adds the JWT token to the context. @@ -241,3 +243,20 @@ func getExternalHost(ctx context.Context) *url.URL { } return obj.(*url.URL) } + +type SharedLimiter struct { + EmailLimiter *limiter.Limiter + PhoneLimiter *limiter.Limiter +} + +func withLimiter(ctx context.Context, limiter *SharedLimiter) context.Context { + return context.WithValue(ctx, sharedLimiterKey, limiter) +} + +func getLimiter(ctx context.Context) *SharedLimiter { + obj := ctx.Value(sharedLimiterKey) + if obj == nil { + return nil + } + return obj.(*SharedLimiter) +} diff --git a/internal/api/external.go b/internal/api/external.go index d0cec1536..2eff891ef 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -2,7 +2,6 @@ package api import ( "context" - "errors" "fmt" "net/http" "net/url" @@ -383,10 +382,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. emailConfirmationSent := false if decision.CandidateEmail.Email != "" { if terr = a.sendConfirmation(r, tx, user, models.ImplicitFlow); terr != nil { - if errors.Is(terr, MaxFrequencyLimitError) { - return nil, tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, "For security purposes, you can only request this once every minute") - } - return nil, internalServerError("Error sending confirmation mail").WithInternalError(terr) + return nil, terr } emailConfirmationSent = true } diff --git a/internal/api/identity.go b/internal/api/identity.go index 69d3f854f..53cef86f9 100644 --- a/internal/api/identity.go +++ b/internal/api/identity.go @@ -7,7 +7,6 @@ import ( "github.com/fatih/structs" "github.com/go-chi/chi/v5" "github.com/gofrs/uuid" - "github.com/pkg/errors" "github.com/supabase/auth/internal/api/provider" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/storage" @@ -133,9 +132,7 @@ func (a *API) linkIdentityToUser(r *http.Request, ctx context.Context, tx *stora } if !userData.Metadata.EmailVerified { if terr := a.sendConfirmation(r, tx, targetUser, models.ImplicitFlow); terr != nil { - if errors.Is(terr, MaxFrequencyLimitError) { - return nil, tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "For security purposes, you can only request this once every minute") - } + return nil, terr } return nil, storage.NewCommitWithError(unprocessableEntityError(ErrorCodeEmailNotConfirmed, "Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType)) } diff --git a/internal/api/invite.go b/internal/api/invite.go index 2e07e7135..76852f711 100644 --- a/internal/api/invite.go +++ b/internal/api/invite.go @@ -80,7 +80,7 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error { } if err := a.sendInvite(r, tx, user); err != nil { - return internalServerError("Error inviting user").WithInternalError(err) + return err } return nil }) diff --git a/internal/api/magic_link.go b/internal/api/magic_link.go index eeabafd39..d7e941b79 100644 --- a/internal/api/magic_link.go +++ b/internal/api/magic_link.go @@ -3,7 +3,6 @@ package api import ( "bytes" "encoding/json" - "errors" "fmt" "io" "net/http" @@ -141,10 +140,7 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { return a.sendMagicLink(r, tx, user, flowType) }) if err != nil { - if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, generateFrequencyLimitErrorMessage(user.RecoverySentAt, config.SMTP.MaxFrequency)) - } - return internalServerError("Error sending magic link").WithInternalError(err) + return err } return sendJSON(w, http.StatusOK, make(map[string]string)) diff --git a/internal/api/mail.go b/internal/api/mail.go index 35f529e25..00bb58e7e 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -5,8 +5,11 @@ import ( "strings" "time" + "github.com/didip/tollbooth/v5" "github.com/supabase/auth/internal/hooks" mail "github.com/supabase/auth/internal/mailer" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" "github.com/badoux/checkmail" "github.com/fatih/structs" @@ -20,7 +23,7 @@ import ( ) var ( - MaxFrequencyLimitError error = errors.New("frequency limit reached") + EmailRateLimitExceeded error = errors.New("email rate limit exceeded") ) type GenerateLinkParams struct { @@ -301,7 +304,6 @@ func (a *API) sendConfirmation(r *http.Request, tx *storage.Connection, u *model maxFrequency := config.SMTP.MaxFrequency otpLength := config.Mailer.OtpLength - var err error if err := validateSentWithinFrequencyLimit(u.ConfirmationSentAt, maxFrequency); err != nil { return err } @@ -314,20 +316,20 @@ func (a *API) sendConfirmation(r *http.Request, tx *storage.Connection, u *model token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.ConfirmationToken = addFlowPrefixToToken(token, flowType) now := time.Now() - err = a.sendEmail(r, tx, u, mail.SignupVerification, otp, "", u.ConfirmationToken) - if err != nil { + if err = a.sendEmail(r, tx, u, mail.SignupVerification, otp, "", u.ConfirmationToken); err != nil { u.ConfirmationToken = oldToken - return errors.Wrap(err, "Error sending confirmation email") + if errors.Is(err, EmailRateLimitExceeded) { + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + } + return internalServerError("Error sending confirmation email").WithInternalError(err) } u.ConfirmationSentAt = &now - err = tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at") - if err != nil { - return errors.Wrap(err, "Database error updating user for confirmation") + if err := tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at"); err != nil { + return internalServerError("Error sending confirmation email").WithInternalError(errors.Wrap(err, "Database error updating user for confirmation")) } - err = models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken) - if err != nil { - return errors.Wrap(err, "Database error creating confirmation token") + if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken); err != nil { + return internalServerError("Error sending confirmation email").WithInternalError(errors.Wrap(err, "Database error creating confirmation token")) } return nil @@ -345,21 +347,23 @@ func (a *API) sendInvite(r *http.Request, tx *storage.Connection, u *models.User } u.ConfirmationToken = crypto.GenerateTokenHash(u.GetEmail(), otp) now := time.Now() - err = a.sendEmail(r, tx, u, mail.InviteVerification, otp, "", u.ConfirmationToken) - if err != nil { + if err = a.sendEmail(r, tx, u, mail.InviteVerification, otp, "", u.ConfirmationToken); err != nil { u.ConfirmationToken = oldToken - return errors.Wrap(err, "Error sending invite email") + if errors.Is(err, EmailRateLimitExceeded) { + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + } + return internalServerError("Error sending invite email").WithInternalError(err) } u.InvitedAt = &now u.ConfirmationSentAt = &now err = tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at", "invited_at") if err != nil { - return errors.Wrap(err, "Database error updating user for invite") + return internalServerError("Error inviting user").WithInternalError(errors.Wrap(err, "Database error updating user for invite")) } err = models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken) if err != nil { - return errors.Wrap(err, "Database error creating confirmation token for invite") + return internalServerError("Error inviting user").WithInternalError(errors.Wrap(err, "Database error creating confirmation token for invite")) } return nil @@ -367,10 +371,9 @@ func (a *API) sendInvite(r *http.Request, tx *storage.Connection, u *models.User func (a *API) sendPasswordRecovery(r *http.Request, tx *storage.Connection, u *models.User, flowType models.FlowType) error { config := a.config - maxFrequency := config.SMTP.MaxFrequency otpLength := config.Mailer.OtpLength - var err error - if err := validateSentWithinFrequencyLimit(u.RecoverySentAt, maxFrequency); err != nil { + + if err := validateSentWithinFrequencyLimit(u.RecoverySentAt, config.SMTP.MaxFrequency); err != nil { return err } @@ -383,20 +386,21 @@ func (a *API) sendPasswordRecovery(r *http.Request, tx *storage.Connection, u *m token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.RecoveryToken = addFlowPrefixToToken(token, flowType) now := time.Now() - err = a.sendEmail(r, tx, u, mail.RecoveryVerification, otp, "", u.RecoveryToken) - if err != nil { + if err = a.sendEmail(r, tx, u, mail.RecoveryVerification, otp, "", u.RecoveryToken); err != nil { u.RecoveryToken = oldToken - return errors.Wrap(err, "Error sending recovery email") + if errors.Is(err, EmailRateLimitExceeded) { + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + } + return internalServerError("Error sending recovery email").WithInternalError(err) } u.RecoverySentAt = &now - err = tx.UpdateOnly(u, "recovery_token", "recovery_sent_at") - if err != nil { - return errors.Wrap(err, "Database error updating user for recovery") + + if err := tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"); err != nil { + return internalServerError("Error sending recovery email").WithInternalError(errors.Wrap(err, "Database error updating user for recovery")) } - err = models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.RecoveryToken, models.RecoveryToken) - if err != nil { - return errors.Wrap(err, "Database error creating recovery token") + if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.RecoveryToken, models.RecoveryToken); err != nil { + return internalServerError("Error sending recovery email").WithInternalError(errors.Wrap(err, "Database error creating recovery token")) } return nil @@ -406,7 +410,6 @@ func (a *API) sendReauthenticationOtp(r *http.Request, tx *storage.Connection, u config := a.config maxFrequency := config.SMTP.MaxFrequency otpLength := config.Mailer.OtpLength - var err error if err := validateSentWithinFrequencyLimit(u.ReauthenticationSentAt, maxFrequency); err != nil { return err @@ -420,20 +423,21 @@ func (a *API) sendReauthenticationOtp(r *http.Request, tx *storage.Connection, u } u.ReauthenticationToken = crypto.GenerateTokenHash(u.GetEmail(), otp) now := time.Now() - err = a.sendEmail(r, tx, u, mail.ReauthenticationVerification, otp, "", u.ReauthenticationToken) - if err != nil { + + if err := a.sendEmail(r, tx, u, mail.ReauthenticationVerification, otp, "", u.ReauthenticationToken); err != nil { u.ReauthenticationToken = oldToken - return errors.Wrap(err, "Error sending reauthentication email") + if errors.Is(err, EmailRateLimitExceeded) { + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + } + return internalServerError("Error sending reauthentication email").WithInternalError(err) } u.ReauthenticationSentAt = &now - err = tx.UpdateOnly(u, "reauthentication_token", "reauthentication_sent_at") - if err != nil { - return errors.Wrap(err, "Database error updating user for reauthentication") + if err := tx.UpdateOnly(u, "reauthentication_token", "reauthentication_sent_at"); err != nil { + return internalServerError("Error sending reauthentication email").WithInternalError(errors.Wrap(err, "Database error updating user for reauthentication")) } - err = models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.ReauthenticationToken, models.ReauthenticationToken) - if err != nil { - return errors.Wrap(err, "Database error creating reauthentication token") + if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.ReauthenticationToken, models.ReauthenticationToken); err != nil { + return internalServerError("Error sending reauthentication email").WithInternalError(errors.Wrap(err, "Database error creating reauthentication token")) } return nil @@ -442,11 +446,10 @@ func (a *API) sendReauthenticationOtp(r *http.Request, tx *storage.Connection, u func (a *API) sendMagicLink(r *http.Request, tx *storage.Connection, u *models.User, flowType models.FlowType) error { config := a.config otpLength := config.Mailer.OtpLength - maxFrequency := config.SMTP.MaxFrequency - var err error + // since Magic Link is just a recovery with a different template and behaviour // around new users we will reuse the recovery db timer to prevent potential abuse - if err := validateSentWithinFrequencyLimit(u.RecoverySentAt, maxFrequency); err != nil { + if err := validateSentWithinFrequencyLimit(u.RecoverySentAt, config.SMTP.MaxFrequency); err != nil { return err } @@ -460,20 +463,20 @@ func (a *API) sendMagicLink(r *http.Request, tx *storage.Connection, u *models.U u.RecoveryToken = addFlowPrefixToToken(token, flowType) now := time.Now() - err = a.sendEmail(r, tx, u, mail.MagicLinkVerification, otp, "", u.RecoveryToken) - if err != nil { + if err = a.sendEmail(r, tx, u, mail.MagicLinkVerification, otp, "", u.RecoveryToken); err != nil { u.RecoveryToken = oldToken - return errors.Wrap(err, "Error sending magic link email") + if errors.Is(err, EmailRateLimitExceeded) { + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + } + return internalServerError("Error sending magic link email").WithInternalError(err) } u.RecoverySentAt = &now - err = tx.UpdateOnly(u, "recovery_token", "recovery_sent_at") - if err != nil { - return errors.Wrap(err, "Database error updating user for recovery") + if err := tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"); err != nil { + return internalServerError("Error sending magic link email").WithInternalError(errors.Wrap(err, "Database error updating user for recovery")) } - err = models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.RecoveryToken, models.RecoveryToken) - if err != nil { - return errors.Wrap(err, "Database error creating recovery token") + if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.RecoveryToken, models.RecoveryToken); err != nil { + return internalServerError("Error sending magic link email").WithInternalError(errors.Wrap(err, "Database error creating recovery token")) } return nil @@ -483,7 +486,7 @@ func (a *API) sendMagicLink(r *http.Request, tx *storage.Connection, u *models.U func (a *API) sendEmailChange(r *http.Request, tx *storage.Connection, u *models.User, email string, flowType models.FlowType) error { config := a.config otpLength := config.Mailer.OtpLength - var err error + if err := validateSentWithinFrequencyLimit(u.EmailChangeSentAt, config.SMTP.MaxFrequency); err != nil { return err } @@ -510,36 +513,35 @@ func (a *API) sendEmailChange(r *http.Request, tx *storage.Connection, u *models u.EmailChangeConfirmStatus = zeroConfirmation now := time.Now() - err = a.sendEmail(r, tx, u, mail.EmailChangeVerification, otpCurrent, otpNew, u.EmailChangeTokenNew) - if err != nil { - return err + + if err := a.sendEmail(r, tx, u, mail.EmailChangeVerification, otpCurrent, otpNew, u.EmailChangeTokenNew); err != nil { + if errors.Is(err, EmailRateLimitExceeded) { + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + } + return internalServerError("Error sending email change email").WithInternalError(err) } u.EmailChangeSentAt = &now - err = tx.UpdateOnly( + if err := tx.UpdateOnly( u, "email_change_token_current", "email_change_token_new", "email_change", "email_change_sent_at", "email_change_confirm_status", - ) - - if err != nil { - return errors.Wrap(err, "Database error updating user for email change") + ); err != nil { + return internalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error updating user for email change")) } if u.EmailChangeTokenCurrent != "" { - err = models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.EmailChangeTokenCurrent, models.EmailChangeTokenCurrent) - if err != nil { - return errors.Wrap(err, "Database error creating email change token current") + if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.EmailChangeTokenCurrent, models.EmailChangeTokenCurrent); err != nil { + return internalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error creating email change token current")) } } if u.EmailChangeTokenNew != "" { - err = models.CreateOneTimeToken(tx, u.ID, u.EmailChange, u.EmailChangeTokenNew, models.EmailChangeTokenNew) - if err != nil { - return errors.Wrap(err, "Database error creating email change token new") + if err := models.CreateOneTimeToken(tx, u.ID, u.EmailChange, u.EmailChangeTokenNew, models.EmailChangeTokenNew); err != nil { + return internalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error creating email change token new")) } } @@ -561,7 +563,7 @@ func validateEmail(email string) (string, error) { func validateSentWithinFrequencyLimit(sentAt *time.Time, frequency time.Duration) error { if sentAt != nil && sentAt.Add(frequency).After(time.Now()) { - return MaxFrequencyLimitError + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, generateFrequencyLimitErrorMessage(sentAt, frequency)) } return nil } @@ -572,6 +574,19 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User, config := a.config referrerURL := utilities.GetReferrer(r, config) externalURL := getExternalHost(ctx) + + // apply rate limiting before the email is sent out + if limiter := getLimiter(ctx); limiter != nil { + if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"email_functions"}); err != nil { + emailRateLimitCounter.Add( + ctx, + 1, + metric.WithAttributeSet(attribute.NewSet(attribute.String("path", r.URL.Path))), + ) + return EmailRateLimitExceeded + } + } + if config.Hook.SendEmail.Enabled { emailData := mail.EmailData{ Token: otp, diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 972ac5ab3..aa2c3e9ff 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -15,8 +15,6 @@ import ( "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" "github.com/supabase/auth/internal/security" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/metric" "github.com/didip/tollbooth/v5" "github.com/didip/tollbooth/v5/limiter" @@ -99,35 +97,11 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { if shouldRateLimitEmail || shouldRateLimitPhone { if req.Method == "PUT" || req.Method == "POST" { - var requestBody struct { - Email string `json:"email"` - Phone string `json:"phone"` - } - - if err := retrieveRequestParams(req, &requestBody); err != nil { - return c, err - } - - if shouldRateLimitEmail { - if requestBody.Email != "" { - if err := tollbooth.LimitByKeys(emailLimiter, []string{"email_functions"}); err != nil { - emailRateLimitCounter.Add( - req.Context(), - 1, - metric.WithAttributeSet(attribute.NewSet(attribute.String("path", req.URL.Path))), - ) - return c, tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, "Email rate limit exceeded") - } - } - } - - if shouldRateLimitPhone { - if requestBody.Phone != "" { - if err := tollbooth.LimitByKeys(phoneLimiter, []string{"phone_functions"}); err != nil { - return c, tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded") - } - } - } + // store rate limiter in request context + c = withLimiter(c, &SharedLimiter{ + EmailLimiter: emailLimiter, + PhoneLimiter: phoneLimiter, + }) } } diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index eb8c5da3b..2d7a32493 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -221,15 +221,12 @@ func (ts *MiddlewareTestSuite) TestLimitEmailOrPhoneSentHandler() { req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - for i := 0; i < 5; i++ { - _, err := limiter(w, req) - require.NoError(ts.T(), err) - } + ctx, err := limiter(w, req) + require.NoError(ts.T(), err) - // should exceed rate limit on 5th try - _, err := limiter(w, req) - require.Error(ts.T(), err) - require.Equal(ts.T(), c.expectedErrorMsg, err.Error()) + // check that shared limiter is set in the request context + sharedLimiter := getLimiter(ctx) + require.NotNil(ts.T(), sharedLimiter) }) } } @@ -406,6 +403,34 @@ func (ts *MiddlewareTestSuite) TestLimitHandlerWithSharedLimiter() { } okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + limiter := getLimiter(r.Context()) + if limiter != nil { + var requestBody struct { + Email string `json:"email"` + Phone string `json:"phone"` + } + err := retrieveRequestParams(r, &requestBody) + require.NoError(ts.T(), err) + + if requestBody.Email != "" { + if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"email_functions"}); err != nil { + sendJSON(w, http.StatusTooManyRequests, HTTPError{ + HTTPStatus: http.StatusTooManyRequests, + ErrorCode: ErrorCodeOverEmailSendRateLimit, + Message: "Email rate limit exceeded", + }) + } + } + if requestBody.Phone != "" { + if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"phone_functions"}); err != nil { + sendJSON(w, http.StatusTooManyRequests, HTTPError{ + HTTPStatus: http.StatusTooManyRequests, + ErrorCode: ErrorCodeOverSMSSendRateLimit, + Message: "SMS rate limit exceeded", + }) + } + } + } w.WriteHeader(http.StatusOK) }) diff --git a/internal/api/phone.go b/internal/api/phone.go index 8e7d39e63..ce11c5a3f 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -8,6 +8,7 @@ import ( "text/template" "time" + "github.com/didip/tollbooth/v5" "github.com/supabase/auth/internal/hooks" "github.com/pkg/errors" @@ -44,6 +45,7 @@ func formatPhoneNumber(phone string) string { // sendPhoneConfirmation sends an otp to the user's phone number func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, user *models.User, phone, otpType string, channel string) (string, error) { + ctx := r.Context() config := a.config var token *string @@ -84,7 +86,15 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use messageID = "test-otp" } - if otp == "" { // not using test OTPs + // not using test OTPs + if otp == "" { + // apply rate limiting before the sms is sent out + limiter := getLimiter(ctx) + if limiter != nil { + if err := tollbooth.LimitByKeys(limiter.PhoneLimiter, []string{"phone_functions"}); err != nil { + return "", tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded") + } + } otp, err = crypto.GenerateOtp(config.Sms.OtpLength) if err != nil { return "", internalServerError("error generating otp").WithInternalError(err) diff --git a/internal/api/reauthenticate.go b/internal/api/reauthenticate.go index df46bad03..5146ae409 100644 --- a/internal/api/reauthenticate.go +++ b/internal/api/reauthenticate.go @@ -1,7 +1,6 @@ package api import ( - "errors" "net/http" "github.com/supabase/auth/internal/api/sms_provider" @@ -53,14 +52,6 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { return nil }) if err != nil { - if errors.Is(err, MaxFrequencyLimitError) { - reason := ErrorCodeOverEmailSendRateLimit - if phone != "" { - reason = ErrorCodeOverSMSSendRateLimit - } - - return tooManyRequestsError(reason, "For security purposes, you can only request this once every 60 seconds") - } return err } diff --git a/internal/api/recover.go b/internal/api/recover.go index 0fa9760ae..cbcff81d8 100644 --- a/internal/api/recover.go +++ b/internal/api/recover.go @@ -1,7 +1,6 @@ package api import ( - "errors" "net/http" "github.com/supabase/auth/internal/models" @@ -67,10 +66,7 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { return a.sendPasswordRecovery(r, tx, user, flowType) }) if err != nil { - if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, "For security purposes, you can only request this once every 60 seconds") - } - return internalServerError("Unable to process request").WithInternalError(err) + return err } return sendJSON(w, http.StatusOK, map[string]string{}) diff --git a/internal/api/resend.go b/internal/api/resend.go index 1dfc47762..a1f4246c8 100644 --- a/internal/api/resend.go +++ b/internal/api/resend.go @@ -1,9 +1,7 @@ package api import ( - "errors" "net/http" - "time" "github.com/supabase/auth/internal/api/sms_provider" "github.com/supabase/auth/internal/conf" @@ -144,16 +142,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { return nil }) if err != nil { - if errors.Is(err, MaxFrequencyLimitError) { - reason := ErrorCodeOverEmailSendRateLimit - if params.Type == smsVerification || params.Type == phoneChangeVerification { - reason = ErrorCodeOverSMSSendRateLimit - } - - until := time.Until(user.ConfirmationSentAt.Add(config.SMTP.MaxFrequency)) / time.Second - return tooManyRequestsError(reason, "For security purposes, you can only request this once every %d seconds.", until) - } - return internalServerError("Unable to process request").WithInternalError(err) + return err } ret := map[string]any{} diff --git a/internal/api/signup.go b/internal/api/signup.go index d7d946c8c..0a1b8c6c4 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -244,10 +244,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { } } if terr = a.sendConfirmation(r, tx, user, flowType); terr != nil { - if errors.Is(terr, MaxFrequencyLimitError) { - return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, generateFrequencyLimitErrorMessage(user.ConfirmationSentAt, config.SMTP.MaxFrequency)) - } - return internalServerError("Error sending confirmation mail").WithInternalError(terr) + return terr } } } else if params.Provider == "phone" && !user.IsPhoneConfirmed() { @@ -277,14 +274,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { }) if err != nil { - reason := ErrorCodeOverEmailSendRateLimit - if params.Provider == "phone" { - reason = ErrorCodeOverSMSSendRateLimit - } - - if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError(reason, "For security purposes, you can only request this once every minute") - } else if errors.Is(err, UserExistsError) { + if errors.Is(err, UserExistsError) { err = db.Transaction(func(tx *storage.Connection) error { if terr := models.NewAuditLogEntry(r, tx, user, models.UserRepeatedSignUpAction, "", map[string]interface{}{ "provider": params.Provider, diff --git a/internal/api/user.go b/internal/api/user.go index c89c5226d..616ad23c7 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -2,7 +2,6 @@ package api import ( "context" - "errors" "net/http" "time" @@ -226,10 +225,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { } if terr = a.sendEmailChange(r, tx, user, params.Email, flowType); terr != nil { - if errors.Is(terr, MaxFrequencyLimitError) { - return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, generateFrequencyLimitErrorMessage(user.EmailChangeSentAt, config.SMTP.MaxFrequency)) - } - return internalServerError("Error sending change email").WithInternalError(terr) + return terr } } } diff --git a/internal/mailer/mailer.go b/internal/mailer/mailer.go index ff19239d8..a05e6e27a 100644 --- a/internal/mailer/mailer.go +++ b/internal/mailer/mailer.go @@ -15,7 +15,6 @@ import ( // Mailer defines the interface a mailer must implement. type Mailer interface { - Send(user *models.User, subject, body string, data map[string]interface{}) error InviteMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error ConfirmationMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error RecoveryMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error