Skip to content

Commit

Permalink
Add retryable http client and use in FGA module (#356)
Browse files Browse the repository at this point in the history
* Create retryable http client

This retryable client lives in a new `retryablehttp` package and implements a `Do` method that includes built-in retry logic. It implements exponential backoff and retries on request errors as well as 50X response codes.

* Use retryable http client in FGA module

* Check request context's Done() channel to handle cancellations

* Update minimum delay duration to 250ms

The minimum delay time should be 250ms to account for the initial delay (500ms) with max negative jitter (250ms).
  • Loading branch information
stanleyphu authored Aug 12, 2024
1 parent 217f774 commit af890c2
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 28 deletions.
5 changes: 3 additions & 2 deletions pkg/fga/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/google/go-querystring/query"
"github.com/workos/workos-go/v4/internal/workos"
"github.com/workos/workos-go/v4/pkg/common"
"github.com/workos/workos-go/v4/pkg/retryablehttp"
"github.com/workos/workos-go/v4/pkg/workos_errors"
)

Expand Down Expand Up @@ -42,7 +43,7 @@ type Client struct {

// The http.Client that is used to get FGA records from WorkOS.
// Defaults to http.Client.
HTTPClient *http.Client
HTTPClient *retryablehttp.HttpClient

// The endpoint to WorkOS API. Defaults to https://api.workos.com.
Endpoint string
Expand All @@ -55,7 +56,7 @@ type Client struct {

func (c *Client) init() {
if c.HTTPClient == nil {
c.HTTPClient = &http.Client{Timeout: 10 * time.Second}
c.HTTPClient = &retryablehttp.HttpClient{Client: http.Client{Timeout: 10 * time.Second}}
}

if c.Endpoint == "" {
Expand Down
27 changes: 14 additions & 13 deletions pkg/fga/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/stretchr/testify/require"
"github.com/workos/workos-go/v4/pkg/common"
"github.com/workos/workos-go/v4/pkg/retryablehttp"
)

func TestGetResource(t *testing.T) {
Expand Down Expand Up @@ -49,7 +50,7 @@ func TestGetResource(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

resource, err := client.GetResource(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -130,7 +131,7 @@ func TestListResources(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

resources, err := client.ListResources(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -240,7 +241,7 @@ func TestListResourceTypes(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

resourceTypes, err := client.ListResourceTypes(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -367,7 +368,7 @@ func TestBatchUpdateResourceTypes(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

resourceTypes, err := client.BatchUpdateResourceTypes(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -488,7 +489,7 @@ func TestCreateResource(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

resource, err := client.CreateResource(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -578,7 +579,7 @@ func TestUpdateResource(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

resource, err := client.UpdateResource(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -665,7 +666,7 @@ func TestDeleteResource(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

err := client.DeleteResource(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -765,7 +766,7 @@ func TestListWarrants(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

resources, err := client.ListWarrants(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -907,7 +908,7 @@ func TestWriteWarrant(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

warrantResponse, err := client.WriteWarrant(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -973,7 +974,7 @@ func TestBatchWriteWarrants(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

warrantResponse, err := client.BatchWriteWarrants(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -1057,7 +1058,7 @@ func TestCheck(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

checkResult, err := client.Check(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -1157,7 +1158,7 @@ func TestCheckBatch(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

checkResults, err := client.CheckBatch(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -1256,7 +1257,7 @@ func TestQuery(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

queryResults, err := client.Query(context.Background(), test.options)
if test.err {
Expand Down
27 changes: 14 additions & 13 deletions pkg/fga/fga_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ import (

"github.com/stretchr/testify/require"
"github.com/workos/workos-go/v4/pkg/common"
"github.com/workos/workos-go/v4/pkg/retryablehttp"
)

func TestFGAGetResource(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(getResourceTestHandler))
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand All @@ -38,7 +39,7 @@ func TestFGAListResources(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down Expand Up @@ -72,7 +73,7 @@ func TestFGACreateResource(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand All @@ -95,7 +96,7 @@ func TestFGAUpdateResource(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down Expand Up @@ -124,7 +125,7 @@ func TestFGADeleteResource(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand All @@ -142,7 +143,7 @@ func TestFGAListResourceTypes(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down Expand Up @@ -184,7 +185,7 @@ func TestFGABatchUpdateResourceTypes(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down Expand Up @@ -235,7 +236,7 @@ func TestFGAListWarrants(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down Expand Up @@ -279,7 +280,7 @@ func TestFGAWriteWarrant(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down Expand Up @@ -307,7 +308,7 @@ func TestFGABatchWriteWarrants(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down Expand Up @@ -347,7 +348,7 @@ func TestFGACheck(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down Expand Up @@ -375,7 +376,7 @@ func TestFGACheckBatch(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down Expand Up @@ -405,7 +406,7 @@ func TestFGAQuery(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down
110 changes: 110 additions & 0 deletions pkg/retryablehttp/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package retryablehttp

import (
"io"
"math"
"math/rand"
"net/http"
"time"
)

const MaxRetryAttempts = 3
const MinimumDelay = 500
const MinimumDelayDuration = 250 * time.Millisecond
const MaximumDelayDuration = 5 * time.Second
const RandomizationFactor = 0.5
const BackoffMultiplier = 1.5

type HttpClient struct {
http.Client
}

func (client *HttpClient) Do(req *http.Request) (*http.Response, error) {
var res *http.Response
var err error
for retry := 0; ; {
// Reset the request body for each retry
if req.Body != nil {
body, err := req.GetBody()
if err != nil {
client.CloseIdleConnections()
return res, err
}
if c, ok := body.(io.ReadCloser); ok {
req.Body = c
} else {
req.Body = io.NopCloser(body)
}
}

res, err = client.Client.Do(req)
if err != nil {
break
}

shouldRetry := client.shouldRetry(req, res, err, retry)

if !shouldRetry {
break
}

sleepTime := client.sleepTime(retry)
retry++

timer := time.NewTimer(sleepTime)
select {
case <-req.Context().Done():
timer.Stop()
client.CloseIdleConnections()
return nil, req.Context().Err()
case <-timer.C:
}
}

if err != nil {
return nil, err
}

return res, nil
}

func (client *HttpClient) shouldRetry(req *http.Request, resp *http.Response, err error, retryAttempt int) bool {
if retryAttempt >= MaxRetryAttempts {
return false
}

if err != nil {
return true
}

if resp.StatusCode >= http.StatusInternalServerError {
return true
}

return false
}

// Calculates backoff time using exponential backoff with 50% jitter.
//
// Backoff times
// Retry Attempt | Sleep Time
// 1 | 500ms +/- 250ms
// 2 | 750ms +/- 375ms
// 3 | 1.125s +/- 562ms
func (client *HttpClient) sleepTime(retryAttempt int) time.Duration {
sleepTime := time.Duration(MinimumDelay*int64(math.Pow(BackoffMultiplier, float64(retryAttempt)))) * time.Millisecond

delta := RandomizationFactor * float64(sleepTime)
minSleep := float64(sleepTime) - delta
maxSleep := float64(sleepTime) + delta

sleepTime = time.Duration(minSleep + (rand.Float64() * (maxSleep - minSleep + 1)))

if sleepTime < MinimumDelayDuration {
sleepTime = MinimumDelayDuration
} else if sleepTime > MaximumDelayDuration {
sleepTime = MaximumDelayDuration
}

return sleepTime
}
Loading

0 comments on commit af890c2

Please sign in to comment.