From af890c23c1b1b3338e29f46b822094330845398e Mon Sep 17 00:00:00 2001 From: Stanley Phu Date: Mon, 12 Aug 2024 15:26:12 -0700 Subject: [PATCH] Add retryable http client and use in FGA module (#356) * Create retryable http client This retryable client lives in a new `retryablehttp` package and implements a `Do` method that includes built-in retry logic. It implements exponential backoff and retries on request errors as well as 50X response codes. * Use retryable http client in FGA module * Check request context's Done() channel to handle cancellations * Update minimum delay duration to 250ms The minimum delay time should be 250ms to account for the initial delay (500ms) with max negative jitter (250ms). --- pkg/fga/client.go | 5 +- pkg/fga/client_test.go | 27 +++---- pkg/fga/fga_test.go | 27 +++---- pkg/retryablehttp/client.go | 110 +++++++++++++++++++++++++++++ pkg/retryablehttp/client_test.go | 116 +++++++++++++++++++++++++++++++ 5 files changed, 257 insertions(+), 28 deletions(-) create mode 100644 pkg/retryablehttp/client.go create mode 100644 pkg/retryablehttp/client_test.go diff --git a/pkg/fga/client.go b/pkg/fga/client.go index 8fe92442..43b29f85 100644 --- a/pkg/fga/client.go +++ b/pkg/fga/client.go @@ -13,6 +13,7 @@ import ( "github.com/google/go-querystring/query" "github.com/workos/workos-go/v4/internal/workos" "github.com/workos/workos-go/v4/pkg/common" + "github.com/workos/workos-go/v4/pkg/retryablehttp" "github.com/workos/workos-go/v4/pkg/workos_errors" ) @@ -42,7 +43,7 @@ type Client struct { // The http.Client that is used to get FGA records from WorkOS. // Defaults to http.Client. - HTTPClient *http.Client + HTTPClient *retryablehttp.HttpClient // The endpoint to WorkOS API. Defaults to https://api.workos.com. Endpoint string @@ -55,7 +56,7 @@ type Client struct { func (c *Client) init() { if c.HTTPClient == nil { - c.HTTPClient = &http.Client{Timeout: 10 * time.Second} + c.HTTPClient = &retryablehttp.HttpClient{Client: http.Client{Timeout: 10 * time.Second}} } if c.Endpoint == "" { diff --git a/pkg/fga/client_test.go b/pkg/fga/client_test.go index c3713aef..9f81f285 100644 --- a/pkg/fga/client_test.go +++ b/pkg/fga/client_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/workos/workos-go/v4/pkg/common" + "github.com/workos/workos-go/v4/pkg/retryablehttp" ) func TestGetResource(t *testing.T) { @@ -49,7 +50,7 @@ func TestGetResource(t *testing.T) { client := test.client client.Endpoint = server.URL - client.HTTPClient = server.Client() + client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()} resource, err := client.GetResource(context.Background(), test.options) if test.err { @@ -130,7 +131,7 @@ func TestListResources(t *testing.T) { client := test.client client.Endpoint = server.URL - client.HTTPClient = server.Client() + client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()} resources, err := client.ListResources(context.Background(), test.options) if test.err { @@ -240,7 +241,7 @@ func TestListResourceTypes(t *testing.T) { client := test.client client.Endpoint = server.URL - client.HTTPClient = server.Client() + client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()} resourceTypes, err := client.ListResourceTypes(context.Background(), test.options) if test.err { @@ -367,7 +368,7 @@ func TestBatchUpdateResourceTypes(t *testing.T) { client := test.client client.Endpoint = server.URL - client.HTTPClient = server.Client() + client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()} resourceTypes, err := client.BatchUpdateResourceTypes(context.Background(), test.options) if test.err { @@ -488,7 +489,7 @@ func TestCreateResource(t *testing.T) { client := test.client client.Endpoint = server.URL - client.HTTPClient = server.Client() + client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()} resource, err := client.CreateResource(context.Background(), test.options) if test.err { @@ -578,7 +579,7 @@ func TestUpdateResource(t *testing.T) { client := test.client client.Endpoint = server.URL - client.HTTPClient = server.Client() + client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()} resource, err := client.UpdateResource(context.Background(), test.options) if test.err { @@ -665,7 +666,7 @@ func TestDeleteResource(t *testing.T) { client := test.client client.Endpoint = server.URL - client.HTTPClient = server.Client() + client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()} err := client.DeleteResource(context.Background(), test.options) if test.err { @@ -765,7 +766,7 @@ func TestListWarrants(t *testing.T) { client := test.client client.Endpoint = server.URL - client.HTTPClient = server.Client() + client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()} resources, err := client.ListWarrants(context.Background(), test.options) if test.err { @@ -907,7 +908,7 @@ func TestWriteWarrant(t *testing.T) { client := test.client client.Endpoint = server.URL - client.HTTPClient = server.Client() + client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()} warrantResponse, err := client.WriteWarrant(context.Background(), test.options) if test.err { @@ -973,7 +974,7 @@ func TestBatchWriteWarrants(t *testing.T) { client := test.client client.Endpoint = server.URL - client.HTTPClient = server.Client() + client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()} warrantResponse, err := client.BatchWriteWarrants(context.Background(), test.options) if test.err { @@ -1057,7 +1058,7 @@ func TestCheck(t *testing.T) { client := test.client client.Endpoint = server.URL - client.HTTPClient = server.Client() + client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()} checkResult, err := client.Check(context.Background(), test.options) if test.err { @@ -1157,7 +1158,7 @@ func TestCheckBatch(t *testing.T) { client := test.client client.Endpoint = server.URL - client.HTTPClient = server.Client() + client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()} checkResults, err := client.CheckBatch(context.Background(), test.options) if test.err { @@ -1256,7 +1257,7 @@ func TestQuery(t *testing.T) { client := test.client client.Endpoint = server.URL - client.HTTPClient = server.Client() + client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()} queryResults, err := client.Query(context.Background(), test.options) if test.err { diff --git a/pkg/fga/fga_test.go b/pkg/fga/fga_test.go index 3ef0ef6b..f2d9929b 100644 --- a/pkg/fga/fga_test.go +++ b/pkg/fga/fga_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" "github.com/workos/workos-go/v4/pkg/common" + "github.com/workos/workos-go/v4/pkg/retryablehttp" ) func TestFGAGetResource(t *testing.T) { @@ -15,7 +16,7 @@ func TestFGAGetResource(t *testing.T) { defer server.Close() DefaultClient = &Client{ - HTTPClient: server.Client(), + HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()}, Endpoint: server.URL, } SetAPIKey("test") @@ -38,7 +39,7 @@ func TestFGAListResources(t *testing.T) { defer server.Close() DefaultClient = &Client{ - HTTPClient: server.Client(), + HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()}, Endpoint: server.URL, } SetAPIKey("test") @@ -72,7 +73,7 @@ func TestFGACreateResource(t *testing.T) { defer server.Close() DefaultClient = &Client{ - HTTPClient: server.Client(), + HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()}, Endpoint: server.URL, } SetAPIKey("test") @@ -95,7 +96,7 @@ func TestFGAUpdateResource(t *testing.T) { defer server.Close() DefaultClient = &Client{ - HTTPClient: server.Client(), + HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()}, Endpoint: server.URL, } SetAPIKey("test") @@ -124,7 +125,7 @@ func TestFGADeleteResource(t *testing.T) { defer server.Close() DefaultClient = &Client{ - HTTPClient: server.Client(), + HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()}, Endpoint: server.URL, } SetAPIKey("test") @@ -142,7 +143,7 @@ func TestFGAListResourceTypes(t *testing.T) { defer server.Close() DefaultClient = &Client{ - HTTPClient: server.Client(), + HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()}, Endpoint: server.URL, } SetAPIKey("test") @@ -184,7 +185,7 @@ func TestFGABatchUpdateResourceTypes(t *testing.T) { defer server.Close() DefaultClient = &Client{ - HTTPClient: server.Client(), + HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()}, Endpoint: server.URL, } SetAPIKey("test") @@ -235,7 +236,7 @@ func TestFGAListWarrants(t *testing.T) { defer server.Close() DefaultClient = &Client{ - HTTPClient: server.Client(), + HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()}, Endpoint: server.URL, } SetAPIKey("test") @@ -279,7 +280,7 @@ func TestFGAWriteWarrant(t *testing.T) { defer server.Close() DefaultClient = &Client{ - HTTPClient: server.Client(), + HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()}, Endpoint: server.URL, } SetAPIKey("test") @@ -307,7 +308,7 @@ func TestFGABatchWriteWarrants(t *testing.T) { defer server.Close() DefaultClient = &Client{ - HTTPClient: server.Client(), + HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()}, Endpoint: server.URL, } SetAPIKey("test") @@ -347,7 +348,7 @@ func TestFGACheck(t *testing.T) { defer server.Close() DefaultClient = &Client{ - HTTPClient: server.Client(), + HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()}, Endpoint: server.URL, } SetAPIKey("test") @@ -375,7 +376,7 @@ func TestFGACheckBatch(t *testing.T) { defer server.Close() DefaultClient = &Client{ - HTTPClient: server.Client(), + HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()}, Endpoint: server.URL, } SetAPIKey("test") @@ -405,7 +406,7 @@ func TestFGAQuery(t *testing.T) { defer server.Close() DefaultClient = &Client{ - HTTPClient: server.Client(), + HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()}, Endpoint: server.URL, } SetAPIKey("test") diff --git a/pkg/retryablehttp/client.go b/pkg/retryablehttp/client.go new file mode 100644 index 00000000..2b4f72dc --- /dev/null +++ b/pkg/retryablehttp/client.go @@ -0,0 +1,110 @@ +package retryablehttp + +import ( + "io" + "math" + "math/rand" + "net/http" + "time" +) + +const MaxRetryAttempts = 3 +const MinimumDelay = 500 +const MinimumDelayDuration = 250 * time.Millisecond +const MaximumDelayDuration = 5 * time.Second +const RandomizationFactor = 0.5 +const BackoffMultiplier = 1.5 + +type HttpClient struct { + http.Client +} + +func (client *HttpClient) Do(req *http.Request) (*http.Response, error) { + var res *http.Response + var err error + for retry := 0; ; { + // Reset the request body for each retry + if req.Body != nil { + body, err := req.GetBody() + if err != nil { + client.CloseIdleConnections() + return res, err + } + if c, ok := body.(io.ReadCloser); ok { + req.Body = c + } else { + req.Body = io.NopCloser(body) + } + } + + res, err = client.Client.Do(req) + if err != nil { + break + } + + shouldRetry := client.shouldRetry(req, res, err, retry) + + if !shouldRetry { + break + } + + sleepTime := client.sleepTime(retry) + retry++ + + timer := time.NewTimer(sleepTime) + select { + case <-req.Context().Done(): + timer.Stop() + client.CloseIdleConnections() + return nil, req.Context().Err() + case <-timer.C: + } + } + + if err != nil { + return nil, err + } + + return res, nil +} + +func (client *HttpClient) shouldRetry(req *http.Request, resp *http.Response, err error, retryAttempt int) bool { + if retryAttempt >= MaxRetryAttempts { + return false + } + + if err != nil { + return true + } + + if resp.StatusCode >= http.StatusInternalServerError { + return true + } + + return false +} + +// Calculates backoff time using exponential backoff with 50% jitter. +// +// Backoff times +// Retry Attempt | Sleep Time +// 1 | 500ms +/- 250ms +// 2 | 750ms +/- 375ms +// 3 | 1.125s +/- 562ms +func (client *HttpClient) sleepTime(retryAttempt int) time.Duration { + sleepTime := time.Duration(MinimumDelay*int64(math.Pow(BackoffMultiplier, float64(retryAttempt)))) * time.Millisecond + + delta := RandomizationFactor * float64(sleepTime) + minSleep := float64(sleepTime) - delta + maxSleep := float64(sleepTime) + delta + + sleepTime = time.Duration(minSleep + (rand.Float64() * (maxSleep - minSleep + 1))) + + if sleepTime < MinimumDelayDuration { + sleepTime = MinimumDelayDuration + } else if sleepTime > MaximumDelayDuration { + sleepTime = MaximumDelayDuration + } + + return sleepTime +} diff --git a/pkg/retryablehttp/client_test.go b/pkg/retryablehttp/client_test.go new file mode 100644 index 00000000..725f3a59 --- /dev/null +++ b/pkg/retryablehttp/client_test.go @@ -0,0 +1,116 @@ +package retryablehttp + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +type testServerResponse struct { + http.Response + Message string `json:"message"` +} + +func TestDo(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := testServerResponse{ + Message: "Success", + } + + responseBody, err := json.Marshal(response) + require.NoError(t, err) + _, err = w.Write(responseBody) + require.NoError(t, err) + })) + defer testServer.Close() + + client := HttpClient{} + req, err := http.NewRequest(http.MethodGet, testServer.URL, nil) + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + + var resBody testServerResponse + dec := json.NewDecoder(resp.Body) + err = dec.Decode(&resBody) + require.NoError(t, err) + + require.Equal(t, "Success", resBody.Message) +} + +func TestDo_Retry(t *testing.T) { + requests := 0 + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch requests { + case 0: + w.WriteHeader(http.StatusInternalServerError) + _, err := w.Write([]byte("Internal Server Error - retry request")) + require.NoError(t, err) + case 1: + response := testServerResponse{ + Message: "Success", + } + + responseBody, err := json.Marshal(response) + require.NoError(t, err) + _, err = w.Write(responseBody) + require.NoError(t, err) + default: + require.Fail(t, "unexpected number of requests") + } + + requests++ + })) + defer testServer.Close() + + client := HttpClient{} + req, err := http.NewRequest(http.MethodGet, testServer.URL, nil) + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + + var resBody testServerResponse + dec := json.NewDecoder(resp.Body) + err = dec.Decode(&resBody) + require.NoError(t, err) + + require.Equal(t, "Success", resBody.Message) + require.Equal(t, 2, requests) +} + +func TestShouldRetry(t *testing.T) { + client := HttpClient{} + + t.Run("Max retry attempts reached", func(t *testing.T) { + shouldRetry := client.shouldRetry(&http.Request{Method: http.MethodGet}, &http.Response{StatusCode: http.StatusInternalServerError}, nil, MaxRetryAttempts) + require.False(t, shouldRetry) + }) + + t.Run("Request context error", func(t *testing.T) { + ctxWithCancel, cancel := context.WithCancel(context.Background()) + cancel() + req, err := http.NewRequestWithContext(ctxWithCancel, http.MethodGet, "http://example.com", nil) + require.NoError(t, err) + + shouldRetry := client.shouldRetry(req, &http.Response{StatusCode: http.StatusOK}, nil, 0) + require.False(t, shouldRetry) + }) + + t.Run("Retry on request errors", func(t *testing.T) { + shouldRetry := client.shouldRetry(&http.Request{Method: http.MethodGet}, nil, http.ErrHandlerTimeout, 0) + require.True(t, shouldRetry) + }) + + t.Run("Retry on 50X response codes", func(t *testing.T) { + shouldRetry := client.shouldRetry(&http.Request{Method: http.MethodGet}, &http.Response{StatusCode: http.StatusInternalServerError}, nil, 0) + require.True(t, shouldRetry) + + shouldRetry = client.shouldRetry(&http.Request{Method: http.MethodGet}, &http.Response{StatusCode: http.StatusBadGateway}, nil, 0) + require.True(t, shouldRetry) + }) +}