diff --git a/internal/elasticsearch/client_test.go b/internal/elasticsearch/client_test.go index 4126c879031..904b2b61236 100644 --- a/internal/elasticsearch/client_test.go +++ b/internal/elasticsearch/client_test.go @@ -18,6 +18,7 @@ package elasticsearch import ( + "bytes" "context" "fmt" "net/http" @@ -30,6 +31,7 @@ import ( apmVersion "github.com/elastic/apm-server/internal/version" esv8 "github.com/elastic/go-elasticsearch/v8" + "github.com/elastic/go-elasticsearch/v8/esapi" ) func TestClient(t *testing.T) { @@ -86,3 +88,114 @@ func TestClientCustomUserAgent(t *testing.T) { t.Fatal("timed out while waiting for request") } } + +type esMock struct { + responder http.HandlerFunc + ClusterUUID string + UUID string +} + +func (h *esMock) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Elastic-Product", "Elasticsearch") + + switch { + case r.Method == http.MethodGet && r.URL.Path == "/": + root := fmt.Sprintf("{\"name\" : \"mock\", \"cluster_uuid\" : \"%s\", \"version\" : { \"number\" : \"%s\", \"build_flavor\" : \"default\"}}", h.ClusterUUID, "runningtest") + w.Header().Set(http.CanonicalHeaderKey("Content-Type"), "application/json") + w.Write([]byte(root)) + return + case r.Method == http.MethodPost && r.URL.Path == "/_bulk": + h.responder(w, r) + return + case r.Method == http.MethodGet && r.URL.Path == "/_license": + license := fmt.Sprintf("{\"license\" : {\"status\" : \"active\", \"uid\" : \"%s\", \"type\" : \"trial\", \"expiry_date_in_millis\" : %d}}", h.UUID, time.Now().Add(1*time.Hour).Unix()) + w.Header().Set(http.CanonicalHeaderKey("Content-Type"), "application/json") + w.Write([]byte(license)) + return + default: + http.Error(w, "unsupported request", 419) // Signal unexpected error + return + } +} + +func TestClientRetryableStatuses(t *testing.T) { + bc := Config{ + Username: "test", + Password: "foobar", + Backoff: BackoffConfig{ + Init: 0, + Max: 0, + }, + MaxRetries: 2, + } + + tests := []struct { + name string + responseStatusCode int + handler http.HandlerFunc + + expectedStatusCode int + expectedRequestCount int + }{ + { + name: "retry 429 Too Many Requests", + + responseStatusCode: http.StatusTooManyRequests, + expectedStatusCode: http.StatusOK, + expectedRequestCount: 2, + }, + { + name: "retry 502 Bad Gateway", + + responseStatusCode: http.StatusBadGateway, + expectedStatusCode: http.StatusBadGateway, + expectedRequestCount: 1, + }, + { + name: "retry 503 Service Not Available", + + responseStatusCode: http.StatusServiceUnavailable, + expectedStatusCode: http.StatusServiceUnavailable, + expectedRequestCount: 1, + }, + { + name: "retry 504 Gateway Timeout", + responseStatusCode: http.StatusGatewayTimeout, + expectedStatusCode: http.StatusGatewayTimeout, + expectedRequestCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + count := 0 + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if count < bc.MaxRetries { + count += 1 + http.Error(w, "", tt.responseStatusCode) + return + } + + w.WriteHeader(http.StatusOK) + }) + + es := esMock{responder: handler, ClusterUUID: "8f8d1c95-dde0-4c11-bf27-063e2a819a4c", UUID: "b97b91b3-16e0-49e2-a635-000c97059b46"} + srv := httptest.NewServer(&es) + defer srv.Close() + c := bc + c.Hosts = []string{srv.URL} + + client, err := NewClient(&c) + require.NoError(t, err) + + buf := bytes.Buffer{} + + var res *esapi.Response + res, err = client.Bulk(bytes.NewReader(buf.Bytes())) + require.NoError(t, err) + assert.Equal(t, tt.expectedStatusCode, res.StatusCode) + assert.Equal(t, tt.expectedRequestCount, count) + + }) + } +}