From 9909a38dbe0b66db7a7d3a41ac141f71372cb550 Mon Sep 17 00:00:00 2001 From: Arne Luenser Date: Wed, 11 Oct 2023 18:57:52 +0200 Subject: [PATCH] feat: add WithMaxHTTPMaxBytes option to fetcher to limit HTTP response body size --- fetcher/fetcher.go | 35 ++++++++++++++++++++++++++++------- fetcher/fetcher_test.go | 19 +++++++++++++++---- 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/fetcher/fetcher.go b/fetcher/fetcher.go index b58f0f51..1d6181ed 100644 --- a/fetcher/fetcher.go +++ b/fetcher/fetcher.go @@ -22,22 +22,32 @@ import ( // Fetcher is able to load file contents from http, https, file, and base64 locations. type Fetcher struct { - hc *retryablehttp.Client + hc *retryablehttp.Client + limit int64 } type opts struct { - hc *retryablehttp.Client + hc *retryablehttp.Client + limit int64 } var ErrUnknownScheme = stderrors.New("unknown scheme") // WithClient sets the http.Client the fetcher uses. -func WithClient(hc *retryablehttp.Client) func(*opts) { +func WithClient(hc *retryablehttp.Client) Modifier { return func(o *opts) { o.hc = hc } } +// WithMaxHTTPMaxBytes reads at most limit bytes from the HTTP response body, +// returning bytes.ErrToLarge if the limit would be exceeded. +func WithMaxHTTPMaxBytes(limit int64) Modifier { + return func(o *opts) { + o.limit = limit + } +} + func newOpts() *opts { return &opts{ hc: httpx.NewResilientClient(), @@ -52,7 +62,7 @@ func NewFetcher(opts ...Modifier) *Fetcher { for _, f := range opts { f(o) } - return &Fetcher{hc: o.hc} + return &Fetcher{hc: o.hc, limit: o.limit} } // Fetch fetches the file contents from the source. @@ -94,7 +104,18 @@ func (f *Fetcher) fetchRemote(ctx context.Context, source string) (*bytes.Buffer return nil, errors.Errorf("expected http response status code 200 but got %d when fetching: %s", res.StatusCode, source) } - return f.decode(res.Body) + if f.limit > 0 { + var buf bytes.Buffer + n, err := io.Copy(&buf, io.LimitReader(res.Body, f.limit+1)) + if n > f.limit { + return nil, bytes.ErrTooLarge + } + if err != nil { + return nil, err + } + return &buf, nil + } + return f.toBuffer(res.Body) } func (f *Fetcher) fetchFile(source string) (*bytes.Buffer, error) { @@ -106,10 +127,10 @@ func (f *Fetcher) fetchFile(source string) (*bytes.Buffer, error) { _ = fp.Close() }() - return f.decode(fp) + return f.toBuffer(fp) } -func (f *Fetcher) decode(r io.Reader) (*bytes.Buffer, error) { +func (f *Fetcher) toBuffer(r io.Reader) (*bytes.Buffer, error) { var b bytes.Buffer if _, err := io.Copy(&b, r); err != nil { return nil, err diff --git a/fetcher/fetcher_test.go b/fetcher/fetcher_test.go index 28c04154..32f24fab 100644 --- a/fetcher/fetcher_test.go +++ b/fetcher/fetcher_test.go @@ -4,6 +4,7 @@ package fetcher import ( + "bytes" "context" "encoding/base64" "fmt" @@ -16,7 +17,6 @@ import ( "github.com/gobuffalo/httptest" "github.com/julienschmidt/httprouter" - "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -67,9 +67,8 @@ func TestFetcher(t *testing.T) { t.Run("case=returns proper error on unknown scheme", func(t *testing.T) { _, err := NewFetcher().Fetch("unknown-scheme://foo") - require.NotNil(t, err) - assert.True(t, errors.Is(err, ErrUnknownScheme)) + assert.ErrorIs(t, err, ErrUnknownScheme) assert.Contains(t, err.Error(), "unknown-scheme") }) @@ -77,8 +76,20 @@ func TestFetcher(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() _, err := NewFetcher().FetchContext(ctx, "https://config.invalid") - require.NotNil(t, err) assert.ErrorIs(t, err, context.DeadlineExceeded) }) + + t.Run("case=with-limit", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write(bytes.Repeat([]byte("test"), 1000)) + })) + t.Cleanup(srv.Close) + + _, err := NewFetcher(WithMaxHTTPMaxBytes(3999)).Fetch(srv.URL) + assert.ErrorIs(t, err, bytes.ErrTooLarge) + + _, err = NewFetcher(WithMaxHTTPMaxBytes(4000)).Fetch(srv.URL) + assert.NoError(t, err) + }) }