diff --git a/net/http/retrier.go b/net/http/retrier.go index 39c4fbb..47da0d0 100644 --- a/net/http/retrier.go +++ b/net/http/retrier.go @@ -19,19 +19,17 @@ import ( // ErrStatusCode can be used to treat/indicate a status code as an error and ability to indicate if it is retryable. type ErrStatusCode struct { - // StatusCode is the HTTP response status code that was encountered. - StatusCode int + // Response contains the full HTTP response associated with the request. + // It is the responsibility of the caller to handle closing the body. + Response *http.Response // IsRetryableStatusCode indicates if the status code is considered retryable. IsRetryableStatusCode bool - - // Headers contains the headers from the HTTP response. - Headers http.Header } // Error returns the error message for the status code. func (e ErrStatusCode) Error() string { - return "status code encountered: " + strconv.Itoa(e.StatusCode) + return "status code encountered: " + strconv.Itoa(e.Response.StatusCode) } // IsRetryable returns if the provided status code is considered retryable. @@ -60,6 +58,7 @@ type Retryer struct { maxMemory bytesext.Bytes mode errorsext.MaxAttemptsMode maxAttempts uint8 + extractStatusCodeBody bool } // NewRetryer returns a new `Retryer` with sane default values. @@ -81,10 +80,11 @@ type Retryer struct { // however every attempt will be made to maintain backwards compatibility or made additive-only if possible. func NewRetryer() Retryer { return Retryer{ - client: http.DefaultClient, - maxMemory: 2 * bytesext.MiB, - mode: errorsext.MaxAttemptsNonRetryableReset, - maxAttempts: 5, + client: http.DefaultClient, + maxMemory: 2 * bytesext.MiB, + mode: errorsext.MaxAttemptsNonRetryableReset, + maxAttempts: 5, + extractStatusCodeBody: true, isRetryableFn: func(ctx context.Context, err error) (isRetryable bool) { _, isRetryable = errorsext.IsRetryableHTTP(err) return @@ -93,7 +93,7 @@ func NewRetryer() Retryer { isEarlyReturnFn: func(_ context.Context, err error) bool { var sce ErrStatusCode if errors.As(err, &sce) { - return IsNonRetryableStatusCode(sce.StatusCode) + return IsNonRetryableStatusCode(sce.Response.StatusCode) } return false }, @@ -109,8 +109,10 @@ func NewRetryer() Retryer { wait := time.Millisecond * 200 var sce ErrStatusCode - if (sce.StatusCode == http.StatusTooManyRequests || sce.StatusCode == http.StatusServiceUnavailable) && errors.As(err, &sce) && sce.Headers != nil { - if ra := HasRetryAfter(sce.Headers); ra.IsSome() { + if errors.As(err, &sce) && (sce.Response.StatusCode == http.StatusTooManyRequests || sce.Response.StatusCode == http.StatusServiceUnavailable) && sce.Response.Header != nil { + defer sce.Response.Body.Close() + + if ra := HasRetryAfter(sce.Response.Header); ra.IsSome() { wait = ra.Unwrap() } } @@ -191,6 +193,13 @@ func (r Retryer) Timeout(timeout time.Duration) Retryer { return r } +// DecodeBadStatusCodeResponseBody sets if the response body will be read, in its entirety up to the configured +// `MaxMemory`, if the status code is not in the expected list and added to the `ErrStatusCode.Body` error. +func (r Retryer) DecodeBadStatusCodeResponseBody(b bool) Retryer { + r.extractStatusCodeBody = b + return r +} + // DoResponse will execute the provided functions code and automatically retry before returning the *http.Response // based on HTTP status code, if defined, and can be used when processing of the response body may not be necessary // or something custom is required. @@ -208,17 +217,22 @@ func (r Retryer) DoResponse(ctx context.Context, fn BuildRequestFn2, expectedRes if req.IsErr() { return Err[*http.Response, error](req.Err()) } + resp, err := r.client.Do(req.Unwrap()) if err != nil { return Err[*http.Response, error](err) } + if len(expectedResponseCodes) > 0 { for _, code := range expectedResponseCodes { if resp.StatusCode == code { goto RETURN } } - return Err[*http.Response, error](ErrStatusCode{StatusCode: resp.StatusCode, IsRetryableStatusCode: r.isRetryableStatusCodeFn(ctx, resp.StatusCode), Headers: resp.Header}) + return Err[*http.Response, error](ErrStatusCode{ + Response: resp, + IsRetryableStatusCode: r.isRetryableStatusCodeFn(ctx, resp.StatusCode), + }) } RETURN: @@ -236,15 +250,22 @@ func (r Retryer) Do(ctx context.Context, fn BuildRequestFn2, v any, expectedResp Timeout(r.timeout). IsEarlyReturnFn(r.isEarlyReturnFn). Do(ctx, func(ctx context.Context) Result[typesext.Nothing, error] { + closeBody := true + req := fn(ctx) if req.IsErr() { return Err[typesext.Nothing, error](req.Err()) } + resp, err := r.client.Do(req.Unwrap()) if err != nil { return Err[typesext.Nothing, error](err) } - defer resp.Body.Close() + defer func() { + if closeBody { + _ = resp.Body.Close() + } + }() if len(expectedResponseCodes) > 0 { for _, code := range expectedResponseCodes { @@ -252,7 +273,11 @@ func (r Retryer) Do(ctx context.Context, fn BuildRequestFn2, v any, expectedResp goto DECODE } } - return Err[typesext.Nothing, error](ErrStatusCode{StatusCode: resp.StatusCode, IsRetryableStatusCode: r.isRetryableStatusCodeFn(ctx, resp.StatusCode), Headers: resp.Header}) + closeBody = false + return Err[typesext.Nothing, error](ErrStatusCode{ + Response: resp, + IsRetryableStatusCode: r.isRetryableStatusCodeFn(ctx, resp.StatusCode), + }) } DECODE: