From 0245a3bbbdbc33d5c2d1955813eca3f8053a4afd Mon Sep 17 00:00:00 2001 From: joel Date: Sat, 28 Sep 2024 11:01:14 +0200 Subject: [PATCH] fix: refactor mfa validation --- internal/api/mfa.go | 30 +++++++++++------------------- internal/api/mfa_test.go | 8 +++----- 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 26df5a2b0..3b7383706 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -70,10 +70,14 @@ const ( QRCodeGenerationErrorMessage = "Error generating QR Code" ) -func validateFactors(user *models.User, newFactorName string, config *conf.GlobalConfiguration, session *models.Session) error { +func validateFactors(db *storage.Connection, user *models.User, newFactorName string, config *conf.GlobalConfiguration, session *models.Session) error { factorCount := len(user.Factors) numVerifiedFactors := 0 + if err := db.Load(user, "Factors"); err != nil { + return err + } + for _, factor := range user.Factors { if factor.FriendlyName == newFactorName { return unprocessableEntityError( @@ -115,16 +119,10 @@ func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params * if err != nil { return badRequestError(ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)") } - factors := user.Factors - - factorCount := len(factors) - numVerifiedFactors := 0 if err := models.DeleteExpiredFactors(db, config.MFA.FactorExpiryDuration); err != nil { return err } - if err := validateFactors(user, params.FriendlyName, a.config, session); err != nil { - return err - } + var factorsToDelete []models.Factor for _, factor := range user.Factors { if factor.IsPhoneFactor() && factor.Phone.String() == phone { @@ -143,17 +141,10 @@ func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params * return internalServerError("Database error deleting unverified phone factors").WithInternalError(err) } - if factorCount >= int(config.MFA.MaxEnrolledFactors) { - return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") - } - - if numVerifiedFactors >= config.MFA.MaxVerifiedFactors { - return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") + if err := validateFactors(db, user, params.FriendlyName, a.config, session); err != nil { + return err } - if numVerifiedFactors > 0 && !session.IsAAL2() { - return forbiddenError(ErrorCodeInsufficientAAL, "AAL2 required to enroll a new factor") - } factor := models.NewPhoneFactor(user, phone, params.FriendlyName) err = db.Transaction(func(tx *storage.Connection) error { if terr := tx.Create(factor); terr != nil { @@ -198,7 +189,8 @@ func (a *API) enrollTOTPFactor(w http.ResponseWriter, r *http.Request, params *E if err := models.DeleteExpiredFactors(db, config.MFA.FactorExpiryDuration); err != nil { return err } - if err := validateFactors(user, params.FriendlyName, config, session); err != nil { + + if err := validateFactors(db, user, params.FriendlyName, config, session); err != nil { return err } @@ -230,8 +222,8 @@ func (a *API) enrollTOTPFactor(w http.ResponseWriter, r *http.Request, params *E err = db.Transaction(func(tx *storage.Connection) error { if terr := tx.Create(factor); terr != nil { return terr - } + if terr := models.NewAuditLogEntry(r, tx, user, models.EnrollFactorAction, r.RemoteAddr, map[string]interface{}{ "factor_id": factor.ID, }); terr != nil { diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 557b0ab13..87767f085 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -290,9 +290,7 @@ func (ts *MFATestSuite) TestDuplicateTOTPEnrollsReturnExpectedMessage() { err := json.NewDecoder(response.Body).Decode(&errorResponse) require.NoError(ts.T(), err) - // Convert the response body to a string and check for the expected error message - expectedErrorMessage := fmt.Sprintf("A factor with the friendly name %q for this user likely already exists", friendlyName) - require.Contains(ts.T(), errorResponse.Message, expectedErrorMessage) + require.Contains(ts.T(), errorResponse.ErrorCode, ErrorCodeMFAFactorNameConflict) } func (ts *MFATestSuite) AAL2RequiredToUpdatePasswordAfterEnrollment() { @@ -369,7 +367,7 @@ func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() { var w *httptest.ResponseRecorder token := accessTokenResp.Token for i := 0; i < numFactors; i++ { - w = performEnrollFlow(ts, token, "", models.TOTP, "https://issuer.com", "", http.StatusOK) + w = performEnrollFlow(ts, token, "first-name", models.TOTP, "https://issuer.com", "", http.StatusOK) } enrollResp := EnrollFactorResponse{} @@ -379,7 +377,7 @@ func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() { _ = performChallengeFlow(ts, enrollResp.ID, token) // Enroll another Factor (Factor 3) - _ = performEnrollFlow(ts, token, "", models.TOTP, "https://issuer.com", "", http.StatusOK) + _ = performEnrollFlow(ts, token, "second-name", models.TOTP, "https://issuer.com", "", http.StatusOK) require.NoError(ts.T(), ts.API.db.Eager("Factors").Find(ts.TestUser, ts.TestUser.ID)) require.Equal(ts.T(), 3, len(ts.TestUser.Factors)) }