Skip to content

Commit

Permalink
feat: do not retry on 429 responses from external verification servic…
Browse files Browse the repository at this point in the history
…e (PS-405)
  • Loading branch information
maoanran committed Sep 10, 2024
1 parent 602d716 commit f13a642
Show file tree
Hide file tree
Showing 7 changed files with 308 additions and 2 deletions.
21 changes: 20 additions & 1 deletion driver/registry_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,26 @@ func (m *RegistryDefault) HTTPClient(_ context.Context, opts ...httpx.ResilientO
httpx.ResilientClientAllowInternalIPRequestsTo(m.Config().ClientHTTPPrivateIPExceptionURLs(contextx.RootContext)...),
)
}
return httpx.NewResilientClient(opts...)
client := httpx.NewResilientClient(opts...)
client.CheckRetry = NoRetryOnRateLimitPolicy
return client
}

func NoRetryOnRateLimitPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) {
// If there's no response (network error), retry
if resp == nil {
return true, nil
}

// Do not retry on 4xx errors, except 408 (Request Timeout)
if resp.StatusCode == http.StatusRequestTimeout {
return true, nil
} else if resp.StatusCode >= 400 && resp.StatusCode < 500 {
return false, nil
}

// Default retry policy will retry on 5xx errors or network errors
return retryablehttp.DefaultRetryPolicy(ctx, resp, err)
}

func (m *RegistryDefault) WithContextualizer(ctxer contextx.Contextualizer) Registry {
Expand Down
28 changes: 28 additions & 0 deletions driver/registry_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"testing"

Expand Down Expand Up @@ -975,3 +976,30 @@ func TestGetActiveVerificationStrategy(t *testing.T) {
}
})
}

