diff --git a/clients/graphql/batch.go b/clients/graphql/batch.go new file mode 100644 index 00000000..5ae8e9b5 --- /dev/null +++ b/clients/graphql/batch.go @@ -0,0 +1,185 @@ +package graphql + +import ( + "bytes" + "compress/gzip" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Khan/genqlient/graphql" + "github.com/forta-network/forta-core-go/protocol" + log "github.com/sirupsen/logrus" +) + +var ( + ErrResponseSizeTooBig = fmt.Errorf("response size too big") +) + +// paginateBatch processes the response received from the server and extracts the relevant data. +// It takes an array of AlertsInput objects and a graphql.Response object as input. +// It then iterates over the inputs and extracts the response item based on the alias. +// If there is an error for the input, it logs a warning and continues to the next input. +// If there are more alerts for the input, it adds the input to the list of pagination inputs. +// The function returns the pagination inputs, the extracted alert events, and any encountered error. +func paginateBatch(inputs []*AlertsInput, response *graphql.Response) ([]*AlertsInput, + []*protocol.AlertEvent, error) { + // type-checking response + if response == nil { + return nil, nil, fmt.Errorf("nil graphql response") + } + + batchAlertsResponseUnsafe, ok := response.Data.(*BatchGetAlertsResponse) + if !ok { + return nil, nil, fmt.Errorf("invalid pagination response") + } + if batchAlertsResponseUnsafe == nil { + return nil, nil, fmt.Errorf("nil pagination response") + } + + batchAlertsResponse := *batchAlertsResponseUnsafe + + var pagination []*AlertsInput + var alerts []*protocol.AlertEvent + for inputIdx := range inputs { + alias := idxToResponseAlias(inputIdx) + + logger := log.WithFields(log.Fields{ + "input": idxToInputAlias(inputIdx), + }) + + responseItem, ok := batchAlertsResponse[alias] + if !ok { + logger.Warn("no response for input") + continue + } + + // check if there is an error for the input + err := HasError(response.Errors, inputIdx) + if err != nil { + logger.WithError(err).Warn("error response for input") + continue + } + + alerts = append(alerts, responseItem.ToAlertEvents()...) + + // check if there are more alerts for the input + if !responseItem.PageInfo.HasNextPage { + continue + } + + // check if there are more alerts for the input + if responseItem.PageInfo.EndCursor == nil { + continue + } + + // add input to the list of pagination inputs + inputs[inputIdx].After = &AlertEndCursorInput{ + AlertId: responseItem.PageInfo.EndCursor.AlertId, + BlockNumber: responseItem.PageInfo.EndCursor.BlockNumber, + } + + pagination = append(pagination, inputs[inputIdx]) + } + + return pagination, alerts, nil +} + +// fetchAlertsBatch retrieves the alerts in batches from the server using the provided client, inputs, and headers. +func fetchAlertsBatch(ctx context.Context, client string, inputs []*AlertsInput, headers map[string]string) (*graphql.Response, error) { + query, variables := createGetAlertsQuery(inputs) + req := &graphql.Request{ + OpName: "getAlerts", + Query: query, + Variables: variables, + } + + respBody, err := makeRequest(ctx, client, req, headers) + if err != nil { + return nil, err + } + + resp := parseBatchResponse(respBody) + if err != nil { + return nil, err + } + + return resp, nil +} + +// makeRequest sends a GraphQL request to the specified client and returns the response body. +// It takes the context, client URL, request, and headers as input parameters. +// It marshals the request into JSON and creates an HTTP request. +// It sets the custom headers and executes the query with the default HTTP client. +// If the response status code is not OK, it returns an error with the status and response body. +// If the response is gzip compressed, it decompresses the body before parsing. +// It reads the response body and returns it along with any encountered error. +func makeRequest(ctx context.Context, client string, req *graphql.Request, headers map[string]string) ([]byte, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, err + } + + httpReq, err := http.NewRequest( + http.MethodPost, + client, + bytes.NewReader(body), + ) + if err != nil { + return nil, err + } + + httpReq = httpReq.WithContext(ctx) + + // set custom headers + for key, val := range headers { + httpReq.Header.Set(key, val) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept-Encoding", "gzip") + + queryTime := time.Now().Truncate(time.Minute).UnixMilli() + httpReq.Header.Set("Forta-Query-Timestamp", fmt.Sprintf("%d", queryTime)) + + // execute query + httpResp, err := http.DefaultClient.Do(httpReq) + if err != nil { + return nil, err + } + defer httpResp.Body.Close() + + if httpResp.StatusCode == http.StatusInternalServerError { + return nil, ErrResponseSizeTooBig + } + + if httpResp.StatusCode != http.StatusOK { + var respBody []byte + respBody, err = io.ReadAll(httpResp.Body) + if err != nil { + respBody = []byte(fmt.Sprintf("", err)) + } + return nil, fmt.Errorf("returned error %v: %s", httpResp.Status, respBody) + } + + // Check if the response is compressed with gzip + var respBodyReader = httpResp.Body + if strings.Contains(httpResp.Header.Get("Content-Encoding"), "gzip") { + respBodyReader, err = gzip.NewReader(httpResp.Body) + if err != nil { + return nil, err + } + defer respBodyReader.Close() + } + + // Parse response + respBody, err := io.ReadAll(respBodyReader) + if err != nil { + return nil, err + } + return respBody, err +} diff --git a/clients/graphql/batch_test.go b/clients/graphql/batch_test.go new file mode 100644 index 00000000..36c5d241 --- /dev/null +++ b/clients/graphql/batch_test.go @@ -0,0 +1,77 @@ +package graphql + +import ( + "fmt" + "testing" + + "github.com/Khan/genqlient/graphql" + "github.com/forta-network/forta-core-go/protocol" + "github.com/stretchr/testify/assert" + "github.com/vektah/gqlparser/v2/ast" + "github.com/vektah/gqlparser/v2/gqlerror" +) + +func TestPaginateBatch(t *testing.T) { + var alerts0 ast.PathName = "alerts0" + tests := []struct { + name string + inputs []*AlertsInput + response *graphql.Response + expectedPagination []*AlertsInput + expectedAlerts []*protocol.AlertEvent + expectedErr error + }{ + { + name: "Invalid Pagination Response", + response: &graphql.Response{}, + expectedErr: fmt.Errorf("invalid pagination response"), + }, + { + name: "Test with two inputs and one error in response", + inputs: []*AlertsInput{{}, {}}, + response: &graphql.Response{ + Data: &BatchGetAlertsResponse{ + "alerts0": { + PageInfo: nil, + }, + "alerts1": { + PageInfo: &PageInfo{ + HasNextPage: true, + EndCursor: &EndCursor{AlertId: "0xaaa"}, + }, + }, + }, + Errors: gqlerror.List{{ + Path: ast.Path{ + alerts0, + }, + Message: "test error", + }}, + }, + expectedPagination: []*AlertsInput{ + { + After: &AlertEndCursorInput{ + AlertId: "0xaaa", + }, + }, + }, + expectedAlerts: nil, + expectedErr: nil, + }, + // Add more test cases here for other scenarios. + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pagination, alerts, err := paginateBatch(tt.inputs, tt.response) + assert.Equal(t, tt.expectedPagination, pagination) + assert.Equal(t, tt.expectedAlerts, alerts) + if tt.expectedErr != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedErr.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/clients/graphql/client.go b/clients/graphql/client.go index bb5a1175..babf4b9d 100644 --- a/clients/graphql/client.go +++ b/clients/graphql/client.go @@ -27,6 +27,7 @@ type client struct { type Client interface { GetAlerts(ctx context.Context, input *AlertsInput, headers map[string]string) ([]*protocol.AlertEvent, error) + GetAlertsBatch(ctx context.Context, input []*AlertsInput, headers map[string]string) ([]*protocol.AlertEvent, error) } func NewClient(url string) Client { @@ -60,7 +61,7 @@ func (ac *client) GetAlerts( return nil, fmt.Errorf("failed to fetch alerts: %v", err) } - alerts = append(alerts, response.ToAlertEvents()...) + alerts = append(alerts, response.Alerts.ToAlertEvents()...) // check if there are more alerts if !response.Alerts.PageInfo.HasNextPage { @@ -175,3 +176,67 @@ func parseResponse(responseBody []byte) (*graphql.Response, *GetAlertsResponse, return resp, &data, err } + +// GetAlertsBatch is a method that retrieves alerts in batches using pagination. +// It takes a context, a slice of AlertsInput, and a map of headers as parameters. +// It returns a slice of AlertEvent and an error. +// +// The method pre-processes the inputs by assigning default values to some fields if they are not provided. +// +// It then iterates until there are no more pagination inputs to query. In each iteration, it calls the fetchAlertsBatch function to fetch alerts based on the inputs and headers. +// If an error occurs during fetching, the method returns nil and the error. +// +// After fetching alerts, the method calls the paginateBatch function to paginate the inputs and the response. It assigns the new inputs and the alert page to variables. +// If an error occurs during pagination, the method returns nil and the error. +// +// Finally, the method appends the alert page to the alerts slice and repeats the iteration until there are no more pagination inputs to query. +// It then returns the alerts slice and nil. +func (ac *client) GetAlertsBatch(ctx context.Context, inputs []*AlertsInput, headers map[string]string) ([]*protocol.AlertEvent, error) { + // pre-process inputs + for _, input := range inputs { + if input.BlockSortDirection == "" { + input.BlockSortDirection = SortDesc + } + + // have a default of 10m + if input.CreatedSince == 0 { + input.CreatedSince = uint(DefaultLastNMinutes.Milliseconds()) + } + + if input.First == 0 { + input.First = DefaultPageSize + } + } + + var alerts []*protocol.AlertEvent + + // iterate until there are no more pagination inputs to query + for len(inputs) > 0 { + response, err := fetchAlertsBatch(ctx, ac.url, inputs, headers) + if err != nil { + return nil, err + } + + var alertPage []*protocol.AlertEvent + inputs, alertPage, err = paginateBatch(inputs, response) + if err != nil { + return nil, err + } + + alerts = append(alerts, alertPage...) + } + + return alerts, nil +} + +func parseBatchResponse(responseBody []byte) *graphql.Response { + var data BatchGetAlertsResponse + resp := &graphql.Response{Data: &data} + + err := json.Unmarshal(responseBody, resp) + if err != nil { + return nil + } + + return resp +} diff --git a/clients/graphql/client_test.go b/clients/graphql/client_test.go index 25633f03..524811fa 100644 --- a/clients/graphql/client_test.go +++ b/clients/graphql/client_test.go @@ -1,32 +1,141 @@ package graphql import ( + "context" "fmt" + "net/http" + "net/http/httptest" "testing" + "time" + "github.com/forta-network/forta-core-go/protocol" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUnmarshal(t *testing.T) { - resp, data, err := parseResponse([]byte(testResponse)) - assert.NoError(t, err) + resp := parseBatchResponse([]byte(testResponse)) + data := (*resp.Data.(*BatchGetAlertsResponse))["alerts0"] + assert.NotNilf(t, resp, "graphql response can not be nil") - assert.NotNilf(t, data, "data can not be nil") for i := 0; i < 5; i++ { - assert.Equal(t, fmt.Sprintf("0x%d", i), data.Alerts.Alerts[i].Source.SourceAlert.Hash) - assert.Equal(t, "0xbbb", data.Alerts.Alerts[i].Source.SourceAlert.BotId) - assert.Equal(t, "2023-01-01T00:00:00Z", data.Alerts.Alerts[i].Source.SourceAlert.Timestamp) - assert.Equal(t, uint64(137), data.Alerts.Alerts[i].Source.SourceAlert.ChainId) - assert.Equal(t, "Block height: 17890044", data.Alerts.Alerts[i].Description) - assert.Equal(t, uint64(i), data.Alerts.Alerts[i].Source.Block.Number) + assert.Equal(t, fmt.Sprintf("0x%d", i), data.Alerts[i].Source.SourceAlert.Hash) + assert.Equal(t, "0xbbb", data.Alerts[i].Source.SourceAlert.BotId) + assert.Equal(t, "2023-01-01T00:00:00Z", data.Alerts[i].Source.SourceAlert.Timestamp) + assert.Equal(t, uint64(137), data.Alerts[i].Source.SourceAlert.ChainId) + assert.Equal(t, "Block height: 17890044", data.Alerts[i].Description) + assert.Equal(t, uint64(i), data.Alerts[i].Source.Block.Number) + } +} + +func TestGetAlertsBatch(t *testing.T) { + batchResp := parseBatchResponse([]byte(testResponse)) + responseItem := (*batchResp.Data.(*BatchGetAlertsResponse))["alerts0"] + expectedAlerts := responseItem.ToAlertEvents() + tests := []struct { + desc string + inputs []*AlertsInput + headers map[string]string + setupMock func(mux *http.ServeMux) + wantAlerts []*protocol.AlertEvent + wantErr bool + }{ + { + desc: "Successful Request", + inputs: []*AlertsInput{ + { + BlockSortDirection: "ASC", + CreatedSince: 30, + First: 3, + }, + }, + headers: map[string]string{ + "Authorization": "Bearer: token", + }, + setupMock: func(mux *http.ServeMux) { + // Here's a simple example of what your setup function might do: + mux.HandleFunc("/graphql", func(w http.ResponseWriter, r *http.Request) { + + fmt.Fprint(w, testResponse) + }) + }, + wantAlerts: expectedAlerts, + wantErr: false, + }, + { + desc: "Failure due to server error", + headers: map[string]string{ + "Authorization": "Bearer: token", + }, + inputs: []*AlertsInput{ + { + Bots: []string{"0xabc"}, + }, + }, + setupMock: func(mux *http.ServeMux) { + // Here's a simple example of what your setup function might do: + mux.HandleFunc("/graphql", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "server error", http.StatusInternalServerError) + }) + }, + wantAlerts: nil, + wantErr: true, + }, + { + desc: "Failure due to unauthorized", + inputs: []*AlertsInput{ + { + Bots: []string{"0xabc"}, + }, + }, + headers: map[string]string{ + "Authorization": "", // No token + }, + setupMock: func(mux *http.ServeMux) { + // Here's a simple example of what your setup function might do: + mux.HandleFunc("/graphql", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + }) + }, + wantAlerts: nil, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.desc, func(t *testing.T) { + mux := http.NewServeMux() + ts := httptest.NewUnstartedServer(mux) + + tc.setupMock(mux) // Modify setupMock to accept *http.ServeMux + ts.Start() + defer ts.Close() + + // Prepare client + client := NewClient(fmt.Sprintf("%s/graphql", ts.URL)) + + // Get context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Invoke GetAlertsBatch + gotAlerts, gotErr := client.GetAlertsBatch(ctx, tc.inputs, tc.headers) + + if tc.wantErr { + require.Error(t, gotErr) + return + } + require.NoError(t, gotErr) + require.Equal(t, tc.wantAlerts, gotAlerts) + }) } } const testResponse = `{ "data": { - "alerts": { + "alerts0": { "pageInfo": { - "hasNextPage": true, + "hasNextPage": false, "endCursor": { "alertId": "0x0baefe6f0be064d7f3637af75a90964e7c231cb6c35266f51af2ce3539558b93", "blockNumber": 17890041 diff --git a/clients/graphql/models.go b/clients/graphql/models.go index 7eb61652..8d8bfd47 100644 --- a/clients/graphql/models.go +++ b/clients/graphql/models.go @@ -1,17 +1,24 @@ package graphql import ( + "fmt" + "strings" "time" "github.com/ethereum/go-ethereum/common/hexutil" "github.com/forta-network/forta-core-go/protocol" + "github.com/vektah/gqlparser/v2/gqlerror" ) type GetAlertsResponse struct { - Alerts struct { - PageInfo *PageInfo `json:"pageInfo"` - Alerts []*protocol.AlertEvent_Alert `json:"alerts"` - } + Alerts GetAlertResponseItem +} + +type BatchGetAlertsResponse map[string]GetAlertResponseItem + +type GetAlertResponseItem struct { + PageInfo *PageInfo `json:"pageInfo"` + Alerts []*protocol.AlertEvent_Alert `json:"alerts"` } // AlertsInput Alert list input @@ -120,9 +127,9 @@ type __getAlertsInput struct { Input *AlertsInput `json:"input,omitempty"` } -func (g *GetAlertsResponse) ToAlertEvents() []*protocol.AlertEvent { - resp := make([]*protocol.AlertEvent, len(g.Alerts.Alerts)) - for i, alert := range g.Alerts.Alerts { +func (g *GetAlertResponseItem) ToAlertEvents() []*protocol.AlertEvent { + resp := make([]*protocol.AlertEvent, len(g.Alerts)) + for i, alert := range g.Alerts { // alert will be the source alert in consumer's perspective, // so it should be exposed as the SourceAlert tracking timestamp t, _ := time.Parse(time.RFC3339, alert.CreatedAt) @@ -139,6 +146,59 @@ func (g *GetAlertsResponse) ToAlertEvents() []*protocol.AlertEvent { return resp } +// createGetAlertsQuery creates aliased graphql queries, using alerts${index} and input${index} as aliases. +func createGetAlertsQuery(inputs []*AlertsInput) (string, map[string]interface{}) { + variables := make(map[string]interface{}) + var queryBuilder strings.Builder + + // Define the operation with necessary variables + queryBuilder.WriteString("query getAlerts(") + for i := range inputs { + input := fmt.Sprintf("$%s: AlertsInput", idxToInputAlias(i)) + queryBuilder.WriteString(input) + if i < len(inputs)-1 { + queryBuilder.WriteString(",") + } + } + queryBuilder.WriteString(") {") + + for i, input := range inputs { + // Ensure that the alias and query structure match the schema + alias := fmt.Sprintf("%s: alerts(input: $%s) {", idxToResponseAlias(i), idxToInputAlias(i)) + queryBuilder.WriteString(alias) + + // Include subfields for alerts based on the schema (example subfields) + queryBuilder.WriteString(getAlertsFields) + queryBuilder.WriteString("}") + + // Add the input to the variables map + variables[idxToInputAlias(i)] = input + } + + // End of the query + queryBuilder.WriteString("}") + + return queryBuilder.String(), variables +} + +func idxToInputAlias(idx int) string { + return fmt.Sprintf("input%d", idx) +} + +func idxToResponseAlias(idx int) string { + return fmt.Sprintf("alerts%d", idx) +} + +func HasError(errors gqlerror.List, idx int) error { + for _, e := range errors { + if e.Path.String() == idxToResponseAlias(idx) { + return fmt.Errorf(e.Error()) + } + } + + return nil +} + // The query or mutation executed by getAlerts. const getAlertsOperation = ` query getAlerts ($input: AlertsInput) { @@ -223,3 +283,84 @@ query getAlerts ($input: AlertsInput) { } } ` + +// getAlertsFields +const getAlertsFields = ` +pageInfo { + hasNextPage + endCursor { + alertId + blockNumber + } +} +alerts { + alertId + addresses + contracts { + name + projectId + } + createdAt + description + hash + metadata + name + projects { + id + } + protocol + scanNodeCount + severity + source { + transactionHash + bot { + chainIds + createdAt + description + developer + docReference + enabled + id + image + name + reference + repository + projects + scanNodes + version + } + block { + number + hash + timestamp + chainId + } + sourceAlert { + hash + botId + timestamp + chainId + } + } + alertDocumentType + findingType + relatedAlerts + chainId + labels { + label + confidence + entity + entityType + remove + metadata + uniqueKey + embedding + } + addressBloomFilter { + bitset + itemCount + k + m + } +} +` diff --git a/clients/graphql/models_test.go b/clients/graphql/models_test.go new file mode 100644 index 00000000..de6eb13a --- /dev/null +++ b/clients/graphql/models_test.go @@ -0,0 +1,98 @@ +package graphql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_createGetAlertsQuery(t *testing.T) { + testInput := AlertsInput{ + AlertId: "0xabc", + } + resp, variables := createGetAlertsQuery([]*AlertsInput{&testInput}) + assert.Equal(t, resp, mockExpectedQuery) + variable, ok := variables["input0"].(*AlertsInput) + assert.True(t, ok) + assert.Equal(t, variable.AlertId, testInput.AlertId) +} + +const mockExpectedQuery = `query getAlerts($input0: AlertsInput) {alerts0: alerts(input: $input0) { +pageInfo { + hasNextPage + endCursor { + alertId + blockNumber + } +} +alerts { + alertId + addresses + contracts { + name + projectId + } + createdAt + description + hash + metadata + name + projects { + id + } + protocol + scanNodeCount + severity + source { + transactionHash + bot { + chainIds + createdAt + description + developer + docReference + enabled + id + image + name + reference + repository + projects + scanNodes + version + } + block { + number + hash + timestamp + chainId + } + sourceAlert { + hash + botId + timestamp + chainId + } + } + alertDocumentType + findingType + relatedAlerts + chainId + labels { + label + confidence + entity + entityType + remove + metadata + uniqueKey + embedding + } + addressBloomFilter { + bitset + itemCount + k + m + } +} +}}` diff --git a/clients/mocks/mock_graphql_client.go b/clients/mocks/mock_graphql_client.go index 12284315..fab45171 100644 --- a/clients/mocks/mock_graphql_client.go +++ b/clients/mocks/mock_graphql_client.go @@ -50,3 +50,18 @@ func (mr *MockClientMockRecorder) GetAlerts(ctx, input, headers interface{}) *go mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAlerts", reflect.TypeOf((*MockClient)(nil).GetAlerts), ctx, input, headers) } + +// GetAlertsBatch mocks base method. +func (m *MockClient) GetAlertsBatch(ctx context.Context, input []*graphql.AlertsInput, headers map[string]string) ([]*protocol.AlertEvent, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAlertsBatch", ctx, input, headers) + ret0, _ := ret[0].([]*protocol.AlertEvent) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAlertsBatch indicates an expected call of GetAlertsBatch. +func (mr *MockClientMockRecorder) GetAlertsBatch(ctx, input, headers interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAlertsBatch", reflect.TypeOf((*MockClient)(nil).GetAlertsBatch), ctx, input, headers) +} diff --git a/feeds/combiner.go b/feeds/combiner.go index f8e9bdd3..e24b3baf 100644 --- a/feeds/combiner.go +++ b/feeds/combiner.go @@ -2,6 +2,7 @@ package feeds import ( "context" + "errors" "fmt" "sync" "time" @@ -22,6 +23,10 @@ var ( ErrBadRequest = fmt.Errorf("bad public api request") ) +const ( + DefaultBatchSize = 10 +) + type cfHandler struct { Handler func(evt *domain.AlertEvent) error ErrCh chan<- error @@ -46,6 +51,7 @@ type combinerFeed struct { handlersMu sync.Mutex cfg CombinerFeedConfig maxAlertAge time.Duration + batchSize int } func (cf *combinerFeed) Subscriptions() []*domain.CombinerBotSubscription { @@ -173,19 +179,7 @@ func (cf *combinerFeed) forEachAlert(alertHandlers []cfHandler) error { upperBound := int64(0) // Query all subscriptions and process alerts - for _, subscription := range cf.Subscriptions() { - logger = logger.WithFields( - log.Fields{ - "subscriberBotId": subscription.Subscriber.BotID, - "subscribedBotId": subscription.Subscription.BotId, - }, - ) - - err := cf.fetchAlertsAndHandle(cf.ctx, alertHandlers, subscription, lowerBound.Milliseconds(), upperBound) - if err != nil { - logger.WithError(err).Warn("failed to fetch alerts and handle") - } - } + cf.handleSubscriptions(alertHandlers, cf.Subscriptions(), lowerBound, upperBound, logger) // Save alert cache to persistent file, if configured if cf.cfg.CombinerCachePath != "" { @@ -196,58 +190,75 @@ func (cf *combinerFeed) forEachAlert(alertHandlers []cfHandler) error { } } -// fetchAlertsAndHandle retrieves alerts from the public API using the given subscription details, filters them by creation date, -// and processes each alert by calling the alert handlers passed in as an argument. -// Returns an error if there was an issue fetching or processing alerts. -// This method is thread-safe, as it acquires a lock on the client and combinerCache mutexes before accessing or modifying them. -func (cf *combinerFeed) fetchAlertsAndHandle( - ctx context.Context, alertHandlers []cfHandler, subscription *domain.CombinerBotSubscription, createdSince int64, - createdBefore int64, -) error { - logger := log.WithFields( - log.Fields{ - "subscriberBotId": subscription.Subscriber.BotID, - "subscriberBotOwner": subscription.Subscriber.BotOwner, - "subscriberBotImage": subscription.Subscriber.BotImage, - "subscribedTo": subscription.Subscription.BotId, - }, - ) - - alerts, err := cf.fetchAlerts(ctx, logger, subscription, createdSince, createdBefore) - if err != nil { - return err +func (cf *combinerFeed) handleSubscriptions(alertHandlers []cfHandler, subscriptions []*domain.CombinerBotSubscription, lowerBound time.Duration, upperBound int64, logger *log.Entry) { + // create a lookup map to batch per subscriber (ie. bot) + subscriberBatchMap := make(map[domain.Subscriber][]*protocol.CombinerBotSubscription) + for _, subscription := range subscriptions { + subscriber := *subscription.Subscriber + subscriberBatchMap[subscriber] = append(subscriberBatchMap[subscriber], subscription.Subscription) } - cf.processAlerts(ctx, logger, alerts, subscription, alertHandlers) - - return nil + // handle subscriptions in batches + for subscriber, botSubscriptions := range subscriberBatchMap { + logger = logger.WithFields( + log.Fields{ + "subscriberBotId": subscriber.BotID, + "subscriberBotOwner": subscriber.BotOwner, + "subscriberBotImage": subscriber.BotImage, + }) + + var alertBatchResponse []*protocol.AlertEvent + // iterate over batches and handle + for i := 0; i < len(botSubscriptions); { + currentBatchSize := cf.batchSize + for { + // Determine the end of the current batch + end := i + currentBatchSize + if end > len(botSubscriptions) { + end = len(botSubscriptions) + } + + // Create a batch + batch := botSubscriptions[i:end] + + alerts, err := cf.fetchAlertsBatch(cf.ctx, logger, &subscriber, batch, lowerBound.Milliseconds(), upperBound) + if err != nil { + if errors.Is(err, graphql.ErrResponseSizeTooBig) && currentBatchSize > 1 { + // Reduce batch size and retry + currentBatchSize /= 2 + logger.WithError(err).Warnf("Batch too big, reducing size to %d and retrying", currentBatchSize) + continue + } else { + // Other error or batch size already at minimum + logger.WithError(err).Warn("failed to fetch alerts") + break + } + } + + alertBatchResponse = append(alertBatchResponse, alerts...) + i += currentBatchSize + break + } + } + cf.processAlerts(cf.ctx, logger, &subscriber, alertBatchResponse, alertHandlers) + } } -// fetchAlerts retrieves alerts from the GraphQL API for the given subscription and time range. The method constructs a -// graphql.AlertsInput object based on the subscription data and passes it to the GraphQL client's GetAlerts method. It uses a -// retryWithBackoff method to retry the GetAlerts call in case of errors. The method returns a slice of alerts on success and an -// error on failure. -func (cf *combinerFeed) fetchAlerts(ctx context.Context, logger *log.Entry, subscription *domain.CombinerBotSubscription, createdSince int64, createdBefore int64) ([]*protocol.AlertEvent, error) { +func (cf *combinerFeed) fetchAlertsBatch(ctx context.Context, logger *log.Entry, subscriber *domain.Subscriber, + subscriptions []*protocol.CombinerBotSubscription, createdSince int64, createdBefore int64) ([]*protocol.AlertEvent, + error) { var alerts []*protocol.AlertEvent // construct auth headers for the subscriber - authHeaders := subscriberInfoToHeaders(subscription.Subscriber) - - // construct the graphql.AlertsInput object based on the subscription data - alertsInput := &graphql.AlertsInput{ - Bots: []string{subscription.Subscription.BotId}, - CreatedSince: uint(createdSince), - CreatedBefore: uint(createdBefore), - AlertIds: subscription.Subscription.AlertIds, - AlertId: subscription.Subscription.AlertId, - ChainId: uint(subscription.Subscription.ChainId), - } + authHeaders := subscriberInfoToHeaders(subscriber) + + inputs := subscriptionsToAlertInputs(subscriptions, createdSince, createdBefore) // call the GraphQL client's GetAlerts method with retries err := cf.retryWithBackoff( ctx, func() error { var cErr error - alerts, cErr = cf.client.GetAlerts(cf.ctx, alertsInput, authHeaders) + alerts, cErr = cf.client.GetAlertsBatch(cf.ctx, inputs, authHeaders) if cErr != nil { logger.WithError(cErr).Warn("error retrieving alerts") @@ -264,27 +275,44 @@ func (cf *combinerFeed) fetchAlerts(ctx context.Context, logger *log.Entry, subs return alerts, nil } +func subscriptionsToAlertInputs(subscriptions []*protocol.CombinerBotSubscription, createdSince int64, createdBefore int64) []*graphql.AlertsInput { + inputs := make([]*graphql.AlertsInput, len(subscriptions)) + for i, subscription := range subscriptions { + // construct the graphql.AlertsInput object based on the subscription data + inputs[i] = &graphql.AlertsInput{ + Bots: []string{subscription.BotId}, + CreatedSince: uint(createdSince), + CreatedBefore: uint(createdBefore), + AlertIds: subscription.AlertIds, + AlertId: subscription.AlertId, + ChainId: uint(subscription.ChainId), + } + } + return inputs +} + // processAlerts processes a slice of alerts by filtering out those that are too old or have already been processed and then passing // the remaining alerts to the alert handlers passed in as an argument. // It uses a cache to prevent duplicate processing of alerts and creates an AlertEvent object to pass to each alert handler. // It is thread-safe as it acquires a lock on the combinerCache mutex before accessing or modifying it. -func (cf *combinerFeed) processAlerts(_ context.Context, logger *log.Entry, alerts []*protocol.AlertEvent, subscription *domain.CombinerBotSubscription, alertHandlers []cfHandler) { +func (cf *combinerFeed) processAlerts(_ context.Context, logger *log.Entry, subscriber *domain.Subscriber, + alerts []*protocol.AlertEvent, + alertHandlers []cfHandler) { for _, alert := range alerts { - // check if the alert is too old to process if tooOld, age := alertIsTooOld(alert, cf.maxAlertAge); tooOld { - logger.WithField("age", age).Warnf( + logger.WithFields(log.Fields{"age": age, "alert": alert.Alert.Hash}).Warnf( "alert is older than %v - setting current alert iterator head to now", cf.maxAlertAge, ) continue } // check if the alert has already been processed - if cf.combinerCache.Exists(subscription, alert) { + if cf.combinerCache.Exists(subscriber, alert) { continue } // add the alert to the cache to prevent duplicate processing - cf.combinerCache.Set(subscription, alert) + cf.combinerCache.Set(subscriber, alert) // create an AlertEvent object to pass to each alert handler alertCA, err := time.Parse(time.RFC3339, alert.Alert.CreatedAt) @@ -300,7 +328,7 @@ func (cf *combinerFeed) processAlerts(_ context.Context, logger *log.Entry, aler Feed: time.Now().UTC(), SourceAlert: alertCA, }, - Subscriber: subscription.Subscriber, + Subscriber: subscriber, } // call each alert handler with the AlertEvent object @@ -410,6 +438,7 @@ func NewCombinerFeedWithClient(ctx context.Context, cfg CombinerFeedConfig, clie if cfg.QueryInterval > 0 && cfg.QueryInterval < uint64(DefaultRatelimitDuration.Milliseconds()) { rateLimit = time.NewTicker(time.Millisecond * time.Duration(cfg.QueryInterval)) } + bf := &combinerFeed{ maxAlertAge: time.Minute * 20, ctx: ctx, @@ -419,6 +448,7 @@ func NewCombinerFeedWithClient(ctx context.Context, cfg CombinerFeedConfig, clie botSubscriptions: []*domain.CombinerBotSubscription{}, cfg: cfg, combinerCache: c, + batchSize: DefaultBatchSize, } return bf, nil diff --git a/feeds/combiner_cache.go b/feeds/combiner_cache.go index 340c7a7d..efc79b3d 100644 --- a/feeds/combiner_cache.go +++ b/feeds/combiner_cache.go @@ -53,13 +53,13 @@ func newCombinerCache(path string) (*combinerCache, error) { return &combinerCache{cache: alertCache, path: path}, nil } -func (c *combinerCache) Exists(subscription *domain.CombinerBotSubscription, alert *protocol.AlertEvent) bool { - _, exists := c.cache.Get(encodeAlertCacheKey(subscription.Subscriber.BotID, subscription.Subscriber.BotImage, alert.Alert.Hash)) +func (c *combinerCache) Exists(subscriber *domain.Subscriber, alert *protocol.AlertEvent) bool { + _, exists := c.cache.Get(encodeAlertCacheKey(subscriber.BotID, subscriber.BotImage, alert.Alert.Hash)) return exists } -func (c *combinerCache) Set(subscription *domain.CombinerBotSubscription, alert *protocol.AlertEvent) { - c.cache.Set(encodeAlertCacheKey(subscription.Subscriber.BotID, subscription.Subscriber.BotImage, alert.Alert.Hash), struct{}{}, cache.DefaultExpiration) +func (c *combinerCache) Set(subscriber *domain.Subscriber, alert *protocol.AlertEvent) { + c.cache.Set(encodeAlertCacheKey(subscriber.BotID, subscriber.BotImage, alert.Alert.Hash), struct{}{}, cache.DefaultExpiration) } // DumpToFile dumps the current cache into a file in JSON format, so that the cache can be used in a persistent way. diff --git a/feeds/combiner_test.go b/feeds/combiner_test.go index 294817f4..c8a42696 100644 --- a/feeds/combiner_test.go +++ b/feeds/combiner_test.go @@ -2,6 +2,7 @@ package feeds import ( "context" + "fmt" "sync" "testing" "time" @@ -12,89 +13,114 @@ import ( "github.com/forta-network/forta-core-go/protocol" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func Test_combinerFeed_Start(t *testing.T) { - type args struct { - rate uint64 - stopAfterFirstAlert bool - expectErr error - } - subscriberBot := "0xsubscriber" subscribeeBot := "0xsubscribee" + rate := uint64(time.Second.Milliseconds()) ctrl := gomock.NewController(t) - successfulMockClient := mock_graphql.NewMockClient(ctrl) - successfulMockClient.EXPECT().GetAlerts(gomock.Any(), gomock.Any(), gomock.Any()).Return( - []*protocol.AlertEvent{ - { - Alert: &protocol.AlertEvent_Alert{ - Hash: "0xaaaaa", - CreatedAt: time.Now().Format(time.RFC3339), - Source: &protocol.AlertEvent_Alert_Source{ - Bot: &protocol.AlertEvent_Alert_Bot{Id: subscribeeBot}, - }, + successfulAlertResponse := []*protocol.AlertEvent{ + { + Alert: &protocol.AlertEvent_Alert{ + Hash: "0xaaaaa", + CreatedAt: time.Now().Format(time.RFC3339), + Source: &protocol.AlertEvent_Alert_Source{ + Bot: &protocol.AlertEvent_Alert_Bot{Id: subscribeeBot}, }, }, - }, nil, - ) + }, + } - tests := []struct { - name string - args args - client graphql.Client - }{ - { - name: "successfully feeds alerts", - args: args{ - rate: uint64(time.Second.Milliseconds()), - stopAfterFirstAlert: true, - expectErr: context.Canceled, + var subscriptions []*domain.CombinerBotSubscription + for i := 0; i < 20; i++ { + subscriptions = append(subscriptions, &domain.CombinerBotSubscription{ + Subscription: &protocol.CombinerBotSubscription{ + BotId: fmt.Sprintf("0xbot%d", i), }, - client: successfulMockClient, - }, + Subscriber: &domain.Subscriber{BotID: subscriberBot, BotOwner: "0x", BotImage: "0x123"}, + }) } - for _, tt := range tests { - t.Run( - tt.name, func(t *testing.T) { - r := require.New(t) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() + // + // Test Case 1: can fetch subscriptions without any issues + // Setup: 1 subscriber bot, has 20 subscriptions + // Batch size is 10, meaning there should be 2 GetAlertsBatch requests + successfulMockClient := mock_graphql.NewMockClient(ctrl) + successfulMockClient.EXPECT().GetAlertsBatch(gomock.Any(), gomock.Any(), gomock.Any()). + Return(successfulAlertResponse, nil). + Times(2) - cf, err := NewCombinerFeedWithClient( - ctx, CombinerFeedConfig{ - QueryInterval: tt.args.rate, - }, tt.client, - ) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*50) + defer cancel() - r.NoError(err) + cf, err := NewCombinerFeedWithClient( + ctx, CombinerFeedConfig{ + QueryInterval: rate, + }, successfulMockClient, + ) - err = cf.AddSubscription( - &domain.CombinerBotSubscription{ - Subscription: &protocol.CombinerBotSubscription{ - BotId: subscribeeBot, - }, - Subscriber: &domain.Subscriber{BotID: subscriberBot, BotOwner: "0x", BotImage: "0x123"}, - }, - ) - r.NoError(err) - - errCh := cf.RegisterHandler( - func(evt *domain.AlertEvent) error { - if tt.args.stopAfterFirstAlert { - cancel() - } - return nil - }, - ) - cf.Start() - r.Error(tt.args.expectErr, <-errCh) + assert.NoError(t, err) + for _, subscription := range subscriptions { + err = cf.AddSubscription(subscription) + assert.NoError(t, err) + } + + errCh := cf.RegisterHandler( + func(evt *domain.AlertEvent) error { + cancel() + return nil + }, + ) + cf.Start() + assert.Error(t, context.Canceled, <-errCh) + + // + // + // Test Case 2: Retries in smaller chunks if there is a response size error + // Test Setup: 1 bot with 20 subscriptions + // Batch size is 10, there should be 2 requests. However, first request fails due to response size + // Resulting in 4 GetAlertsBatch calls. + // + // + ctrl = gomock.NewController(t) + responseTooBigClient := mock_graphql.NewMockClient(ctrl) + responseTooBigClient.EXPECT().GetAlertsBatch(gomock.Any(), gomock.Any(), gomock.Any()).Return( + nil, graphql.ErrResponseSizeTooBig, + ) + responseTooBigClient.EXPECT().GetAlertsBatch(gomock.Any(), gomock.Any(), gomock.Any()).Return( + successfulAlertResponse, nil, + ).Times(3) + + ctxTooBig, cancelTooBig := context.WithTimeout(context.Background(), time.Second*50) + defer cancelTooBig() + + cfTooBig, err := NewCombinerFeedWithClient( + ctxTooBig, CombinerFeedConfig{ + QueryInterval: rate, + }, responseTooBigClient, + ) + assert.NoError(t, err) + + for i := 0; i < 20; i++ { + err = cfTooBig.AddSubscription(&domain.CombinerBotSubscription{ + Subscription: &protocol.CombinerBotSubscription{ + BotId: fmt.Sprintf("0xbot%d", i), }, - ) + Subscriber: &domain.Subscriber{BotID: subscriberBot, BotOwner: "0x", BotImage: "0x123"}, + }) + assert.NoError(t, err) } + + errChTooBig := cfTooBig.RegisterHandler( + func(evt *domain.AlertEvent) error { + cancelTooBig() + return nil + }, + ) + cfTooBig.Start() + assert.Error(t, context.Canceled, <-errChTooBig) } func Test_combinerFeed_AddSubscription(t *testing.T) { diff --git a/go.mod b/go.mod index f9cc680d..e0cc8192 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/stretchr/testify v1.8.2 github.com/tidwall/gjson v1.14.4 github.com/tylertreat/BoomFilters v0.0.0-20210315201527-1a82519a3e43 + github.com/vektah/gqlparser/v2 v2.4.5 github.com/wealdtech/go-ens/v3 v3.5.2 golang.org/x/sync v0.1.0 google.golang.org/grpc v1.47.0 @@ -209,7 +210,6 @@ require ( github.com/tidwall/pretty v1.2.0 // indirect github.com/tklauser/go-sysconf v0.3.9 // indirect github.com/tklauser/numcpus v0.4.0 // indirect - github.com/vektah/gqlparser/v2 v2.4.5 // indirect github.com/wI2L/jsondiff v0.2.0 // indirect github.com/wealdtech/go-multicodec v1.4.0 // indirect github.com/whyrusleeping/base32 v0.0.0-20170828182744-c30ac30633cc // indirect