From eef9598ea110abb7b4f37215fb396b0b32bdb2d8 Mon Sep 17 00:00:00 2001 From: Dan Kortschak <90160302+efd6@users.noreply.github.com> Date: Thu, 31 Aug 2023 06:03:48 +0930 Subject: [PATCH] x-pack/filebeat/input/httpjson: reorganise code to improve code locality (#36439) * chain.go => config_chain.go * simplify encoding registry * simplify transform registry This removes the registration code and makes the use of the registry effectively local rather than a mutated global. It also simplifies the structure of the registry type. * improve code locality of chain helpers * improve code locality of request Use conventional ordering: type, constructor, methods and caller before callee in source order. * improve code locality of response --- .../httpjson/{chain.go => config_chain.go} | 0 .../filebeat/input/httpjson/config_request.go | 2 +- .../input/httpjson/config_response.go | 6 +- x-pack/filebeat/input/httpjson/encoding.go | 109 +-- x-pack/filebeat/input/httpjson/input.go | 37 + .../filebeat/input/httpjson/input_manager.go | 5 - x-pack/filebeat/input/httpjson/input_test.go | 33 - .../filebeat/input/httpjson/metrics_test.go | 3 - x-pack/filebeat/input/httpjson/pagination.go | 10 +- x-pack/filebeat/input/httpjson/request.go | 774 +++++++++++------- .../input/httpjson/request_chain_helper.go | 216 ----- .../httpjson/request_chain_helper_test.go | 164 ---- .../filebeat/input/httpjson/request_test.go | 153 +++- x-pack/filebeat/input/httpjson/response.go | 102 ++- x-pack/filebeat/input/httpjson/split.go | 2 +- x-pack/filebeat/input/httpjson/split_test.go | 2 - x-pack/filebeat/input/httpjson/transform.go | 8 +- .../input/httpjson/transform_registry.go | 73 +- .../filebeat/input/httpjson/transform_test.go | 24 +- 19 files changed, 783 insertions(+), 940 deletions(-) rename x-pack/filebeat/input/httpjson/{chain.go => config_chain.go} (100%) delete mode 100644 x-pack/filebeat/input/httpjson/request_chain_helper.go delete mode 100644 x-pack/filebeat/input/httpjson/request_chain_helper_test.go diff --git a/x-pack/filebeat/input/httpjson/chain.go b/x-pack/filebeat/input/httpjson/config_chain.go similarity index 100% rename from x-pack/filebeat/input/httpjson/chain.go rename to x-pack/filebeat/input/httpjson/config_chain.go diff --git a/x-pack/filebeat/input/httpjson/config_request.go b/x-pack/filebeat/input/httpjson/config_request.go index 86ec0a68a581..e12388986514 100644 --- a/x-pack/filebeat/input/httpjson/config_request.go +++ b/x-pack/filebeat/input/httpjson/config_request.go @@ -151,7 +151,7 @@ func (c *requestConfig) Validate() error { return fmt.Errorf("unsupported method %q", c.Method) } - if _, err := newBasicTransformsFromConfig(c.Transforms, requestNamespace, nil); err != nil { + if _, err := newBasicTransformsFromConfig(registeredTransforms, c.Transforms, requestNamespace, nil); err != nil { return err } diff --git a/x-pack/filebeat/input/httpjson/config_response.go b/x-pack/filebeat/input/httpjson/config_response.go index 2f0da5685f3e..0e9f2f26d891 100644 --- a/x-pack/filebeat/input/httpjson/config_response.go +++ b/x-pack/filebeat/input/httpjson/config_response.go @@ -37,10 +37,10 @@ type splitConfig struct { } func (c *responseConfig) Validate() error { - if _, err := newBasicTransformsFromConfig(c.Transforms, responseNamespace, nil); err != nil { + if _, err := newBasicTransformsFromConfig(registeredTransforms, c.Transforms, responseNamespace, nil); err != nil { return err } - if _, err := newBasicTransformsFromConfig(c.Pagination, paginationNamespace, nil); err != nil { + if _, err := newBasicTransformsFromConfig(registeredTransforms, c.Pagination, paginationNamespace, nil); err != nil { return err } if c.DecodeAs != "" { @@ -52,7 +52,7 @@ func (c *responseConfig) Validate() error { } func (c *splitConfig) Validate() error { - if _, err := newBasicTransformsFromConfig(c.Transforms, responseNamespace, nil); err != nil { + if _, err := newBasicTransformsFromConfig(registeredTransforms, c.Transforms, responseNamespace, nil); err != nil { return err } diff --git a/x-pack/filebeat/input/httpjson/encoding.go b/x-pack/filebeat/input/httpjson/encoding.go index 611894b05433..7bf851617898 100644 --- a/x-pack/filebeat/input/httpjson/encoding.go +++ b/x-pack/filebeat/input/httpjson/encoding.go @@ -14,57 +14,8 @@ import ( "net/http" "github.com/elastic/mito/lib/xml" - - "github.com/elastic/elastic-agent-libs/logp" -) - -type encoderFunc func(trReq transformable) ([]byte, error) - -type decoderFunc func(p []byte, dst *response) error - -var ( - registeredEncoders = map[string]encoderFunc{} - registeredDecoders = map[string]decoderFunc{} - defaultEncoder encoderFunc = encodeAsJSON - defaultDecoder decoderFunc = decodeAsJSON ) -func registerEncoder(contentType string, enc encoderFunc) error { - if contentType == "" { - return errors.New("content-type can't be empty") - } - - if enc == nil { - return errors.New("encoder can't be nil") - } - - if _, found := registeredEncoders[contentType]; found { - return errors.New("already registered") - } - - registeredEncoders[contentType] = enc - - return nil -} - -func registerDecoder(contentType string, dec decoderFunc) error { - if contentType == "" { - return errors.New("content-type can't be empty") - } - - if dec == nil { - return errors.New("decoder can't be nil") - } - - if _, found := registeredDecoders[contentType]; found { - return errors.New("already registered") - } - - registeredDecoders[contentType] = dec - - return nil -} - func encode(contentType string, trReq transformable) ([]byte, error) { enc, found := registeredEncoders[contentType] if !found { @@ -81,35 +32,34 @@ func decode(contentType string, p []byte, dst *response) error { return dec(p, dst) } -func registerEncoders() { - log := logp.L().Named(logName) - log.Debugf("registering encoder 'application/json': returned error: %#v", - registerEncoder("application/json", encodeAsJSON)) - - log.Debugf("registering encoder 'application/x-www-form-urlencoded': returned error: %#v", - registerEncoder("application/x-www-form-urlencoded", encodeAsForm)) -} - -func registerDecoders() { - log := logp.L().Named(logName) - log.Debugf("registering decoder 'application/json': returned error: %#v", - registerDecoder("application/json", decodeAsJSON)) - - log.Debugf("registering decoder 'application/x-ndjson': returned error: %#v", - registerDecoder("application/x-ndjson", decodeAsNdjson)) - - log.Debugf("registering decoder 'text/csv': returned error: %#v", - registerDecoder("text/csv", decodeAsCSV)) - - log.Debugf("registering decoder 'application/zip': returned error: %#v", - registerDecoder("application/zip", decodeAsZip)) +var ( + // registeredEncoders is the set of available encoders. + registeredEncoders = map[string]encoderFunc{ + "application/json": encodeAsJSON, + "application/x-www-form-urlencoded": encodeAsForm, + } + // defaultEncoder is the decoder used when no registers + // encoder is available. + defaultEncoder = encodeAsJSON + + // registeredDecoders is the set of available decoders. + registeredDecoders = map[string]decoderFunc{ + "application/json": decodeAsJSON, + "application/x-ndjson": decodeAsNdjson, + "text/csv": decodeAsCSV, + "application/zip": decodeAsZip, + "application/xml": decodeAsXML, + "text/xml; charset=utf-8": decodeAsXML, + } + // defaultDecoder is the decoder used when no registers + // decoder is available. + defaultDecoder = decodeAsJSON +) - log.Debugf("registering decoder 'application/xml': returned error: %#v", - registerDecoder("application/xml", decodeAsXML)) - log.Debugf("registering decoder 'text/xml': returned error: %#v", - registerDecoder("text/xml; charset=utf-8", decodeAsXML)) -} +type encoderFunc func(trReq transformable) ([]byte, error) +type decoderFunc func(p []byte, dst *response) error +// encodeAsJSON encodes trReq as a JSON message. func encodeAsJSON(trReq transformable) ([]byte, error) { if len(trReq.body()) == 0 { return nil, nil @@ -120,10 +70,12 @@ func encodeAsJSON(trReq transformable) ([]byte, error) { return json.Marshal(trReq.body()) } +// decodeAsJSON decodes the JSON message in p into dst. func decodeAsJSON(p []byte, dst *response) error { return json.Unmarshal(p, &dst.body) } +// encodeAsForm encodes trReq as a URL encoded form. func encodeAsForm(trReq transformable) ([]byte, error) { url := trReq.url() body := []byte(url.RawQuery) @@ -135,6 +87,8 @@ func encodeAsForm(trReq transformable) ([]byte, error) { return body, nil } +// decodeAsNdjson decodes the message in p as a JSON object stream +// It is more relaxed than NDJSON. func decodeAsNdjson(p []byte, dst *response) error { var results []interface{} dec := json.NewDecoder(bytes.NewReader(p)) @@ -149,6 +103,7 @@ func decodeAsNdjson(p []byte, dst *response) error { return nil } +// decodeAsCSV decodes p as a headed CSV document into dst. func decodeAsCSV(p []byte, dst *response) error { var results []interface{} @@ -189,6 +144,7 @@ func decodeAsCSV(p []byte, dst *response) error { return nil } +// decodeAsZip decodes p as a ZIP archive into dst. func decodeAsZip(p []byte, dst *response) error { var results []interface{} r, err := zip.NewReader(bytes.NewReader(p), int64(len(p))) @@ -225,6 +181,7 @@ func decodeAsZip(p []byte, dst *response) error { return nil } +// decodeAsXML decodes p as an XML document into dst. func decodeAsXML(p []byte, dst *response) error { cdata, body, err := xml.Unmarshal(bytes.NewReader(p), dst.xmlDetails) if err != nil { diff --git a/x-pack/filebeat/input/httpjson/input.go b/x-pack/filebeat/input/httpjson/input.go index 749e1c73bcfc..eea7bd39e281 100644 --- a/x-pack/filebeat/input/httpjson/input.go +++ b/x-pack/filebeat/input/httpjson/input.go @@ -265,6 +265,43 @@ func newNetHTTPClient(ctx context.Context, cfg *requestConfig, log *logp.Logger, return netHTTPClient, nil } +func newChainHTTPClient(ctx context.Context, authCfg *authConfig, requestCfg *requestConfig, log *logp.Logger, reg *monitoring.Registry, p ...*Policy) (*httpClient, error) { + // Make retryable HTTP client + netHTTPClient, err := newNetHTTPClient(ctx, requestCfg, log, reg) + if err != nil { + return nil, err + } + + var retryPolicyFunc retryablehttp.CheckRetry + if len(p) != 0 { + retryPolicyFunc = p[0].CustomRetryPolicy + } else { + retryPolicyFunc = retryablehttp.DefaultRetryPolicy + } + + client := &retryablehttp.Client{ + HTTPClient: netHTTPClient, + Logger: newRetryLogger(log), + RetryWaitMin: requestCfg.Retry.getWaitMin(), + RetryWaitMax: requestCfg.Retry.getWaitMax(), + RetryMax: requestCfg.Retry.getMaxAttempts(), + CheckRetry: retryPolicyFunc, + Backoff: retryablehttp.DefaultBackoff, + } + + limiter := newRateLimiterFromConfig(requestCfg.RateLimit, log) + + if authCfg != nil && authCfg.OAuth2.isEnabled() { + authClient, err := authCfg.OAuth2.client(ctx, client.StandardClient()) + if err != nil { + return nil, err + } + return &httpClient{client: authClient, limiter: limiter}, nil + } + + return &httpClient{client: client.StandardClient(), limiter: limiter}, nil +} + // clientOption returns constructed client configuration options, including // setting up http+unix and http+npipe transports if requested. func clientOptions(u *url.URL, keepalive httpcommon.WithKeepaliveSettings) []httpcommon.TransportOption { diff --git a/x-pack/filebeat/input/httpjson/input_manager.go b/x-pack/filebeat/input/httpjson/input_manager.go index 93e76a6eb57a..7eb2d628aaf7 100644 --- a/x-pack/filebeat/input/httpjson/input_manager.go +++ b/x-pack/filebeat/input/httpjson/input_manager.go @@ -41,11 +41,6 @@ func NewInputManager(log *logp.Logger, store inputcursor.StateStore) InputManage // Init initializes both wrapped input managers. func (m InputManager) Init(grp unison.Group, mode v2.Mode) error { - registerRequestTransforms() - registerResponseTransforms() - registerPaginationTransforms() - registerEncoders() - registerDecoders() return multierr.Append( m.stateless.Init(grp, mode), m.cursor.Init(grp, mode), diff --git a/x-pack/filebeat/input/httpjson/input_test.go b/x-pack/filebeat/input/httpjson/input_test.go index bf12335ac3df..e88a0a28d30d 100644 --- a/x-pack/filebeat/input/httpjson/input_test.go +++ b/x-pack/filebeat/input/httpjson/input_test.go @@ -259,8 +259,6 @@ var testCases = []struct { { name: "date_cursor", setupServer: func(t testing.TB, h http.HandlerFunc, config map[string]interface{}) { - registerRequestTransforms() - t.Cleanup(func() { registeredTransforms = newRegistry() }) // mock timeNow func to return a fixed value timeNow = func() time.Time { t, _ := time.Parse(time.RFC3339, "2002-10-02T15:00:00Z") @@ -300,8 +298,6 @@ var testCases = []struct { { name: "tracer_filename_sanitization", setupServer: func(t testing.TB, h http.HandlerFunc, config map[string]interface{}) { - registerRequestTransforms() - t.Cleanup(func() { registeredTransforms = newRegistry() }) // mock timeNow func to return a fixed value timeNow = func() time.Time { t, _ := time.Parse(time.RFC3339, "2002-10-02T15:00:00Z") @@ -343,9 +339,6 @@ var testCases = []struct { { name: "pagination", setupServer: func(t testing.TB, h http.HandlerFunc, config map[string]interface{}) { - registerPaginationTransforms() - registerResponseTransforms() - t.Cleanup(func() { registeredTransforms = newRegistry() }) server := httptest.NewServer(h) config["request.url"] = server.URL t.Cleanup(server.Close) @@ -385,9 +378,6 @@ var testCases = []struct { name: "first_event", setupServer: func(t testing.TB, h http.HandlerFunc, config map[string]interface{}) { - registerPaginationTransforms() - registerResponseTransforms() - t.Cleanup(func() { registeredTransforms = newRegistry() }) server := httptest.NewServer(h) config["request.url"] = server.URL t.Cleanup(server.Close) @@ -428,8 +418,6 @@ var testCases = []struct { { name: "pagination_with_array_response", setupServer: func(t testing.TB, h http.HandlerFunc, config map[string]interface{}) { - registerPaginationTransforms() - t.Cleanup(func() { registeredTransforms = newRegistry() }) server := httptest.NewServer(h) config["request.url"] = server.URL t.Cleanup(server.Close) @@ -473,8 +461,6 @@ var testCases = []struct { { name: "request_transforms_can_access_state_from_previous_transforms", setupServer: func(t testing.TB, h http.HandlerFunc, config map[string]interface{}) { - registerRequestTransforms() - t.Cleanup(func() { registeredTransforms = newRegistry() }) server := httptest.NewServer(h) config["request.url"] = server.URL + "/test-path" t.Cleanup(server.Close) @@ -509,9 +495,6 @@ var testCases = []struct { { name: "response_transforms_can't_access_request_state_from_previous_transforms", setupServer: func(t testing.TB, h http.HandlerFunc, config map[string]interface{}) { - registerRequestTransforms() - registerResponseTransforms() - t.Cleanup(func() { registeredTransforms = newRegistry() }) server := httptest.NewServer(h) config["request.url"] = server.URL t.Cleanup(server.Close) @@ -600,8 +583,6 @@ var testCases = []struct { { name: "date_cursor_while_using_chain", setupServer: func(t testing.TB, h http.HandlerFunc, config map[string]interface{}) { - registerRequestTransforms() - t.Cleanup(func() { registeredTransforms = newRegistry() }) // mock timeNow func to return a fixed value timeNow = func() time.Time { t, _ := time.Parse(time.RFC3339, "2002-10-02T15:00:00Z") @@ -954,8 +935,6 @@ var testCases = []struct { name: "global_transform_context_separation_with_parent_last_response_object", setupServer: func(t testing.TB, h http.HandlerFunc, config map[string]interface{}) { var serverURL string - registerPaginationTransforms() - registerRequestTransforms() r := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/": @@ -973,7 +952,6 @@ var testCases = []struct { } }) server := httptest.NewServer(r) - t.Cleanup(func() { registeredTransforms = newRegistry() }) config["request.url"] = server.URL serverURL = server.URL config["chain.0.step.request.url"] = server.URL + "/$.exportId/$.files[:].id" @@ -1021,8 +999,6 @@ var testCases = []struct { name: "cursor_value_is_updated_for_root_response_with_chaining_&_pagination", setupServer: func(t testing.TB, h http.HandlerFunc, config map[string]interface{}) { var serverURL string - registerPaginationTransforms() - registerRequestTransforms() r := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/": @@ -1041,7 +1017,6 @@ var testCases = []struct { } }) server := httptest.NewServer(r) - t.Cleanup(func() { registeredTransforms = newRegistry() }) config["request.url"] = server.URL serverURL = server.URL config["chain.0.step.request.url"] = server.URL + "/$.exportId/$.files[:].id" @@ -1100,8 +1075,6 @@ var testCases = []struct { name: "cursor_value_is_updated_for_root_response_with_chaining_&_pagination_along_with_split_operator", setupServer: func(t testing.TB, h http.HandlerFunc, config map[string]interface{}) { var serverURL string - registerPaginationTransforms() - registerRequestTransforms() r := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/": @@ -1120,7 +1093,6 @@ var testCases = []struct { } }) server := httptest.NewServer(r) - t.Cleanup(func() { registeredTransforms = newRegistry() }) config["request.url"] = server.URL serverURL = server.URL config["chain.0.step.request.url"] = server.URL + "/$.exportId/$.files[:].id" @@ -1183,8 +1155,6 @@ var testCases = []struct { { name: "Test simple XML decode", setupServer: func(t testing.TB, h http.HandlerFunc, config map[string]interface{}) { - registerDecoders() - registerRequestTransforms() r := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { const text = ` @@ -1210,7 +1180,6 @@ var testCases = []struct { w.Write([]byte(text)) }) server := httptest.NewServer(r) - t.Cleanup(func() { registeredTransforms = newRegistry() }) config["request.url"] = server.URL t.Cleanup(server.Close) }, @@ -1454,7 +1423,6 @@ func newChainPaginationTestServer( newServer func(http.Handler) *httptest.Server, ) func(testing.TB, http.HandlerFunc, map[string]interface{}) { return func(t testing.TB, h http.HandlerFunc, config map[string]interface{}) { - registerPaginationTransforms() var serverURL string r := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { @@ -1474,7 +1442,6 @@ func newChainPaginationTestServer( config["request.url"] = server.URL serverURL = server.URL config["chain.0.step.request.url"] = server.URL + "/$.records[:].id" - t.Cleanup(func() { registeredTransforms = newRegistry() }) } } diff --git a/x-pack/filebeat/input/httpjson/metrics_test.go b/x-pack/filebeat/input/httpjson/metrics_test.go index 1b57bc5edce2..653523ec5a2c 100644 --- a/x-pack/filebeat/input/httpjson/metrics_test.go +++ b/x-pack/filebeat/input/httpjson/metrics_test.go @@ -32,9 +32,6 @@ func TestMetrics(t *testing.T) { { name: "Test pagination metrics", setupServer: func(t *testing.T, h http.HandlerFunc, config map[string]interface{}) { - registerPaginationTransforms() - registerResponseTransforms() - t.Cleanup(func() { registeredTransforms = newRegistry() }) server := httptest.NewServer(h) config["request.url"] = server.URL t.Cleanup(server.Close) diff --git a/x-pack/filebeat/input/httpjson/pagination.go b/x-pack/filebeat/input/httpjson/pagination.go index ddaf5e8c6703..1b8cfc6d240f 100644 --- a/x-pack/filebeat/input/httpjson/pagination.go +++ b/x-pack/filebeat/input/httpjson/pagination.go @@ -19,12 +19,6 @@ import ( const paginationNamespace = "pagination" -func registerPaginationTransforms() { - registerTransform(paginationNamespace, appendName, newAppendPagination) - registerTransform(paginationNamespace, deleteName, newDeletePagination) - registerTransform(paginationNamespace, setName, newSetRequestPagination) -} - type pagination struct { log *logp.Logger httpClient *httpClient @@ -44,8 +38,8 @@ func newPagination(config config, httpClient *httpClient, log *logp.Logger) *pag return pagination } - rts, _ := newBasicTransformsFromConfig(config.Request.Transforms, requestNamespace, log) - pts, _ := newBasicTransformsFromConfig(config.Response.Pagination, paginationNamespace, log) + rts, _ := newBasicTransformsFromConfig(registeredTransforms, config.Request.Transforms, requestNamespace, log) + pts, _ := newBasicTransformsFromConfig(registeredTransforms, config.Response.Pagination, paginationNamespace, log) body := func() *mapstr.M { if config.Response.RequestBodyOnPagination { diff --git a/x-pack/filebeat/input/httpjson/request.go b/x-pack/filebeat/input/httpjson/request.go index 34532e1d5500..f92d2944c70b 100644 --- a/x-pack/filebeat/input/httpjson/request.go +++ b/x-pack/filebeat/input/httpjson/request.go @@ -8,10 +8,13 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" "net/url" + "reflect" + "strconv" "strings" "github.com/PaesslerAG/jsonpath" @@ -25,17 +28,210 @@ import ( const requestNamespace = "request" -func registerRequestTransforms() { - registerTransform(requestNamespace, appendName, newAppendRequest) - registerTransform(requestNamespace, deleteName, newDeleteRequest) - registerTransform(requestNamespace, setName, newSetRequestPagination) -} - type httpClient struct { client *http.Client limiter *rateLimiter } +func (r *requester) doRequest(stdCtx context.Context, trCtx *transformContext, publisher inputcursor.Publisher) error { + var ( + n int + ids []string + err error + urlCopy url.URL + urlString string + httpResp *http.Response + initialResponse []*http.Response + intermediateResps []*http.Response + finalResps []*http.Response + isChainWithPageExpected bool + chainIndex int + ) + + //nolint:bodyclose // response body is closed through drainBody method + for i, rf := range r.requestFactories { + finalResps = nil + intermediateResps = nil + // iterate over collected ids from last response + if i == 0 { + // perform and store regular call responses + httpResp, err = rf.collectResponse(stdCtx, trCtx, r) + if err != nil { + return fmt.Errorf("failed to execute rf.collectResponse: %w", err) + } + + if rf.saveFirstResponse { + // store first response in transform context + var bodyMap map[string]interface{} + body, err := io.ReadAll(httpResp.Body) + if err != nil { + return fmt.Errorf("failed to read http response body: %w", err) + } + httpResp.Body = io.NopCloser(bytes.NewReader(body)) + err = json.Unmarshal(body, &bodyMap) + if err != nil { + r.log.Errorf("unable to unmarshal first_response.body: %v", err) + } + firstResponse := response{ + url: *httpResp.Request.URL, + header: httpResp.Header.Clone(), + body: bodyMap, + } + trCtx.updateFirstResponse(firstResponse) + } + + if len(r.requestFactories) == 1 { + finalResps = append(finalResps, httpResp) + events := r.responseProcessors[i].startProcessing(stdCtx, trCtx, finalResps, true) + n = processAndPublishEvents(trCtx, events, publisher, true, r.log) + continue + } + + // if flow of control reaches here, that means there are more than 1 request factories + // if a chain step exists, only then we will initialize flags & variables here which are required for chaining + if r.requestFactories[i+1].isChain { + chainIndex = i + 1 + resp, err := cloneResponse(httpResp) + if err != nil { + return err + } + // the response is cloned and added to finalResps here, since the response of the 1st page (whether pagination exists or not), will + // be sent for further processing to check if any response processors can be applied or not and at the same time update the last_response, + // first_event & last_event cursor values. + finalResps = append(finalResps, resp) + + // if a pagination request factory exists at the root level along with a chain step, only then we will initialize flags & variables here + // which are required for chaining with root level pagination + if r.responseProcessors[i].pagination.requestFactory != nil { + isChainWithPageExpected = true + resp, err := cloneResponse(httpResp) + if err != nil { + return err + } + initialResponse = append(initialResponse, resp) + } + } + + intermediateResps = append(intermediateResps, httpResp) + ids, err = r.getIdsFromResponses(intermediateResps, r.requestFactories[i+1].replace) + if err != nil { + return err + } + // we avoid unnecessary pagination here since chaining is present, thus avoiding any unexpected updates to cursor values + events := r.responseProcessors[i].startProcessing(stdCtx, trCtx, finalResps, false) + n = processAndPublishEvents(trCtx, events, publisher, false, r.log) + } else { + if len(ids) == 0 { + n = 0 + continue + } + urlCopy = rf.url + urlString = rf.url.String() + + // new transform context for every chain step, derived from parent transform context + var chainTrCtx *transformContext + if rf.isChain { + chainTrCtx = trCtx.clone() + } + + var val string + var doReplaceWith bool + var replaceArr []string + if rf.replaceWith != "" { + replaceArr = strings.Split(rf.replaceWith, ",") + val, doReplaceWith, err = fetchValueFromContext(chainTrCtx, strings.TrimSpace(replaceArr[1])) + if err != nil { + return err + } + } + + // perform request over collected ids + for _, id := range ids { + // reformat urls of requestFactory using ids + rf.url, err = generateNewUrl(rf.replace, urlString, id) + if err != nil { + return fmt.Errorf("failed to generate new URL: %w", err) + } + + // reformat url accordingly if replaceWith clause exists + if doReplaceWith { + rf.url, err = generateNewUrl(strings.TrimSpace(replaceArr[0]), rf.url.String(), val) + if err != nil { + return fmt.Errorf("failed to generate new URL: %w", err) + } + } + // collect data from new urls + httpResp, err = rf.collectResponse(stdCtx, chainTrCtx, r) + if err != nil { + return fmt.Errorf("failed to execute rf.collectResponse: %w", err) + } + // store data according to response type + if i == len(r.requestFactories)-1 && len(ids) != 0 { + finalResps = append(finalResps, httpResp) + } else { + intermediateResps = append(intermediateResps, httpResp) + } + } + rf.url = urlCopy + + var resps []*http.Response + if i == len(r.requestFactories)-1 { + resps = finalResps + } else { + // The if comdition (i < len(r.requestFactories)) ensures this branch never runs to the last element + // of r.requestFactories, therefore r.requestFactories[i+1] will never be out of bounds. + ids, err = r.getIdsFromResponses(intermediateResps, r.requestFactories[i+1].replace) + if err != nil { + return err + } + resps = intermediateResps + } + + var events <-chan maybeMsg + if rf.isChain { + events = rf.chainResponseProcessor.startProcessing(stdCtx, chainTrCtx, resps, true) + } else { + events = r.responseProcessors[i].startProcessing(stdCtx, trCtx, resps, true) + } + n += processAndPublishEvents(chainTrCtx, events, publisher, i < len(r.requestFactories), r.log) + } + } + + defer httpResp.Body.Close() + // if pagination exists for the parent request along with chaining, then for each page response the chain is processed + if isChainWithPageExpected { + n += r.processRemainingChainEvents(stdCtx, trCtx, publisher, initialResponse, chainIndex) + } + r.log.Infof("request finished: %d events published", n) + + return nil +} + +// collectResponse returns response from provided request +func (rf *requestFactory) collectResponse(stdCtx context.Context, trCtx *transformContext, r *requester) (*http.Response, error) { + var err error + var httpResp *http.Response + + req, err := rf.newHTTPRequest(stdCtx, trCtx) + if err != nil { + return nil, fmt.Errorf("failed to create http request: %w", err) + } + + if rf.isChain && rf.chainHTTPClient != nil { + httpResp, err = rf.chainHTTPClient.do(stdCtx, req) + if err != nil { + return nil, fmt.Errorf("failed to execute chain http client.Do: %w", err) + } + } else { + httpResp, err = r.client.do(stdCtx, req) + if err != nil { + return nil, fmt.Errorf("failed to execute http client.Do: %w", err) + } + } + + return httpResp, nil +} + func (c *httpClient) do(stdCtx context.Context, req *http.Request) (*http.Response, error) { resp, err := c.limiter.execute(stdCtx, func() (*http.Response, error) { return c.client.Do(req) @@ -59,40 +255,6 @@ func (c *httpClient) do(stdCtx context.Context, req *http.Request) (*http.Respon return resp, nil } -func (rf *requestFactory) newRequest(ctx *transformContext) (transformable, error) { - req := transformable{} - req.setURL(rf.url) - - if rf.body != nil && len(*rf.body) > 0 { - req.setBody(rf.body.Clone()) - } - - header := http.Header{} - header.Set("Accept", "application/json") - header.Set("User-Agent", userAgent) - req.setHeader(header) - - var err error - for _, t := range rf.transforms { - req, err = t.run(ctx, req) - if err != nil { - return transformable{}, err - } - } - - if rf.method == http.MethodPost { - header = req.header() - if header.Get("Content-Type") == "" { - header.Set("Content-Type", "application/json") - req.setHeader(header) - } - } - - rf.log.Debugf("new request: %#v", req) - - return req, nil -} - type requestFactory struct { url url.URL method string @@ -114,7 +276,7 @@ type requestFactory struct { func newRequestFactory(ctx context.Context, config config, log *logp.Logger, metrics *inputMetrics, reg *monitoring.Registry) ([]*requestFactory, error) { // config validation already checked for errors here rfs := make([]*requestFactory, 0, len(config.Chain)+1) - ts, _ := newBasicTransformsFromConfig(config.Request.Transforms, requestNamespace, log) + ts, _ := newBasicTransformsFromConfig(registeredTransforms, config.Request.Transforms, requestNamespace, log) // regular call requestFactory object rf := &requestFactory{ url: *config.Request.URL.URL, @@ -143,7 +305,7 @@ func newRequestFactory(ctx context.Context, config config, log *logp.Logger, met var rf *requestFactory // chain calls requestFactory object if ch.Step != nil { - ts, _ := newBasicTransformsFromConfig(ch.Step.Request.Transforms, requestNamespace, log) + ts, _ := newBasicTransformsFromConfig(registeredTransforms, ch.Step.Request.Transforms, requestNamespace, log) ch.Step.Auth = tryAssignAuth(config.Auth, ch.Step.Auth) httpClient, err := newChainHTTPClient(ctx, ch.Step.Auth, ch.Step.Request, log, reg) if err != nil { @@ -170,7 +332,7 @@ func newRequestFactory(ctx context.Context, config config, log *logp.Logger, met chainResponseProcessor: responseProcessor, } } else if ch.While != nil { - ts, _ := newBasicTransformsFromConfig(ch.While.Request.Transforms, requestNamespace, log) + ts, _ := newBasicTransformsFromConfig(registeredTransforms, ch.While.Request.Transforms, requestNamespace, log) policy := newHTTPPolicy(evaluateResponse, ch.While.Until, log) ch.While.Auth = tryAssignAuth(config.Auth, ch.While.Auth) httpClient, err := newChainHTTPClient(ctx, ch.While.Auth, ch.While.Request, log, reg, policy) @@ -203,268 +365,128 @@ func newRequestFactory(ctx context.Context, config config, log *logp.Logger, met return rfs, nil } -func (rf *requestFactory) newHTTPRequest(stdCtx context.Context, trCtx *transformContext) (*http.Request, error) { - trReq, err := rf.newRequest(trCtx) +func evaluateResponse(expression *valueTpl, data []byte, log *logp.Logger) (bool, error) { + var dataMap mapstr.M + + err := json.Unmarshal(data, &dataMap) if err != nil { - return nil, err + return false, fmt.Errorf("error while unmarshalling data : %w", err) } - - var body []byte - if rf.method == http.MethodPost { - if rf.encoder != nil { - body, err = rf.encoder(trReq) - } else { - body, err = encode(trReq.header().Get("Content-Type"), trReq) - } - if err != nil { - return nil, err - } + tr := transformable{} + paramCtx := &transformContext{ + firstEvent: &mapstr.M{}, + lastEvent: &mapstr.M{}, + firstResponse: &response{}, + lastResponse: &response{body: dataMap}, } - url := trReq.url() - req, err := http.NewRequest(rf.method, url.String(), bytes.NewBuffer(body)) + val, err := expression.Execute(paramCtx, tr, "", nil, log) if err != nil { - return nil, err + return false, fmt.Errorf("error while evaluating expression : %w", err) } - - req = req.WithContext(stdCtx) - - req.Header = trReq.header().Clone() - - if rf.user != "" || rf.password != "" { - req.SetBasicAuth(rf.user, rf.password) + result, err := strconv.ParseBool(val) + if err != nil { + return false, fmt.Errorf("error while parsing boolean value of string : %w", err) } - return req, nil + return result, nil } -type requester struct { - log *logp.Logger - client *httpClient - requestFactories []*requestFactory - responseProcessors []*responseProcessor -} - -func newRequester( - client *httpClient, - requestFactory []*requestFactory, - responseProcessor []*responseProcessor, - log *logp.Logger, -) *requester { - return &requester{ - log: log, - client: client, - requestFactories: requestFactory, - responseProcessors: responseProcessor, +func tryAssignAuth(parentConfig *authConfig, childConfig *authConfig) *authConfig { + if parentConfig != nil && childConfig == nil { + return parentConfig } + return childConfig } -// collectResponse returns response from provided request -func (rf *requestFactory) collectResponse(stdCtx context.Context, trCtx *transformContext, r *requester) (*http.Response, error) { - var err error - var httpResp *http.Response - - req, err := rf.newHTTPRequest(stdCtx, trCtx) - if err != nil { - return nil, fmt.Errorf("failed to create http request: %w", err) - } - - if rf.isChain && rf.chainHTTPClient != nil { - httpResp, err = rf.chainHTTPClient.do(stdCtx, req) - if err != nil { - return nil, fmt.Errorf("failed to execute chain http client.Do: %w", err) - } - } else { - httpResp, err = r.client.do(stdCtx, req) - if err != nil { - return nil, fmt.Errorf("failed to execute http client.Do: %w", err) - } - } - - return httpResp, nil -} - -// generateNewUrl returns new url value using replacement from oldUrl with ids -func generateNewUrl(replacement, oldUrl, id string) (url.URL, error) { - newUrl, err := url.Parse(strings.Replace(oldUrl, replacement, id, 1)) +func (rf *requestFactory) newHTTPRequest(stdCtx context.Context, trCtx *transformContext) (*http.Request, error) { + trReq, err := rf.newRequest(trCtx) if err != nil { - return url.URL{}, fmt.Errorf("failed to replace value in url: %w", err) + return nil, err } - return *newUrl, nil -} - -func (r *requester) doRequest(stdCtx context.Context, trCtx *transformContext, publisher inputcursor.Publisher) error { - var ( - n int - ids []string - err error - urlCopy url.URL - urlString string - httpResp *http.Response - initialResponse []*http.Response - intermediateResps []*http.Response - finalResps []*http.Response - isChainWithPageExpected bool - chainIndex int - ) - - //nolint:bodyclose // response body is closed through drainBody method - for i, rf := range r.requestFactories { - finalResps = nil - intermediateResps = nil - // iterate over collected ids from last response - if i == 0 { - // perform and store regular call responses - httpResp, err = rf.collectResponse(stdCtx, trCtx, r) - if err != nil { - return fmt.Errorf("failed to execute rf.collectResponse: %w", err) - } - - if rf.saveFirstResponse { - // store first response in transform context - var bodyMap map[string]interface{} - body, err := io.ReadAll(httpResp.Body) - if err != nil { - return fmt.Errorf("failed to read http response body: %w", err) - } - httpResp.Body = io.NopCloser(bytes.NewReader(body)) - err = json.Unmarshal(body, &bodyMap) - if err != nil { - r.log.Errorf("unable to unmarshal first_response.body: %v", err) - } - firstResponse := response{ - url: *httpResp.Request.URL, - header: httpResp.Header.Clone(), - body: bodyMap, - } - trCtx.updateFirstResponse(firstResponse) - } - - if len(r.requestFactories) == 1 { - finalResps = append(finalResps, httpResp) - events := r.responseProcessors[i].startProcessing(stdCtx, trCtx, finalResps, true) - n = processAndPublishEvents(trCtx, events, publisher, true, r.log) - continue - } - - // if flow of control reaches here, that means there are more than 1 request factories - // if a chain step exists, only then we will initialize flags & variables here which are required for chaining - if r.requestFactories[i+1].isChain { - chainIndex = i + 1 - resp, err := cloneResponse(httpResp) - if err != nil { - return err - } - // the response is cloned and added to finalResps here, since the response of the 1st page (whether pagination exists or not), will - // be sent for further processing to check if any response processors can be applied or not and at the same time update the last_response, - // first_event & last_event cursor values. - finalResps = append(finalResps, resp) - - // if a pagination request factory exists at the root level along with a chain step, only then we will initialize flags & variables here - // which are required for chaining with root level pagination - if r.responseProcessors[i].pagination.requestFactory != nil { - isChainWithPageExpected = true - resp, err := cloneResponse(httpResp) - if err != nil { - return err - } - initialResponse = append(initialResponse, resp) - } - } - - intermediateResps = append(intermediateResps, httpResp) - ids, err = r.getIdsFromResponses(intermediateResps, r.requestFactories[i+1].replace) - if err != nil { - return err - } - // we avoid unnecessary pagination here since chaining is present, thus avoiding any unexpected updates to cursor values - events := r.responseProcessors[i].startProcessing(stdCtx, trCtx, finalResps, false) - n = processAndPublishEvents(trCtx, events, publisher, false, r.log) + + var body []byte + if rf.method == http.MethodPost { + if rf.encoder != nil { + body, err = rf.encoder(trReq) } else { - if len(ids) == 0 { - n = 0 - continue - } - urlCopy = rf.url - urlString = rf.url.String() + body, err = encode(trReq.header().Get("Content-Type"), trReq) + } + if err != nil { + return nil, err + } + } - // new transform context for every chain step, derived from parent transform context - var chainTrCtx *transformContext - if rf.isChain { - chainTrCtx = trCtx.clone() - } + url := trReq.url() + req, err := http.NewRequest(rf.method, url.String(), bytes.NewBuffer(body)) + if err != nil { + return nil, err + } - var val string - var doReplaceWith bool - var replaceArr []string - if rf.replaceWith != "" { - replaceArr = strings.Split(rf.replaceWith, ",") - val, doReplaceWith, err = fetchValueFromContext(chainTrCtx, strings.TrimSpace(replaceArr[1])) - if err != nil { - return err - } - } + req = req.WithContext(stdCtx) - // perform request over collected ids - for _, id := range ids { - // reformat urls of requestFactory using ids - rf.url, err = generateNewUrl(rf.replace, urlString, id) - if err != nil { - return fmt.Errorf("failed to generate new URL: %w", err) - } + req.Header = trReq.header().Clone() - // reformat url accordingly if replaceWith clause exists - if doReplaceWith { - rf.url, err = generateNewUrl(strings.TrimSpace(replaceArr[0]), rf.url.String(), val) - if err != nil { - return fmt.Errorf("failed to generate new URL: %w", err) - } - } - // collect data from new urls - httpResp, err = rf.collectResponse(stdCtx, chainTrCtx, r) - if err != nil { - return fmt.Errorf("failed to execute rf.collectResponse: %w", err) - } - // store data according to response type - if i == len(r.requestFactories)-1 && len(ids) != 0 { - finalResps = append(finalResps, httpResp) - } else { - intermediateResps = append(intermediateResps, httpResp) - } - } - rf.url = urlCopy + if rf.user != "" || rf.password != "" { + req.SetBasicAuth(rf.user, rf.password) + } - var resps []*http.Response - if i == len(r.requestFactories)-1 { - resps = finalResps - } else { - // The if comdition (i < len(r.requestFactories)) ensures this branch never runs to the last element - // of r.requestFactories, therefore r.requestFactories[i+1] will never be out of bounds. - ids, err = r.getIdsFromResponses(intermediateResps, r.requestFactories[i+1].replace) - if err != nil { - return err - } - resps = intermediateResps - } + return req, nil +} - var events <-chan maybeMsg - if rf.isChain { - events = rf.chainResponseProcessor.startProcessing(stdCtx, chainTrCtx, resps, true) - } else { - events = r.responseProcessors[i].startProcessing(stdCtx, trCtx, resps, true) - } - n += processAndPublishEvents(chainTrCtx, events, publisher, i < len(r.requestFactories), r.log) +func (rf *requestFactory) newRequest(ctx *transformContext) (transformable, error) { + req := transformable{} + req.setURL(rf.url) + + if rf.body != nil && len(*rf.body) > 0 { + req.setBody(rf.body.Clone()) + } + + header := http.Header{} + header.Set("Accept", "application/json") + header.Set("User-Agent", userAgent) + req.setHeader(header) + + var err error + for _, t := range rf.transforms { + req, err = t.run(ctx, req) + if err != nil { + return transformable{}, err } } - defer httpResp.Body.Close() - // if pagination exists for the parent request along with chaining, then for each page response the chain is processed - if isChainWithPageExpected { - n += r.processRemainingChainEvents(stdCtx, trCtx, publisher, initialResponse, chainIndex) + if rf.method == http.MethodPost { + header = req.header() + if header.Get("Content-Type") == "" { + header.Set("Content-Type", "application/json") + req.setHeader(header) + } } - r.log.Infof("request finished: %d events published", n) - return nil + rf.log.Debugf("new request: %#v", req) + + return req, nil +} + +type requester struct { + log *logp.Logger + client *httpClient + requestFactories []*requestFactory + responseProcessors []*responseProcessor +} + +func newRequester( + client *httpClient, + requestFactory []*requestFactory, + responseProcessor []*responseProcessor, + log *logp.Logger, +) *requester { + return &requester{ + log: log, + client: client, + requestFactories: requestFactory, + responseProcessors: responseProcessor, + } } // getIdsFromResponses returns ids from responses @@ -516,38 +538,6 @@ func (r *requester) getIdsFromResponses(intermediateResps []*http.Response, repl return ids, nil } -// processAndPublishEvents process and publish events based on event type -func processAndPublishEvents(trCtx *transformContext, events <-chan maybeMsg, publisher inputcursor.Publisher, publish bool, log *logp.Logger) int { - var n int - for maybeMsg := range events { - if maybeMsg.failed() { - log.Errorf("error processing response: %v", maybeMsg) - continue - } - - if publish { - event, err := makeEvent(maybeMsg.msg) - if err != nil { - log.Errorf("error creating event: %v", maybeMsg) - continue - } - - if err := publisher.Publish(event, trCtx.cursorMap()); err != nil { - log.Errorf("error publishing event: %v", err) - continue - } - } - if len(*trCtx.firstEventClone()) == 0 { - trCtx.updateFirstEvent(maybeMsg.msg) - } - trCtx.updateLastEvent(maybeMsg.msg) - trCtx.updateCursor() - - n++ - } - return n -} - // processRemainingChainEvents, processes the remaining pagination events for chain blocks func (r *requester) processRemainingChainEvents(stdCtx context.Context, trCtx *transformContext, publisher inputcursor.Publisher, initialResp []*http.Response, chainIndex int) int { // we start from 0, and skip the 1st event since we have already processed it @@ -698,6 +688,170 @@ func (r *requester) processChainPaginationEvents(stdCtx context.Context, trCtx * return n, nil } +// generateNewUrl returns new url value using replacement from oldUrl with ids +func generateNewUrl(replacement, oldUrl, id string) (url.URL, error) { + newUrl, err := url.Parse(strings.Replace(oldUrl, replacement, id, 1)) + if err != nil { + return url.URL{}, fmt.Errorf("failed to replace value in url: %w", err) + } + return *newUrl, nil +} + +// processAndPublishEvents process and publish events based on event type +func processAndPublishEvents(trCtx *transformContext, events <-chan maybeMsg, publisher inputcursor.Publisher, publish bool, log *logp.Logger) int { + var n int + for maybeMsg := range events { + if maybeMsg.failed() { + log.Errorf("error processing response: %v", maybeMsg) + continue + } + + if publish { + event, err := makeEvent(maybeMsg.msg) + if err != nil { + log.Errorf("error creating event: %v", maybeMsg) + continue + } + + if err := publisher.Publish(event, trCtx.cursorMap()); err != nil { + log.Errorf("error publishing event: %v", err) + continue + } + } + if len(*trCtx.firstEventClone()) == 0 { + trCtx.updateFirstEvent(maybeMsg.msg) + } + trCtx.updateLastEvent(maybeMsg.msg) + trCtx.updateCursor() + + n++ + } + return n +} + +const ( + // This is generally updated with chain responses, if present, as they continue to occur + // Otherwise this is always the last response of the root request w.r.t pagination + lastResponse = "last_response" + // This is always the first root response + firstResponse = "first_response" + // This is always the last response of the parent (root) request w.r.t pagination + // This is only set if chaining is used + parentLastResponse = "parent_last_response" +) + +func fetchValueFromContext(trCtx *transformContext, expression string) (string, bool, error) { + var val interface{} + + switch keys := processExpression(expression); keys[0] { + case lastResponse: + respMap, err := responseToMap(trCtx.lastResponse) + if err != nil { + return "", false, err + } + val, err = iterateRecursive(respMap, keys[1:], 0) + if err != nil { + return "", false, err + } + case parentLastResponse: + respMap, err := responseToMap(trCtx.parentTrCtx.lastResponse) + if err != nil { + return "", false, err + } + val, err = iterateRecursive(respMap, keys[1:], 0) + if err != nil { + return "", false, err + } + case firstResponse: + // since first response body is already a map, we do not need to transform it + respMap, err := responseToMap(trCtx.firstResponse) + if err != nil { + return "", false, err + } + val, err = iterateRecursive(respMap, keys[1:], 0) + if err != nil { + return "", false, err + } + // In this scenario we treat the expression as a hardcoded value, with which we will replace the fixed-pattern + case expression: + return expression, true, nil + default: + return "", false, fmt.Errorf("context value not supported for key: %q in expression %q", keys[0], expression) + } + + return fmt.Sprint(val), true, nil +} + +// processExpression, splits the expression string based on the separator and looks for +// supported keywords. If present, returns an expression array containing separated elements. +// If no keywords are present, the expression is treated as a hardcoded value and returned +// as a merged string which is the only array element. +func processExpression(expression string) []string { + if !strings.HasPrefix(expression, ".") { + return []string{expression} + } + switch { + case strings.HasPrefix(expression, "."+firstResponse+"."), + strings.HasPrefix(expression, "."+lastResponse+"."), + strings.HasPrefix(expression, "."+parentLastResponse+"."): + return strings.Split(expression, ".")[1:] + default: + return []string{expression} + } +} + +func responseToMap(r *response) (mapstr.M, error) { + if r.body == nil { + return nil, fmt.Errorf("response body is empty for request url: %s", &r.url) + } + respMap := map[string]interface{}{ + "header": make(mapstr.M), + "body": make(mapstr.M), + } + + for key, value := range r.header { + respMap["header"] = mapstr.M{ + key: value, + } + } + respMap["body"] = r.body + + return respMap, nil +} + +func iterateRecursive(m mapstr.M, keys []string, depth int) (interface{}, error) { + val := m[keys[depth]] + + if val == nil { + return nil, fmt.Errorf("value of expression could not be determined for key %s", strings.Join(keys[:depth+1], ".")) + } + + switch v := reflect.ValueOf(val); v.Kind() { + case reflect.Bool: + return v.Bool(), nil + case reflect.Int, reflect.Int8, reflect.Int32, reflect.Int64: + return v.Int(), nil + case reflect.Uint, reflect.Uint8, reflect.Uint32, reflect.Uint64: + return v.Uint(), nil + case reflect.Float32, reflect.Float64: + return v.Float(), nil + case reflect.String: + return v.String(), nil + case reflect.Map: + nextMap, ok := v.Interface().(map[string]interface{}) + if !ok { + return nil, errors.New("unable to parse the value of the given expression") + } + depth++ + if depth >= len(keys) { + return nil, errors.New("value of expression could not be determined") + } + return iterateRecursive(nextMap, keys, depth) + default: + return nil, fmt.Errorf("unable to parse the value of the expression %s: type %T is not handled", strings.Join(keys[:depth+1], "."), val) + } +} + // cloneResponse clones required http response attributes func cloneResponse(source *http.Response) (*http.Response, error) { var resp http.Response diff --git a/x-pack/filebeat/input/httpjson/request_chain_helper.go b/x-pack/filebeat/input/httpjson/request_chain_helper.go deleted file mode 100644 index c75c573a17f7..000000000000 --- a/x-pack/filebeat/input/httpjson/request_chain_helper.go +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one -// or more contributor license agreements. Licensed under the Elastic License; -// you may not use this file except in compliance with the Elastic License. - -package httpjson - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "reflect" - "strconv" - "strings" - - retryablehttp "github.com/hashicorp/go-retryablehttp" - - "github.com/elastic/elastic-agent-libs/logp" - "github.com/elastic/elastic-agent-libs/mapstr" - "github.com/elastic/elastic-agent-libs/monitoring" -) - -const ( - // This is generally updated with chain responses, if present, as they continue to occur - // Otherwise this is always the last response of the root request w.r.t pagination - lastResponse = "last_response" - // This is always the first root response - firstResponse = "first_response" - // This is always the last response of the parent (root) request w.r.t pagination - // This is only set if chaining is used - parentLastResponse = "parent_last_response" -) - -func newChainHTTPClient(ctx context.Context, authCfg *authConfig, requestCfg *requestConfig, log *logp.Logger, reg *monitoring.Registry, p ...*Policy) (*httpClient, error) { - // Make retryable HTTP client - netHTTPClient, err := newNetHTTPClient(ctx, requestCfg, log, reg) - if err != nil { - return nil, err - } - - var retryPolicyFunc retryablehttp.CheckRetry - if len(p) != 0 { - retryPolicyFunc = p[0].CustomRetryPolicy - } else { - retryPolicyFunc = retryablehttp.DefaultRetryPolicy - } - - client := &retryablehttp.Client{ - HTTPClient: netHTTPClient, - Logger: newRetryLogger(log), - RetryWaitMin: requestCfg.Retry.getWaitMin(), - RetryWaitMax: requestCfg.Retry.getWaitMax(), - RetryMax: requestCfg.Retry.getMaxAttempts(), - CheckRetry: retryPolicyFunc, - Backoff: retryablehttp.DefaultBackoff, - } - - limiter := newRateLimiterFromConfig(requestCfg.RateLimit, log) - - if authCfg != nil && authCfg.OAuth2.isEnabled() { - authClient, err := authCfg.OAuth2.client(ctx, client.StandardClient()) - if err != nil { - return nil, err - } - return &httpClient{client: authClient, limiter: limiter}, nil - } - - return &httpClient{client: client.StandardClient(), limiter: limiter}, nil -} - -func evaluateResponse(expression *valueTpl, data []byte, log *logp.Logger) (bool, error) { - var dataMap mapstr.M - - err := json.Unmarshal(data, &dataMap) - if err != nil { - return false, fmt.Errorf("error while unmarshalling data : %w", err) - } - tr := transformable{} - paramCtx := &transformContext{ - firstEvent: &mapstr.M{}, - lastEvent: &mapstr.M{}, - firstResponse: &response{}, - lastResponse: &response{body: dataMap}, - } - - val, err := expression.Execute(paramCtx, tr, "", nil, log) - if err != nil { - return false, fmt.Errorf("error while evaluating expression : %w", err) - } - result, err := strconv.ParseBool(val) - if err != nil { - return false, fmt.Errorf("error while parsing boolean value of string : %w", err) - } - - return result, nil -} - -// fetchValueFromContext evaluates a given expression and returns the appropriate value from context variables if present -func fetchValueFromContext(trCtx *transformContext, expression string) (string, bool, error) { - var val interface{} - - switch keys := processExpression(expression); keys[0] { - case lastResponse: - respMap, err := responseToMap(trCtx.lastResponse) - if err != nil { - return "", false, err - } - val, err = iterateRecursive(respMap, keys[1:], 0) - if err != nil { - return "", false, err - } - case parentLastResponse: - respMap, err := responseToMap(trCtx.parentTrCtx.lastResponse) - if err != nil { - return "", false, err - } - val, err = iterateRecursive(respMap, keys[1:], 0) - if err != nil { - return "", false, err - } - case firstResponse: - // since first response body is already a map, we do not need to transform it - respMap, err := responseToMap(trCtx.firstResponse) - if err != nil { - return "", false, err - } - val, err = iterateRecursive(respMap, keys[1:], 0) - if err != nil { - return "", false, err - } - // In this scenario we treat the expression as a hardcoded value, with which we will replace the fixed-pattern - case expression: - return expression, true, nil - default: - return "", false, fmt.Errorf("context value not supported for key: %q in expression %q", keys[0], expression) - } - - return fmt.Sprint(val), true, nil -} - -func responseToMap(r *response) (mapstr.M, error) { - if r.body == nil { - return nil, fmt.Errorf("response body is empty for request url: %s", &r.url) - } - respMap := map[string]interface{}{ - "header": make(mapstr.M), - "body": make(mapstr.M), - } - - for key, value := range r.header { - respMap["header"] = mapstr.M{ - key: value, - } - } - respMap["body"] = r.body - - return respMap, nil -} - -func iterateRecursive(m mapstr.M, keys []string, depth int) (interface{}, error) { - val := m[keys[depth]] - - if val == nil { - return nil, fmt.Errorf("value of expression could not be determined for key %s", strings.Join(keys[:depth+1], ".")) - } - - switch v := reflect.ValueOf(val); v.Kind() { - case reflect.Bool: - return v.Bool(), nil - case reflect.Int, reflect.Int8, reflect.Int32, reflect.Int64: - return v.Int(), nil - case reflect.Uint, reflect.Uint8, reflect.Uint32, reflect.Uint64: - return v.Uint(), nil - case reflect.Float32, reflect.Float64: - return v.Float(), nil - case reflect.String: - return v.String(), nil - case reflect.Map: - nextMap, ok := v.Interface().(map[string]interface{}) - if !ok { - return nil, errors.New("unable to parse the value of the given expression") - } - depth++ - if depth >= len(keys) { - return nil, errors.New("value of expression could not be determined") - } - return iterateRecursive(nextMap, keys, depth) - default: - return nil, fmt.Errorf("unable to parse the value of the expression %s: type %T is not handled", strings.Join(keys[:depth+1], "."), val) - } -} - -// processExpression, splits the expression string based on the separator and looks for -// supported keywords. If present, returns an expression array containing separated elements. -// If no keywords are present, the expression is treated as a hardcoded value and returned -// as a merged string which is the only array element. -func processExpression(expression string) []string { - if !strings.HasPrefix(expression, ".") { - return []string{expression} - } - switch { - case strings.HasPrefix(expression, "."+firstResponse+"."), - strings.HasPrefix(expression, "."+lastResponse+"."), - strings.HasPrefix(expression, "."+parentLastResponse+"."): - return strings.Split(expression, ".")[1:] - default: - return []string{expression} - } -} - -func tryAssignAuth(parentConfig *authConfig, childConfig *authConfig) *authConfig { - if parentConfig != nil && childConfig == nil { - return parentConfig - } - return childConfig -} diff --git a/x-pack/filebeat/input/httpjson/request_chain_helper_test.go b/x-pack/filebeat/input/httpjson/request_chain_helper_test.go deleted file mode 100644 index 0c2bf3840956..000000000000 --- a/x-pack/filebeat/input/httpjson/request_chain_helper_test.go +++ /dev/null @@ -1,164 +0,0 @@ -// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one -// or more contributor license agreements. Licensed under the Elastic License; -// you may not use this file except in compliance with the Elastic License. - -package httpjson - -import ( - "bytes" - "context" - "net/url" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/elastic/elastic-agent-libs/logp" -) - -func Test_newChainHTTPClient(t *testing.T) { - cfg := defaultChainConfig() - cfg.Request.URL = &urlConfig{URL: &url.URL{}} - ctx := context.Background() - log := logp.NewLogger("newChainClientTestLogger") - - type args struct { - ctx context.Context - authCfg *authConfig - requestCfg *requestConfig - log *logp.Logger - p []*Policy - } - tests := []struct { - name string - args args - }{ - { - name: "newChainClientTest", - args: args{ - ctx: ctx, - authCfg: cfg.Auth, - requestCfg: cfg.Request, - log: log, - p: nil, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := newChainHTTPClient(tt.args.ctx, tt.args.authCfg, tt.args.requestCfg, tt.args.log, nil, tt.args.p...) - assert.NoError(t, err) - assert.NotNil(t, got) - }) - } -} - -func Test_evaluateResponse(t *testing.T) { - log := logp.NewLogger("newEvaluateResponseTestLogger") - responseTrue := bytes.NewBufferString(`{"status": "completed"}`).Bytes() - responseFalse := bytes.NewBufferString(`{"status": "initiated"}`).Bytes() - - type args struct { - expression string - data []byte - log *logp.Logger - } - tests := []struct { - name string - args args - expectedError string - want bool - }{ - { - name: "newEvaluateResponse_resultIsTrue", - args: args{ - expression: `[[ eq .last_response.body.status "completed" ]]`, - data: responseTrue, - log: log, - }, - want: true, - expectedError: "", - }, - { - name: "newEvaluateResponse_resultIsFalse", - args: args{ - expression: `[[ eq .last_response.body.status "completed" ]]`, - data: responseFalse, - log: log, - }, - want: false, - expectedError: "", - }, - { - name: "newEvaluateResponse_invalidExpressionError", - args: args{ - expression: `eq .last_response.body.status "completed" ]]`, - data: responseFalse, - log: log, - }, - want: false, - expectedError: "error while parsing boolean value of string : strconv.ParseBool: parsing \"eq .last_response.body.status \\\"completed\\\" ]]\": invalid syntax", - }, - { - name: "newEvaluateResponse_emptyExpressionError", - args: args{ - expression: "", - data: responseFalse, - log: log, - }, - want: false, - expectedError: "error while evaluating expression : the template result is empty", - }, - { - name: "newEvaluateResponse_incompleteExpressionError", - args: args{ - expression: `[[.last_response.body.status]]`, - data: responseFalse, - log: log, - }, - want: false, - expectedError: "error while parsing boolean value of string : strconv.ParseBool: parsing \"initiated\": invalid syntax", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - expression := &valueTpl{} - err := expression.Unpack(tt.args.expression) - assert.NoError(t, err) - - got, err := evaluateResponse(expression, tt.args.data, tt.args.log) - if err != nil { - assert.EqualError(t, err, tt.expectedError) - } else { - assert.Equal(t, tt.want, got) - } - }) - } -} - -func TestProcessExpression(t *testing.T) { - tests := []struct { - in string - want []string - }{ - // Cursor values. - {in: ".first_response.foo", want: []string{"first_response", "foo"}}, - {in: ".first_response.", want: []string{"first_response", ""}}, - {in: ".last_response.foo", want: []string{"last_response", "foo"}}, - {in: ".last_response.", want: []string{"last_response", ""}}, - {in: ".parent_last_response.foo", want: []string{"parent_last_response", "foo"}}, - {in: ".parent_last_response.", want: []string{"parent_last_response", ""}}, - - // Literal values. - {in: ".literal_foo", want: []string{".literal_foo"}}, - {in: ".literal_foo.bar", want: []string{".literal_foo.bar"}}, - {in: "literal.foo.bar", want: []string{"literal.foo.bar"}}, - {in: "first_response.foo", want: []string{"first_response.foo"}}, - {in: ".first_response", want: []string{".first_response"}}, - {in: ".last_response", want: []string{".last_response"}}, - {in: ".parent_last_response", want: []string{".parent_last_response"}}, - } - for _, test := range tests { - got := processExpression(test.in) - assert.Equal(t, test.want, got) - } -} diff --git a/x-pack/filebeat/input/httpjson/request_test.go b/x-pack/filebeat/input/httpjson/request_test.go index 981089ff3ac5..315a3d4864bb 100644 --- a/x-pack/filebeat/input/httpjson/request_test.go +++ b/x-pack/filebeat/input/httpjson/request_test.go @@ -5,9 +5,11 @@ package httpjson import ( + "bytes" "context" "fmt" "net/http/httptest" + "net/url" "testing" "time" @@ -20,9 +22,6 @@ import ( ) func TestCtxAfterDoRequest(t *testing.T) { - registerRequestTransforms() - t.Cleanup(func() { registeredTransforms = newRegistry() }) - // mock timeNow func to return a fixed value timeNow = func() time.Time { t, _ := time.Parse(time.RFC3339, "2002-10-02T15:00:00Z") @@ -135,3 +134,151 @@ func TestCtxAfterDoRequest(t *testing.T) { lastResp, ) } + +func Test_newChainHTTPClient(t *testing.T) { + cfg := defaultChainConfig() + cfg.Request.URL = &urlConfig{URL: &url.URL{}} + ctx := context.Background() + log := logp.NewLogger("newChainClientTestLogger") + + type args struct { + ctx context.Context + authCfg *authConfig + requestCfg *requestConfig + log *logp.Logger + p []*Policy + } + tests := []struct { + name string + args args + }{ + { + name: "newChainClientTest", + args: args{ + ctx: ctx, + authCfg: cfg.Auth, + requestCfg: cfg.Request, + log: log, + p: nil, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := newChainHTTPClient(tt.args.ctx, tt.args.authCfg, tt.args.requestCfg, tt.args.log, nil, tt.args.p...) + assert.NoError(t, err) + assert.NotNil(t, got) + }) + } +} + +func Test_evaluateResponse(t *testing.T) { + log := logp.NewLogger("newEvaluateResponseTestLogger") + responseTrue := bytes.NewBufferString(`{"status": "completed"}`).Bytes() + responseFalse := bytes.NewBufferString(`{"status": "initiated"}`).Bytes() + + type args struct { + expression string + data []byte + log *logp.Logger + } + tests := []struct { + name string + args args + expectedError string + want bool + }{ + { + name: "newEvaluateResponse_resultIsTrue", + args: args{ + expression: `[[ eq .last_response.body.status "completed" ]]`, + data: responseTrue, + log: log, + }, + want: true, + expectedError: "", + }, + { + name: "newEvaluateResponse_resultIsFalse", + args: args{ + expression: `[[ eq .last_response.body.status "completed" ]]`, + data: responseFalse, + log: log, + }, + want: false, + expectedError: "", + }, + { + name: "newEvaluateResponse_invalidExpressionError", + args: args{ + expression: `eq .last_response.body.status "completed" ]]`, + data: responseFalse, + log: log, + }, + want: false, + expectedError: "error while parsing boolean value of string : strconv.ParseBool: parsing \"eq .last_response.body.status \\\"completed\\\" ]]\": invalid syntax", + }, + { + name: "newEvaluateResponse_emptyExpressionError", + args: args{ + expression: "", + data: responseFalse, + log: log, + }, + want: false, + expectedError: "error while evaluating expression : the template result is empty", + }, + { + name: "newEvaluateResponse_incompleteExpressionError", + args: args{ + expression: `[[.last_response.body.status]]`, + data: responseFalse, + log: log, + }, + want: false, + expectedError: "error while parsing boolean value of string : strconv.ParseBool: parsing \"initiated\": invalid syntax", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expression := &valueTpl{} + err := expression.Unpack(tt.args.expression) + assert.NoError(t, err) + + got, err := evaluateResponse(expression, tt.args.data, tt.args.log) + if err != nil { + assert.EqualError(t, err, tt.expectedError) + } else { + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestProcessExpression(t *testing.T) { + tests := []struct { + in string + want []string + }{ + // Cursor values. + {in: ".first_response.foo", want: []string{"first_response", "foo"}}, + {in: ".first_response.", want: []string{"first_response", ""}}, + {in: ".last_response.foo", want: []string{"last_response", "foo"}}, + {in: ".last_response.", want: []string{"last_response", ""}}, + {in: ".parent_last_response.foo", want: []string{"parent_last_response", "foo"}}, + {in: ".parent_last_response.", want: []string{"parent_last_response", ""}}, + + // Literal values. + {in: ".literal_foo", want: []string{".literal_foo"}}, + {in: ".literal_foo.bar", want: []string{".literal_foo.bar"}}, + {in: "literal.foo.bar", want: []string{"literal.foo.bar"}}, + {in: "first_response.foo", want: []string{"first_response.foo"}}, + {in: ".first_response", want: []string{".first_response"}}, + {in: ".last_response", want: []string{".last_response"}}, + {in: ".parent_last_response", want: []string{".parent_last_response"}}, + } + for _, test := range tests { + got := processExpression(test.in) + assert.Equal(t, test.want, got) + } +} diff --git a/x-pack/filebeat/input/httpjson/response.go b/x-pack/filebeat/input/httpjson/response.go index 2b23d1fdb03c..7adfd956fa20 100644 --- a/x-pack/filebeat/input/httpjson/response.go +++ b/x-pack/filebeat/input/httpjson/response.go @@ -18,12 +18,6 @@ import ( const responseNamespace = "response" -func registerResponseTransforms() { - registerTransform(responseNamespace, appendName, newAppendResponse) - registerTransform(responseNamespace, deleteName, newDeleteResponse) - registerTransform(responseNamespace, setName, newSetResponse) -} - type response struct { page int64 url url.URL @@ -53,6 +47,51 @@ func (resp *response) clone() *response { return clone } +func (resp *response) asTransformables(log *logp.Logger) []transformable { + var ts []transformable + + convertAndAppend := func(m map[string]interface{}) { + tr := transformable{} + tr.setHeader(resp.header.Clone()) + tr.setURL(resp.url) + tr.setBody(mapstr.M(m).Clone()) + ts = append(ts, tr) + } + + switch tresp := resp.body.(type) { + case []interface{}: + for _, v := range tresp { + m, ok := v.(map[string]interface{}) + if !ok { + log.Debugf("events must be JSON objects, but got %T: skipping", v) + continue + } + convertAndAppend(m) + } + case map[string]interface{}: + convertAndAppend(tresp) + default: + log.Debugf("response is not a valid JSON") + } + + return ts +} + +func (resp *response) templateValues() mapstr.M { + if resp == nil { + return mapstr.M{} + } + return mapstr.M{ + "header": resp.header.Clone(), + "page": resp.page, + "url": mapstr.M{ + "value": resp.url.String(), + "params": resp.url.Query(), + }, + "body": resp.body, + } +} + type responseProcessor struct { metrics *inputMetrics log *logp.Logger @@ -75,7 +114,7 @@ func newResponseProcessor(config config, pagination *pagination, xmlDetails map[ rps = append(rps, rp) return rps } - ts, _ := newBasicTransformsFromConfig(config.Response.Transforms, responseNamespace, log) + ts, _ := newBasicTransformsFromConfig(registeredTransforms, config.Response.Transforms, responseNamespace, log) rp.transforms = ts split, _ := newSplitResponse(config.Response.Split, log) @@ -119,7 +158,7 @@ func newChainResponseProcessor(config chainConfig, httpClient *httpClient, xmlDe return rp } - ts, _ := newBasicTransformsFromConfig(config.Step.Response.Transforms, responseNamespace, log) + ts, _ := newBasicTransformsFromConfig(registeredTransforms, config.Step.Response.Transforms, responseNamespace, log) rp.transforms = ts split, _ := newSplitResponse(config.Step.Response.Split, log) @@ -130,7 +169,7 @@ func newChainResponseProcessor(config chainConfig, httpClient *httpClient, xmlDe return rp } - ts, _ := newBasicTransformsFromConfig(config.While.Response.Transforms, responseNamespace, log) + ts, _ := newBasicTransformsFromConfig(registeredTransforms, config.While.Response.Transforms, responseNamespace, log) rp.transforms = ts split, _ := newSplitResponse(config.While.Response.Split, log) @@ -221,48 +260,3 @@ func (rp *responseProcessor) startProcessing(stdCtx context.Context, trCtx *tran return ch } - -func (resp *response) asTransformables(log *logp.Logger) []transformable { - var ts []transformable - - convertAndAppend := func(m map[string]interface{}) { - tr := transformable{} - tr.setHeader(resp.header.Clone()) - tr.setURL(resp.url) - tr.setBody(mapstr.M(m).Clone()) - ts = append(ts, tr) - } - - switch tresp := resp.body.(type) { - case []interface{}: - for _, v := range tresp { - m, ok := v.(map[string]interface{}) - if !ok { - log.Debugf("events must be JSON objects, but got %T: skipping", v) - continue - } - convertAndAppend(m) - } - case map[string]interface{}: - convertAndAppend(tresp) - default: - log.Debugf("response is not a valid JSON") - } - - return ts -} - -func (resp *response) templateValues() mapstr.M { - if resp == nil { - return mapstr.M{} - } - return mapstr.M{ - "header": resp.header.Clone(), - "page": resp.page, - "url": mapstr.M{ - "value": resp.url.String(), - "params": resp.url.Query(), - }, - "body": resp.body, - } -} diff --git a/x-pack/filebeat/input/httpjson/split.go b/x-pack/filebeat/input/httpjson/split.go index fd01f0775fc6..8fa892a6ce05 100644 --- a/x-pack/filebeat/input/httpjson/split.go +++ b/x-pack/filebeat/input/httpjson/split.go @@ -66,7 +66,7 @@ func newSplit(c *splitConfig, log *logp.Logger) (*split, error) { return nil, fmt.Errorf("invalid target type: %s", ti.Type) } - ts, err := newBasicTransformsFromConfig(c.Transforms, responseNamespace, log) + ts, err := newBasicTransformsFromConfig(registeredTransforms, c.Transforms, responseNamespace, log) if err != nil { return nil, err } diff --git a/x-pack/filebeat/input/httpjson/split_test.go b/x-pack/filebeat/input/httpjson/split_test.go index c8c800769701..2c4553e9df64 100644 --- a/x-pack/filebeat/input/httpjson/split_test.go +++ b/x-pack/filebeat/input/httpjson/split_test.go @@ -15,8 +15,6 @@ import ( ) func TestSplit(t *testing.T) { - registerResponseTransforms() - t.Cleanup(func() { registeredTransforms = newRegistry() }) cases := []struct { name string config *splitConfig diff --git a/x-pack/filebeat/input/httpjson/transform.go b/x-pack/filebeat/input/httpjson/transform.go index 53db9b9d45d2..d4055889bf04 100644 --- a/x-pack/filebeat/input/httpjson/transform.go +++ b/x-pack/filebeat/input/httpjson/transform.go @@ -228,7 +228,7 @@ func (e maybeMsg) failed() bool { return e.err != nil } func (e maybeMsg) Error() string { return e.err.Error() } // newTransformsFromConfig creates a list of transforms from a list of free user configurations. -func newTransformsFromConfig(config transformsConfig, namespace string, log *logp.Logger) (transforms, error) { +func newTransformsFromConfig(registeredTransforms registry, config transformsConfig, namespace string, log *logp.Logger) (transforms, error) { var trans transforms for _, tfConfig := range config { if len(tfConfig.GetFields()) != 1 { @@ -246,7 +246,7 @@ func newTransformsFromConfig(config transformsConfig, namespace string, log *log constructor, found := registeredTransforms.get(namespace, actionName) if !found { - return nil, fmt.Errorf("the transform %s does not exist. Valid transforms: %s", actionName, registeredTransforms.String()) + return nil, fmt.Errorf("the transform %s does not exist. Valid transforms: %s", actionName, registeredTransforms) } common.PrintConfigDebugf(cfg, "Configure transform '%v' with:", actionName) @@ -261,8 +261,8 @@ func newTransformsFromConfig(config transformsConfig, namespace string, log *log return trans, nil } -func newBasicTransformsFromConfig(config transformsConfig, namespace string, log *logp.Logger) ([]basicTransform, error) { - ts, err := newTransformsFromConfig(config, namespace, log) +func newBasicTransformsFromConfig(registeredTransforms registry, config transformsConfig, namespace string, log *logp.Logger) ([]basicTransform, error) { + ts, err := newTransformsFromConfig(registeredTransforms, config, namespace, log) if err != nil { return nil, err } diff --git a/x-pack/filebeat/input/httpjson/transform_registry.go b/x-pack/filebeat/input/httpjson/transform_registry.go index 26a739494db1..c9936c987b02 100644 --- a/x-pack/filebeat/input/httpjson/transform_registry.go +++ b/x-pack/filebeat/input/httpjson/transform_registry.go @@ -5,7 +5,6 @@ package httpjson import ( - "errors" "fmt" "strings" @@ -13,45 +12,43 @@ import ( "github.com/elastic/elastic-agent-libs/logp" ) -type constructor func(config *conf.C, log *logp.Logger) (transform, error) - -var registeredTransforms = newRegistry() +// registry is a collection of namespaced transform constructors. +// The registry is keyed on the namespace major and then on the +// transforms name. +type registry map[string]map[string]constructor -type registry struct { - namespaces map[string]map[string]constructor -} +type constructor func(config *conf.C, log *logp.Logger) (transform, error) -func newRegistry() *registry { - return ®istry{namespaces: make(map[string]map[string]constructor)} +var registeredTransforms = registry{ + requestNamespace: { + appendName: newAppendRequest, + deleteName: newDeleteRequest, + setName: newSetRequestPagination, + }, + responseNamespace: { + appendName: newAppendResponse, + deleteName: newDeleteResponse, + setName: newSetResponse, + }, + paginationNamespace: { + appendName: newAppendPagination, + deleteName: newDeletePagination, + setName: newSetRequestPagination, + }, } -func (reg *registry) register(namespace, transform string, cons constructor) error { - if cons == nil { - return errors.New("constructor can't be nil") - } - - m, found := reg.namespaces[namespace] - if !found { - reg.namespaces[namespace] = make(map[string]constructor) - m = reg.namespaces[namespace] - } - - if _, found := m[transform]; found { - return errors.New("already registered") - } - - m[transform] = cons - - return nil +func (reg registry) get(namespace, transform string) (_ constructor, ok bool) { + c, ok := reg[namespace][transform] + return c, ok } func (reg registry) String() string { - if len(reg.namespaces) == 0 { + if len(reg) == 0 { return "(empty registry)" } var str string - for namespace, m := range reg.namespaces { + for namespace, m := range reg { names := make([]string, 0, len(m)) for k := range m { names = append(names, k) @@ -61,21 +58,3 @@ func (reg registry) String() string { return str } - -func (reg registry) get(namespace, transform string) (constructor, bool) { - m, found := reg.namespaces[namespace] - if !found { - return nil, false - } - c, found := m[transform] - return c, found -} - -func registerTransform(namespace, transform string, constructor constructor) { - logp.L().Named(logName).Debugf("Register transform %s:%s", namespace, transform) - - err := registeredTransforms.register(namespace, transform, constructor) - if err != nil { - panic(err) - } -} diff --git a/x-pack/filebeat/input/httpjson/transform_test.go b/x-pack/filebeat/input/httpjson/transform_test.go index 57a337f65814..9b1535898b53 100644 --- a/x-pack/filebeat/input/httpjson/transform_test.go +++ b/x-pack/filebeat/input/httpjson/transform_test.go @@ -47,8 +47,11 @@ func TestTransformableClone(t *testing.T) { } func TestNewTransformsFromConfig(t *testing.T) { - registerTransform("test", setName, newSetRequestPagination) - t.Cleanup(func() { registeredTransforms = newRegistry() }) + registeredTransforms := registry{ + "test": { + setName: newSetRequestPagination, + }, + } cases := []struct { name string @@ -104,7 +107,7 @@ func TestNewTransformsFromConfig(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { cfg := conf.MustNewConfigFrom(tc.paramCfg) - gotTransforms, gotErr := newTransformsFromConfig(transformsConfig{cfg}, tc.paramNamespace, nil) + gotTransforms, gotErr := newTransformsFromConfig(registeredTransforms, transformsConfig{cfg}, tc.paramNamespace, nil) if tc.expectedErr == "" { assert.NoError(t, gotErr) tr := gotTransforms[0].(*set) @@ -123,14 +126,15 @@ type fakeTransform struct{} func (fakeTransform) transformName() string { return "fake" } func TestNewBasicTransformsFromConfig(t *testing.T) { - fakeConstr := func(*conf.C, *logp.Logger) (transform, error) { - return fakeTransform{}, nil + registeredTransforms := registry{ + "test": { + setName: newSetRequestPagination, + "fake": func(*conf.C, *logp.Logger) (transform, error) { + return fakeTransform{}, nil + }, + }, } - registerTransform("test", setName, newSetRequestPagination) - registerTransform("test", "fake", fakeConstr) - t.Cleanup(func() { registeredTransforms = newRegistry() }) - cases := []struct { name string paramCfg map[string]interface{} @@ -160,7 +164,7 @@ func TestNewBasicTransformsFromConfig(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { cfg := conf.MustNewConfigFrom(tc.paramCfg) - _, gotErr := newBasicTransformsFromConfig(transformsConfig{cfg}, tc.paramNamespace, nil) + _, gotErr := newBasicTransformsFromConfig(registeredTransforms, transformsConfig{cfg}, tc.paramNamespace, nil) if tc.expectedErr == "" { assert.NoError(t, gotErr) } else {