func TestNoRetryOnRateLimitPolicy(t *testing.T) {
t.Run("case=does not retry on 4xx errors except 408", func(t *testing.T) {
resp := &http.Response{StatusCode: 429}
retry, err := driver.NoRetryOnRateLimitPolicy(context.Background(), resp, nil)
assert.False(t, retry)
assert.NoError(t, err)

resp.StatusCode = 408
retry, err = driver.NoRetryOnRateLimitPolicy(context.Background(), resp, nil)
assert.True(t, retry)
assert.NoError(t, err)
})

t.Run("case=retries on 5xx errors", func(t *testing.T) {
resp := &http.Response{StatusCode: 500}
retry, err := driver.NoRetryOnRateLimitPolicy(context.Background(), resp, nil)
assert.True(t, retry)
assert.NoError(t, err)
})

t.Run("case=retries on network errors", func(t *testing.T) {
retry, err := driver.NoRetryOnRateLimitPolicy(context.Background(), nil, assert.AnError)
assert.True(t, retry)
assert.NoError(t, err)
})
}
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ require (
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/mitchellh/reflectwalk v1.0.2 // indirect
github.com/moby/term v0.0.0-20220808134915-39b0c02b01ae // indirect
github.com/nyaruka/phonenumbers v1.3.6 // indirect
github.com/nyaruka/phonenumbers v1.3.6
github.com/ogier/pflag v0.0.1 // indirect
github.com/oklog/ulid v1.3.1 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
Expand Down Expand Up @@ -338,5 +338,6 @@ require (
github.com/jackc/puddle/v2 v2.1.2 // indirect
github.com/lestrrat-go/httprc v1.0.4 // indirect
github.com/segmentio/asm v1.2.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
go.uber.org/atomic v1.10.0 // indirect
)
210 changes: 210 additions & 0 deletions selfservice/strategy/code/code_external_verifier_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
package code_test

import (
"bytes"
"context"
"github.com/hashicorp/go-retryablehttp"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/driver"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/selfservice/strategy/code"
"github.com/ory/x/configx"
"github.com/ory/x/httpx"
"github.com/ory/x/jsonnetsecure"
"github.com/ory/x/logrusx"
"github.com/ory/x/otelx"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"io"
"net/http"
"os"
"testing"
)

type MockDependencies struct {
mock.Mock
t *testing.T
mockHTTPClient *retryablehttp.Client
}

func NewMockDependencies(t *testing.T, mockHTTPClient *retryablehttp.Client) *MockDependencies {
return &MockDependencies{t: t, mockHTTPClient: mockHTTPClient}
}

func (m *MockDependencies) Config() *config.Config {
return config.MustNew(
m.t,
logrusx.New("kratos", "test"),
os.Stderr,
configx.WithConfigFiles("../../../test/e2e/profiles/code/.kratos.yml"),
configx.SkipValidation(),
)
}

func (m *MockDependencies) Logger() *logrusx.Logger {
return logrusx.New("kratos", "test")
}

func (m *MockDependencies) Audit() *logrusx.Logger {
return logrusx.New("kratos", "test")
}

func (m *MockDependencies) Tracer(ctx context.Context) *otelx.Tracer {
return otelx.NewNoop(nil, nil)
}

func (m *MockDependencies) HTTPClient(ctx context.Context, options ...httpx.ResilientOptions) *retryablehttp.Client {
return m.mockHTTPClient
}

func (m *MockDependencies) JsonnetVM(ctx context.Context) (jsonnetsecure.VM, error) {
return jsonnetsecure.NewTestProvider(m.t).JsonnetVM(ctx)
}

type MockSMSTemplate struct {
mock.Mock
marshalJson string
}

func (m *MockSMSTemplate) MarshalJSON() ([]byte, error) {
return []byte(m.marshalJson), nil
}

func (m *MockSMSTemplate) SMSBody(ctx context.Context) (string, error) {
return "sms body", nil
}

func (m *MockSMSTemplate) TemplateType() template.TemplateType {
return "sms"
}

func (m *MockSMSTemplate) PhoneNumber() (string, error) {
return "1234567890", nil
}

func TestVerificationStart(t *testing.T) {
ctx := context.Background()

mockHTTPClient := new(retryablehttp.Client)
mockHTTPClient.CheckRetry = driver.NoRetryOnRateLimitPolicy
mockDeps := NewMockDependencies(t, mockHTTPClient)

mockSMSTemplate := new(MockSMSTemplate)
mockSMSTemplate.marshalJson = `{"To":"12345678"}`

externalVerifier := code.NewExternalVerifier(mockDeps)

t.Run("method=VerificationStart", func(t *testing.T) {
t.Run("case=returns no error for 2xx response", func(t *testing.T) {
mockHTTPClient.HTTPClient = &http.Client{
Transport: &mockTransport{
response: &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader([]byte("OK"))),
},
},
}

err := externalVerifier.VerificationStart(ctx, mockSMSTemplate)
require.NoError(t, err)
})

t.Run("case=returns error for 4xx response", func(t *testing.T) {
mockHTTPClient.HTTPClient = &http.Client{
Transport: &mockTransport{
response: &http.Response{
StatusCode: 400,
Body: io.NopCloser(bytes.NewReader([]byte("Bad Request"))),
},
},
}

err := externalVerifier.VerificationStart(ctx, mockSMSTemplate)
require.Error(t, err)
assert.Contains(t, err.Error(), "upstream server replied with status code 400 and body Bad Request")
})

t.Run("case=returns error for 5xx response", func(t *testing.T) {
mockHTTPClient.HTTPClient = &http.Client{
Transport: &mockTransport{
response: &http.Response{
StatusCode: 500,
Body: io.NopCloser(bytes.NewReader([]byte("Internal Server Error"))),
},
},
}

err := externalVerifier.VerificationStart(ctx, mockSMSTemplate)
require.Error(t, err)
assert.Contains(t, err.Error(), "giving up after 1 attempt(s)")
})
})
}

