diff --git a/go.mod b/go.mod index e4bbc9eb38..34e2cc6302 100644 --- a/go.mod +++ b/go.mod @@ -57,6 +57,7 @@ require ( github.com/containerd/stargz-snapshotter/estargz v0.15.1 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.5 // indirect github.com/cyphar/filepath-securejoin v0.2.5 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dlclark/regexp2 v1.11.0 // indirect github.com/docker/distribution v2.8.3+incompatible // indirect github.com/docker/docker-credential-helpers v0.8.1 // indirect @@ -84,6 +85,7 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0-rc3 // indirect github.com/pjbgf/sha1cd v0.3.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect diff --git a/pkg/osv/osv.go b/pkg/osv/osv.go index e27d700e60..cd68c0b33a 100644 --- a/pkg/osv/osv.go +++ b/pkg/osv/osv.go @@ -156,21 +156,6 @@ func chunkBy[T any](items []T, chunkSize int) [][]T { return append(chunks, items) } -// checkResponseError checks if the response has an error. -func checkResponseError(resp *http.Response) error { - if resp.StatusCode == http.StatusOK { - return nil - } - - respBuf, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read error response from server: %w", err) - } - defer resp.Body.Close() - - return fmt.Errorf("server response error: %s", string(respBuf)) -} - // MakeRequest sends a batched query to osv.dev func MakeRequest(request BatchedQuery) (*BatchedResponse, error) { return MakeRequestWithClient(request, http.DefaultClient) @@ -306,10 +291,9 @@ func HydrateWithClient(resp *BatchedResponse, client *http.Client) (*HydratedBat return &hydrated, nil } -// makeRetryRequest will return an error on both network errors, and if the response is not 200 +// makeRetryRequest executes HTTP requests with exponential backoff retry logic func makeRetryRequest(action func() (*http.Response, error)) (*http.Response, error) { - var resp *http.Response - var err error + var lastErr error for i := range maxRetryAttempts { // rand is initialized with a random number (since go1.20), and is also safe to use concurrently @@ -318,17 +302,31 @@ func makeRetryRequest(action func() (*http.Response, error)) (*http.Response, er jitterAmount := (rand.Float64() * float64(jitterMultiplier) * float64(i)) time.Sleep(time.Duration(i*i)*time.Second + time.Duration(jitterAmount*1000)*time.Millisecond) - resp, err = action() - if err == nil { - // Check the response for HTTP errors - err = checkResponseError(resp) - if err == nil { - break - } + resp, err := action() + if err != nil { + lastErr = fmt.Errorf("attempt %d: request failed: %w", i+1, err) + continue + } + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + return resp, nil } + + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + lastErr = fmt.Errorf("attempt %d: failed to read response: %w", i+1, err) + continue + } + + if resp.StatusCode >= 400 && resp.StatusCode < 500 { + return nil, fmt.Errorf("client error: status=%d body=%s", resp.StatusCode, body) + } + + lastErr = fmt.Errorf("server error: status=%d body=%s", resp.StatusCode, body) } - return resp, err + return nil, fmt.Errorf("max retries exceeded: %w", lastErr) } func MakeDetermineVersionRequest(name string, hashes []DetermineVersionHash) (*DetermineVersionResponse, error) { diff --git a/pkg/osv/osv_test.go b/pkg/osv/osv_test.go new file mode 100644 index 0000000000..68d8e51688 --- /dev/null +++ b/pkg/osv/osv_test.go @@ -0,0 +1,107 @@ +package osv + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestMakeRetryRequest(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + statusCodes []int + expectedError string + wantAttempts int + }{ + { + name: "success on first attempt", + statusCodes: []int{http.StatusOK}, + wantAttempts: 1, + }, + { + name: "client error no retry", + statusCodes: []int{http.StatusBadRequest}, + expectedError: "client error: status=400", + wantAttempts: 1, + }, + { + name: "server error then success", + statusCodes: []int{http.StatusInternalServerError, http.StatusOK}, + wantAttempts: 2, + }, + { + name: "max retries on server error", + statusCodes: []int{http.StatusInternalServerError, http.StatusInternalServerError, http.StatusInternalServerError, http.StatusInternalServerError}, + expectedError: "max retries exceeded", + wantAttempts: 4, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + attempts := 0 + idx := 0 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + status := tt.statusCodes[idx] + if idx < len(tt.statusCodes)-1 { + idx++ + } + + w.WriteHeader(status) + message := fmt.Sprintf("response-%d", attempts) + w.Write([]byte(message)) + })) + defer server.Close() + + client := &http.Client{Timeout: time.Second} + + resp, err := makeRetryRequest(func() (*http.Response, error) { + return client.Get(server.URL) + }) + + if attempts != tt.wantAttempts { + t.Errorf("got %d attempts, want %d", attempts, tt.wantAttempts) + } + + if tt.expectedError != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.expectedError) + } + if !strings.Contains(err.Error(), tt.expectedError) { + t.Errorf("expected error containing %q, got %q", tt.expectedError, err) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resp == nil { + t.Fatal("expected non-nil response") + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + + expectedBody := fmt.Sprintf("response-%d", attempts) + if string(body) != expectedBody { + t.Errorf("got body %q, want %q", string(body), expectedBody) + } + }) + } +}