Skip to content

Commit

Permalink
feat: rework the OTP code submit count mechanism
Browse files Browse the repository at this point in the history
Unlike what the previous comment suggested, incrementing and checking the submit count inside the
database transaction is not actually optimal peformance- or security-wise.

We now check atomically increment and check the submit count as the first part of the operation,
and abort as early as possible if we detect brute-forcing. This prevents a situation where the
check works only on certain transaction isolation levels.
  • Loading branch information
alnr committed Dec 19, 2024
1 parent a893cd8 commit 299c94e
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 37 deletions.
86 changes: 56 additions & 30 deletions persistence/sql/persister_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ package sql
import (
"context"
"crypto/subtle"
"database/sql"
"fmt"
"time"

"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"

"github.com/ory/kratos/selfservice/strategy/code"
"github.com/ory/x/otelx"
Expand Down Expand Up @@ -41,7 +44,7 @@ func useOneTimeCode[P any, U interface {
*P
oneTimeCodeProvider
}](ctx context.Context, p *Persister, flowID uuid.UUID, userProvidedCode string, flowTableName string, foreignKeyName string, opts ...codeOption,
) (_ U, err error) {
) (target U, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.useOneTimeCode")
defer otelx.End(span, &err)

Expand All @@ -50,33 +53,21 @@ func useOneTimeCode[P any, U interface {
opt(o)
}

var target U
nid := p.NetworkID(ctx)
if err := p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
//#nosec G201 -- TableName is static
if err := tx.RawQuery(fmt.Sprintf("UPDATE %s SET submit_count = submit_count + 1 WHERE id = ? AND nid = ?", flowTableName), flowID, nid).Exec(); err != nil {
return err
}

var submitCount int
// Because MySQL does not support "RETURNING" clauses, but we need the updated `submit_count` later on.
//#nosec G201 -- TableName is static
if err := sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("SELECT submit_count FROM %s WHERE id = ? AND nid = ?", flowTableName), flowID, nid).First(&submitCount)); err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
// Return no error, as that would roll back the transaction
return nil
}
return err
}
// Before we do anything else, increment the submit count and check if we're
// being brute-forced. This is a separate statement/transaction to the rest
// of the operations so that it is correct for all transaction isolation
// levels.
submitCount, err := incrementOTPCodeSubmitCount(ctx, p, flowID, flowTableName)
if err != nil {
return nil, err
}
if submitCount > 5 {
return nil, errors.WithStack(code.ErrCodeSubmittedTooOften)
}

// This check prevents parallel brute force attacks by checking the submit count inside this database
// transaction. If the flow has been submitted more than 5 times, the transaction is aborted (regardless of
// whether the code was correct or not) and we thus give no indication whether the supplied code was correct or
// not. For more explanation see [this comment](https://github.com/ory/kratos/pull/2645#discussion_r984732899).
if submitCount > 5 {
return errors.WithStack(code.ErrCodeSubmittedTooOften)
}
nid := p.NetworkID(ctx)

if err := p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
var codes []U
codesQuery := tx.Where(fmt.Sprintf("nid = ? AND %s = ?", foreignKeyName), nid, flowID)
if o.IdentityID != nil {
Expand All @@ -85,10 +76,8 @@ func useOneTimeCode[P any, U interface {

if err := sqlcon.HandleError(codesQuery.All(&codes)); err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
// Return no error, as that would roll back the transaction and reset the submit count.
return nil
return code.ErrCodeNotFound
}

return err
}

Expand All @@ -107,7 +96,7 @@ func useOneTimeCode[P any, U interface {
}

if target.Validate() != nil {
// Return no error, as that would roll back the transaction
// Return no error, as that would roll back the transaction. We re-validate the code after the transaction.
return nil
}

Expand All @@ -123,3 +112,40 @@ func useOneTimeCode[P any, U interface {

return target, nil
}

func incrementOTPCodeSubmitCount(ctx context.Context, p *Persister, flowID uuid.UUID, flowTableName string) (submitCount int, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.incrementOTPCodeSubmitCount",
trace.WithAttributes(attribute.Stringer("flow_id", flowID), attribute.String("flow_table_name", flowTableName)))
defer otelx.End(span, &err)
defer func() {
span.SetAttributes(attribute.Int("submit_count", submitCount))
}()

nid := p.NetworkID(ctx)

if p.c.Dialect.Name() == "mysql" { // no RETURNING support
//#nosec G201 -- TableName is static
qUpdate := fmt.Sprintf("UPDATE %s SET submit_count = submit_count + 1 WHERE id = ? AND nid = ?", flowTableName)
//#nosec G201 -- TableName is static
qSelect := fmt.Sprintf("SELECT submit_count FROM %s WHERE id = ? AND nid = ?", flowTableName)
err = p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
if err := tx.RawQuery(qUpdate, flowID, nid).Exec(); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return code.ErrCodeNotFound
}
return sqlcon.HandleError(err)
}
return sqlcon.HandleError(tx.RawQuery(qSelect, flowID, nid).First(&submitCount))
})
} else {
//#nosec G201 -- TableName is static
q := fmt.Sprintf("UPDATE %s SET submit_count = submit_count + 1 WHERE id = ? AND nid = ? RETURNING submit_count", flowTableName)
err = p.Connection(ctx).RawQuery(q, flowID, nid).First(&submitCount)
if errors.Is(err, sql.ErrNoRows) {
return 0, code.ErrCodeNotFound
}
err = sqlcon.HandleError(err)
}