func TestVerificationCheck(t *testing.T) {
ctx := context.Background()

mockHTTPClient := new(retryablehttp.Client)
mockHTTPClient.CheckRetry = driver.NoRetryOnRateLimitPolicy
mockDeps := NewMockDependencies(t, mockHTTPClient)

mockSMSTemplate := new(MockSMSTemplate)
mockSMSTemplate.marshalJson = `{"To":"12345678", "Code":"1234"}`

externalVerifier := code.NewExternalVerifier(mockDeps)

t.Run("method=VerificationCheck", func(t *testing.T) {
t.Run("case=returns no error for 2xx response", func(t *testing.T) {
mockHTTPClient.HTTPClient = &http.Client{
Transport: &mockTransport{
response: &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader([]byte("OK"))),
},
},
}

err := externalVerifier.VerificationCheck(ctx, mockSMSTemplate)
require.NoError(t, err)
})

t.Run("case=returns error for 4xx response", func(t *testing.T) {
mockHTTPClient.HTTPClient = &http.Client{
Transport: &mockTransport{
response: &http.Response{
StatusCode: 400,
Body: io.NopCloser(bytes.NewReader([]byte("Bad Request"))),
},
},
}

err := externalVerifier.VerificationCheck(ctx, mockSMSTemplate)
require.Error(t, err)
assert.Contains(t, err.Error(), "The requested resource could not be found")
})

t.Run("case=returns error for 5xx response", func(t *testing.T) {
mockHTTPClient.HTTPClient = &http.Client{
Transport: &mockTransport{
response: &http.Response{
StatusCode: 500,
Body: io.NopCloser(bytes.NewReader([]byte("Internal Server Error"))),
},
},
}

err := externalVerifier.VerificationCheck(ctx, mockSMSTemplate)
require.Error(t, err)
assert.Contains(t, err.Error(), "giving up after 1 attempt(s)")
})
})
}

type mockTransport struct {
response *http.Response
}

func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return m.response, nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
function(ctx) {
to: ctx.To,
code: if std.objectHas(ctx, "VerificationCode") then ctx.VerificationCode else ctx.Code,
}
22 changes: 22 additions & 0 deletions selfservice/strategy/code/stub/code.verification.start.jsonnet
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
local getFlowId(ctx) =
if std.objectHas(ctx, "VerificationURL") then
local start = std.findSubstr("flow=", ctx.VerificationURL);
std.substr(ctx.VerificationURL, start[0]+5, 36)
else
"error_getting_flow_id";

local getOperator(ctx) =
if std.objectHas(ctx, "TransientPayload") && std.objectHas(ctx.TransientPayload, "application") then ctx.TransientPayload.application
else "monta";

function(ctx) {
to: ctx.To,
[if std.objectHas(ctx, "TransientPayload") && std.objectHas(ctx.TransientPayload, "language") then 'language']: ctx.TransientPayload.language,
[if std.objectHas(ctx, "TransientPayload") && std.objectHas(ctx.TransientPayload, "application") then 'application']: ctx.TransientPayload.application,
template: if std.objectHas(ctx, "VerificationCode") then 'sms_localisation.verification_code_with_link' else 'sms_localisation.account_activation_passcode',
templateParameters: {
"host": "portal.monta.app",
"flow_id": getFlowId(ctx),
"operator": getOperator(ctx),
},
}
22 changes: 22 additions & 0 deletions test/e2e/profiles/code/.kratos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,28 @@ selfservice:
enabled: true
config:
lifespan: 1h
external_sms_verify:
enabled: true
verification_start_request:
url: http://notification:8080/api/verifications
method: POST
body: file://stub/code.verification.start.jsonnet
auth:
type: api_key
config:
name: Authorization
value: ...
in: header
verification_check_request:
url: http://notification:8080/api/verifications/check
method: POST
body: file://stub/code.verification.check.jsonnet
auth:
type: api_key
config:
name: Authorization
value: ...
in: header

identity:
schemas:
Expand Down

0 comments on commit f13a642

Please sign in to comment.