Skip to content

Commit

Permalink
fix: refactor mfa validation
Browse files Browse the repository at this point in the history
  • Loading branch information
J0 committed Sep 30, 2024
1 parent 937372b commit 0245a3b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 24 deletions.
30 changes: 11 additions & 19 deletions internal/api/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,14 @@ const (
QRCodeGenerationErrorMessage = "Error generating QR Code"
)

func validateFactors(user *models.User, newFactorName string, config *conf.GlobalConfiguration, session *models.Session) error {
func validateFactors(db *storage.Connection, user *models.User, newFactorName string, config *conf.GlobalConfiguration, session *models.Session) error {
factorCount := len(user.Factors)
numVerifiedFactors := 0

if err := db.Load(user, "Factors"); err != nil {
return err
}

for _, factor := range user.Factors {
if factor.FriendlyName == newFactorName {
return unprocessableEntityError(
Expand Down Expand Up @@ -115,16 +119,10 @@ func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params *
if err != nil {
return badRequestError(ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)")
}
factors := user.Factors

factorCount := len(factors)
numVerifiedFactors := 0
if err := models.DeleteExpiredFactors(db, config.MFA.FactorExpiryDuration); err != nil {
return err
}
if err := validateFactors(user, params.FriendlyName, a.config, session); err != nil {
return err
}

var factorsToDelete []models.Factor
for _, factor := range user.Factors {
if factor.IsPhoneFactor() && factor.Phone.String() == phone {
Expand All @@ -143,17 +141,10 @@ func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params *
return internalServerError("Database error deleting unverified phone factors").WithInternalError(err)
}

if factorCount >= int(config.MFA.MaxEnrolledFactors) {
return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue")
}

if numVerifiedFactors >= config.MFA.MaxVerifiedFactors {
return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue")
if err := validateFactors(db, user, params.FriendlyName, a.config, session); err != nil {
return err
}

if numVerifiedFactors > 0 && !session.IsAAL2() {
return forbiddenError(ErrorCodeInsufficientAAL, "AAL2 required to enroll a new factor")
}
factor := models.NewPhoneFactor(user, phone, params.FriendlyName)
err = db.Transaction(func(tx *storage.Connection) error {
if terr := tx.Create(factor); terr != nil {
Expand Down Expand Up @@ -198,7 +189,8 @@ func (a *API) enrollTOTPFactor(w http.ResponseWriter, r *http.Request, params *E
if err := models.DeleteExpiredFactors(db, config.MFA.FactorExpiryDuration); err != nil {
return err
}
if err := validateFactors(user, params.FriendlyName, config, session); err != nil {

if err := validateFactors(db, user, params.FriendlyName, config, session); err != nil {
return err
}

Expand Down Expand Up @@ -230,8 +222,8 @@ func (a *API) enrollTOTPFactor(w http.ResponseWriter, r *http.Request, params *E
err = db.Transaction(func(tx *storage.Connection) error {
if terr := tx.Create(factor); terr != nil {
return terr

}

if terr := models.NewAuditLogEntry(r, tx, user, models.EnrollFactorAction, r.RemoteAddr, map[string]interface{}{
"factor_id": factor.ID,
}); terr != nil {
Expand Down
8 changes: 3 additions & 5 deletions internal/api/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,7 @@ func (ts *MFATestSuite) TestDuplicateTOTPEnrollsReturnExpectedMessage() {
err := json.NewDecoder(response.Body).Decode(&errorResponse)
require.NoError(ts.T(), err)

// Convert the response body to a string and check for the expected error message
expectedErrorMessage := fmt.Sprintf("A factor with the friendly name %q for this user likely already exists", friendlyName)
require.Contains(ts.T(), errorResponse.Message, expectedErrorMessage)
require.Contains(ts.T(), errorResponse.ErrorCode, ErrorCodeMFAFactorNameConflict)
}

func (ts *MFATestSuite) AAL2RequiredToUpdatePasswordAfterEnrollment() {
Expand Down Expand Up @@ -369,7 +367,7 @@ func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() {
var w *httptest.ResponseRecorder
token := accessTokenResp.Token
for i := 0; i < numFactors; i++ {
w = performEnrollFlow(ts, token, "", models.TOTP, "https://issuer.com", "", http.StatusOK)
w = performEnrollFlow(ts, token, "first-name", models.TOTP, "https://issuer.com", "", http.StatusOK)
}

enrollResp := EnrollFactorResponse{}
Expand All @@ -379,7 +377,7 @@ func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() {
_ = performChallengeFlow(ts, enrollResp.ID, token)

// Enroll another Factor (Factor 3)
_ = performEnrollFlow(ts, token, "", models.TOTP, "https://issuer.com", "", http.StatusOK)
_ = performEnrollFlow(ts, token, "second-name", models.TOTP, "https://issuer.com", "", http.StatusOK)
require.NoError(ts.T(), ts.API.db.Eager("Factors").Find(ts.TestUser, ts.TestUser.ID))
require.Equal(ts.T(), 3, len(ts.TestUser.Factors))
}
Expand Down

0 comments on commit 0245a3b

Please sign in to comment.