return submitCount, err
}
2 changes: 1 addition & 1 deletion persistence/sql/persister_login_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (p *Persister) UseLoginCode(ctx context.Context, flowID uuid.UUID, identity
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseLoginCode")
defer otelx.End(span, &err)

codeRow, err := useOneTimeCode[code.LoginCode, *code.LoginCode](ctx, p, flowID, userProvidedCode, new(login.Flow).TableName(ctx), "selfservice_login_flow_id", withCheckIdentityID(identityID))
codeRow, err := useOneTimeCode[code.LoginCode](ctx, p, flowID, userProvidedCode, new(login.Flow).TableName(ctx), "selfservice_login_flow_id", withCheckIdentityID(identityID))
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion persistence/sql/persister_recovery_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (p *Persister) UseRecoveryCode(ctx context.Context, flowID uuid.UUID, userP
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseRecoveryCode")
defer otelx.End(span, &err)

codeRow, err := useOneTimeCode[code.RecoveryCode, *code.RecoveryCode](ctx, p, flowID, userProvidedCode, new(recovery.Flow).TableName(ctx), "selfservice_recovery_flow_id")
codeRow, err := useOneTimeCode[code.RecoveryCode](ctx, p, flowID, userProvidedCode, new(recovery.Flow).TableName(ctx), "selfservice_recovery_flow_id")
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion persistence/sql/persister_registration_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (p *Persister) UseRegistrationCode(ctx context.Context, flowID uuid.UUID, u
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseRegistrationCode")
defer otelx.End(span, &err)

codeRow, err := useOneTimeCode[code.RegistrationCode, *code.RegistrationCode](ctx, p, flowID, userProvidedCode, new(registration.Flow).TableName(ctx), "selfservice_registration_flow_id")
codeRow, err := useOneTimeCode[code.RegistrationCode](ctx, p, flowID, userProvidedCode, new(registration.Flow).TableName(ctx), "selfservice_registration_flow_id")
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion persistence/sql/persister_verification_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (p *Persister) UseVerificationCode(ctx context.Context, flowID uuid.UUID, u
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseVerificationCode")
defer otelx.End(span, &err)

codeRow, err := useOneTimeCode[code.VerificationCode, *code.VerificationCode](ctx, p, flowID, userProvidedCode, new(verification.Flow).TableName(ctx), "selfservice_verification_flow_id")
codeRow, err := useOneTimeCode[code.VerificationCode](ctx, p, flowID, userProvidedCode, new(verification.Flow).TableName(ctx), "selfservice_verification_flow_id")
if err != nil {
return nil, err
}
Expand Down
26 changes: 23 additions & 3 deletions selfservice/strategy/code/test/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ package code

import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -113,10 +116,27 @@ func TestPersister(ctx context.Context, p interface {
_, err := p.CreateRecoveryCode(ctx, dto)
require.NoError(t, err)

for i := 1; i <= 5; i++ {
_, err = p.UseRecoveryCode(ctx, f.ID, "i-do-not-exist")
require.Error(t, err)
var tooOften, wrongCode int32
var wg sync.WaitGroup
for range 50 {
wg.Add(1)
go func() {
defer wg.Done()
_, err := p.UseRecoveryCode(ctx, f.ID, "i-do-not-exist")
if err == nil {
t.Error("should have rejected incorrect code")
return
}
if errors.Is(err, code.ErrCodeSubmittedTooOften) {
atomic.AddInt32(&tooOften, 1)
} else {
atomic.AddInt32(&wrongCode, 1)
}
}()
}
wg.Wait()
require.EqualValues(t, 5, wrongCode, "should reject 5 times with wrong code")
require.EqualValues(t, 45, tooOften, "should reject 45 times with too often")

_, err = p.UseRecoveryCode(ctx, f.ID, "i-do-not-exist")
require.ErrorIs(t, err, code.ErrCodeSubmittedTooOften)
Expand Down

0 comments on commit 299c94e

Please sign in to comment.