diff --git a/metricproviders/newrelic/mock_test.go b/metricproviders/newrelic/mock_test.go index 7e7be08def..a584d09e2a 100644 --- a/metricproviders/newrelic/mock_test.go +++ b/metricproviders/newrelic/mock_test.go @@ -1,6 +1,8 @@ package newrelic import ( + "reflect" + "github.com/argoproj/argo-rollouts/pkg/apis/rollouts/v1alpha1" "github.com/newrelic/newrelic-client-go/v2/pkg/nrdb" ) @@ -16,3 +18,45 @@ func (m *mockAPI) Query(metric v1alpha1.Metric) ([]nrdb.NRDBResult, error) { } return m.response, nil } + +type mockNerdGraphClient struct { + response []nrdb.NRDBResult + lastArgs map[string]interface{} + err error +} + +func (m *mockNerdGraphClient) QueryWithResponse(query string, variables map[string]interface{}, respBody interface{}) error { + m.lastArgs = variables + + if m.err != nil { + return m.err + } + + r := gqlNrglQueryResponse{ + Actor{Account{NRQL: nrdb.NRDBResultContainer{ + Results: m.response, + }}}, + } + rVal := reflect.ValueOf(r) + reflect.ValueOf(respBody).Elem().Set(rVal) + + return nil +} + +func (m *mockNerdGraphClient) Response(response []nrdb.NRDBResult) { + m.response = response +} + +func (m *mockNerdGraphClient) LastArgs() map[string]any { + return m.lastArgs +} + +func (m *mockNerdGraphClient) Err(err error) { + m.err = err +} + +func (m *mockNerdGraphClient) Clear() { + m.err = nil + m.response = nil + m.lastArgs = nil +} diff --git a/metricproviders/newrelic/newrelic.go b/metricproviders/newrelic/newrelic.go index 2ec1ed9073..4222c72983 100644 --- a/metricproviders/newrelic/newrelic.go +++ b/metricproviders/newrelic/newrelic.go @@ -30,12 +30,20 @@ const ( defaultNrqlTimeout = 5 ) +var ( + ErrNegativeTimeout = errors.New("timeout value needs to be a positive value") +) + +type Account struct { + NRQL nrdb.NRDBResultContainer +} + +type Actor struct { + Account +} + type gqlNrglQueryResponse struct { - Actor struct { - Account struct { - NRQL nrdb.NRDBResultContainer - } - } + Actor } const gqlNrqlQuery = `query ( @@ -60,9 +68,13 @@ type NewRelicClientAPI interface { Query(metric v1alpha1.Metric) ([]nrdb.NRDBResult, error) } +type nerdGraphClient interface { + QueryWithResponse(query string, variables map[string]interface{}, respBody interface{}) error +} + type NewRelicClient struct { - *newrelic.NewRelic - AccountID int + NerdGraphClient nerdGraphClient + AccountID int } // Query executes a NRQL query against the given New Relic account @@ -75,7 +87,7 @@ func (n *NewRelicClient) Query(metric v1alpha1.Metric) ([]nrdb.NRDBResult, error } if timeout < 0 { - return nil, fmt.Errorf("timeout value needs to be a positive value") + return nil, ErrNegativeTimeout } args := map[string]any{ @@ -84,7 +96,7 @@ func (n *NewRelicClient) Query(metric v1alpha1.Metric) ([]nrdb.NRDBResult, error "timeout": timeout, } - if err := n.NerdGraph.QueryWithResponse(gqlNrqlQuery, args, &respBody); err != nil { + if err := n.NerdGraphClient.QueryWithResponse(gqlNrqlQuery, args, &respBody); err != nil { return nil, err } @@ -233,7 +245,7 @@ func NewNewRelicAPIClient(metric v1alpha1.Metric, kubeclientset kubernetes.Inter if err != nil { return nil, fmt.Errorf("could not parse account ID: %w", err) } - return &NewRelicClient{NewRelic: nrClient, AccountID: accID}, nil + return &NewRelicClient{NerdGraphClient: &nrClient.NerdGraph, AccountID: accID}, nil } else { return nil, errors.New("account ID or personal API key not found") } diff --git a/metricproviders/newrelic/newrelic_test.go b/metricproviders/newrelic/newrelic_test.go index e3fdf8016a..9b1cf31156 100644 --- a/metricproviders/newrelic/newrelic_test.go +++ b/metricproviders/newrelic/newrelic_test.go @@ -407,3 +407,73 @@ func TestNewNewRelicAPIClient(t *testing.T) { assert.NotNil(t, err) }) } + +func TestNewRelicClient_Query(t *testing.T) { + accountId := 1234567 + sevenTo := int64(7) + negativeTo := int64(-1) + defaultTo := int64(defaultNrqlTimeout) + theQuery := "FROM K8sContainerSample SELECT percentile(`cpuCoresUtilization`, 95)" + + mockNGC := &mockNerdGraphClient{} + nrc := &NewRelicClient{NerdGraphClient: mockNGC, AccountID: accountId} + + tests := map[string]struct { + timeoutProvided *int64 + timeoutUsed *int64 + query string + want []nrdb.NRDBResult + errMsg string + gqlErr error + }{ + `returns results`: { + timeoutUsed: &defaultTo, + query: theQuery, + want: []nrdb.NRDBResult{map[string]any{"count": 10}}, + }, + `uses default timeout when one is not provided`: { + timeoutUsed: &defaultTo, + query: theQuery, + }, + `uses provided timeout`: { + timeoutUsed: &sevenTo, + timeoutProvided: &sevenTo, + query: theQuery, + }, + `errors when timeout is negative`: { + timeoutProvided: &negativeTo, + query: theQuery, + errMsg: ErrNegativeTimeout.Error(), + }, + `errors when nerdgraph returns error`: { + timeoutUsed: &defaultTo, + query: theQuery, + errMsg: "boom", + gqlErr: errors.New("boom"), + }, + } + for testName, tc := range tests { + t.Run(testName, func(t *testing.T) { + defer mockNGC.Clear() + mockNGC.Err(tc.gqlErr) + mockNGC.Response(tc.want) + metric := v1alpha1.Metric{ + Provider: v1alpha1.MetricProvider{ + NewRelic: &v1alpha1.NewRelicMetric{ + Timeout: tc.timeoutProvided, + Query: tc.query, + }, + }, + } + results, err := nrc.Query(metric) + if len(tc.errMsg) > 0 { + assert.EqualError(t, err, tc.errMsg) + return + } + assert.Equal(t, *tc.timeoutUsed, mockNGC.LastArgs()["timeout"]) + assert.Equal(t, tc.query, mockNGC.LastArgs()["query"]) + assert.Equal(t, accountId, mockNGC.LastArgs()["accountId"]) + assert.Equal(t, tc.want, results) + }) + } +}