Skip to content

Commit

Permalink
change ErrStatusCode
Browse files Browse the repository at this point in the history
  • Loading branch information
deankarn committed Mar 25, 2024
1 parent 02faf3a commit 533fa00
Showing 1 changed file with 41 additions and 16 deletions.
57 changes: 41 additions & 16 deletions net/http/retrier.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,17 @@ import (

// ErrStatusCode can be used to treat/indicate a status code as an error and ability to indicate if it is retryable.
type ErrStatusCode struct {
// StatusCode is the HTTP response status code that was encountered.
StatusCode int
// Response contains the full HTTP response associated with the request.
// It is the responsibility of the caller to handle closing the body.
Response *http.Response

// IsRetryableStatusCode indicates if the status code is considered retryable.
IsRetryableStatusCode bool

// Headers contains the headers from the HTTP response.
Headers http.Header
}

// Error returns the error message for the status code.
func (e ErrStatusCode) Error() string {
return "status code encountered: " + strconv.Itoa(e.StatusCode)
return "status code encountered: " + strconv.Itoa(e.Response.StatusCode)
}

// IsRetryable returns if the provided status code is considered retryable.
Expand Down Expand Up @@ -60,6 +58,7 @@ type Retryer struct {
maxMemory bytesext.Bytes
mode errorsext.MaxAttemptsMode
maxAttempts uint8
extractStatusCodeBody bool
}

// NewRetryer returns a new `Retryer` with sane default values.
Expand All @@ -81,10 +80,11 @@ type Retryer struct {
// however every attempt will be made to maintain backwards compatibility or made additive-only if possible.
func NewRetryer() Retryer {
return Retryer{
client: http.DefaultClient,
maxMemory: 2 * bytesext.MiB,
mode: errorsext.MaxAttemptsNonRetryableReset,
maxAttempts: 5,
client: http.DefaultClient,
maxMemory: 2 * bytesext.MiB,
mode: errorsext.MaxAttemptsNonRetryableReset,
maxAttempts: 5,
extractStatusCodeBody: true,
isRetryableFn: func(ctx context.Context, err error) (isRetryable bool) {
_, isRetryable = errorsext.IsRetryableHTTP(err)
return
Expand All @@ -93,7 +93,7 @@ func NewRetryer() Retryer {
isEarlyReturnFn: func(_ context.Context, err error) bool {
var sce ErrStatusCode
if errors.As(err, &sce) {
return IsNonRetryableStatusCode(sce.StatusCode)
return IsNonRetryableStatusCode(sce.Response.StatusCode)
}
return false
},
Expand All @@ -109,8 +109,10 @@ func NewRetryer() Retryer {
wait := time.Millisecond * 200

var sce ErrStatusCode
if (sce.StatusCode == http.StatusTooManyRequests || sce.StatusCode == http.StatusServiceUnavailable) && errors.As(err, &sce) && sce.Headers != nil {
if ra := HasRetryAfter(sce.Headers); ra.IsSome() {
if errors.As(err, &sce) && (sce.Response.StatusCode == http.StatusTooManyRequests || sce.Response.StatusCode == http.StatusServiceUnavailable) && sce.Response.Header != nil {
defer sce.Response.Body.Close()

if ra := HasRetryAfter(sce.Response.Header); ra.IsSome() {
wait = ra.Unwrap()
}
}
Expand Down Expand Up @@ -191,6 +193,13 @@ func (r Retryer) Timeout(timeout time.Duration) Retryer {
return r
}

// DecodeBadStatusCodeResponseBody sets if the response body will be read, in its entirety up to the configured
// `MaxMemory`, if the status code is not in the expected list and added to the `ErrStatusCode.Body` error.
func (r Retryer) DecodeBadStatusCodeResponseBody(b bool) Retryer {
r.extractStatusCodeBody = b
return r
}

// DoResponse will execute the provided functions code and automatically retry before returning the *http.Response
// based on HTTP status code, if defined, and can be used when processing of the response body may not be necessary
// or something custom is required.
Expand All @@ -208,17 +217,22 @@ func (r Retryer) DoResponse(ctx context.Context, fn BuildRequestFn2, expectedRes
if req.IsErr() {
return Err[*http.Response, error](req.Err())
}

resp, err := r.client.Do(req.Unwrap())
if err != nil {
return Err[*http.Response, error](err)
}

if len(expectedResponseCodes) > 0 {
for _, code := range expectedResponseCodes {
if resp.StatusCode == code {
goto RETURN
}
}
return Err[*http.Response, error](ErrStatusCode{StatusCode: resp.StatusCode, IsRetryableStatusCode: r.isRetryableStatusCodeFn(ctx, resp.StatusCode), Headers: resp.Header})
return Err[*http.Response, error](ErrStatusCode{
Response: resp,
IsRetryableStatusCode: r.isRetryableStatusCodeFn(ctx, resp.StatusCode),
})
}

RETURN:
Expand All @@ -236,23 +250,34 @@ func (r Retryer) Do(ctx context.Context, fn BuildRequestFn2, v any, expectedResp
Timeout(r.timeout).
IsEarlyReturnFn(r.isEarlyReturnFn).
Do(ctx, func(ctx context.Context) Result[typesext.Nothing, error] {
closeBody := true

req := fn(ctx)
if req.IsErr() {
return Err[typesext.Nothing, error](req.Err())
}

resp, err := r.client.Do(req.Unwrap())
if err != nil {
return Err[typesext.Nothing, error](err)
}
defer resp.Body.Close()
defer func() {
if closeBody {
_ = resp.Body.Close()
}
}()

if len(expectedResponseCodes) > 0 {
for _, code := range expectedResponseCodes {
if resp.StatusCode == code {
goto DECODE
}
}
return Err[typesext.Nothing, error](ErrStatusCode{StatusCode: resp.StatusCode, IsRetryableStatusCode: r.isRetryableStatusCodeFn(ctx, resp.StatusCode), Headers: resp.Header})
closeBody = false
return Err[typesext.Nothing, error](ErrStatusCode{
Response: resp,
IsRetryableStatusCode: r.isRetryableStatusCodeFn(ctx, resp.StatusCode),
})
}

DECODE:
Expand Down

0 comments on commit 533fa00

Please sign in to comment.