diff --git a/dev/generate b/dev/generate index c4ce285b..eb85f658 100755 --- a/dev/generate +++ b/dev/generate @@ -3,8 +3,8 @@ set -euo pipefail ./dev/gen_protos +sqlc generate go generate ./... rm -rf pkg/mocks/* ./dev/abigen mockery -sqlc generate diff --git a/pkg/api/query_test.go b/pkg/api/query_test.go index 55cf6a1c..bd7c232f 100644 --- a/pkg/api/query_test.go +++ b/pkg/api/query_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/stretchr/testify/require" + "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" "github.com/xmtp/xmtpd/pkg/testutils" @@ -116,15 +117,15 @@ func TestQueryEnvelopesByOriginator(t *testing.T) { } func TestQueryEnvelopesByTopic(t *testing.T) { - api, db, cleanup := apiTestUtils.NewTestAPIClient(t) + api, store, cleanup := apiTestUtils.NewTestAPIClient(t) defer cleanup() - db_rows := setupQueryTest(t, db) + db_rows := setupQueryTest(t, store) resp, err := api.QueryEnvelopes( context.Background(), &message_api.QueryEnvelopesRequest{ Query: &message_api.EnvelopesQuery{ - Topics: [][]byte{[]byte("topicA")}, + Topics: []db.Topic{db.Topic("topicA")}, LastSeen: nil, }, Limit: 0, @@ -152,16 +153,79 @@ func TestQueryEnvelopesFromLastSeen(t *testing.T) { checkRowsMatchProtos(t, db_rows, []int{1, 3, 4}, resp.GetEnvelopes()) } +func TestQueryTopicFromLastSeen(t *testing.T) { + api, store, cleanup := apiTestUtils.NewTestAPIClient(t) + defer cleanup() + db_rows := setupQueryTest(t, store) + + resp, err := api.QueryEnvelopes( + context.Background(), + &message_api.QueryEnvelopesRequest{ + Query: &message_api.EnvelopesQuery{ + Topics: []db.Topic{db.Topic("topicA")}, + LastSeen: &message_api.VectorClock{ + NodeIdToSequenceId: map[uint32]uint64{1: 2, 2: 1}, + }, + }, + Limit: 0, + }, + ) + require.NoError(t, err) + checkRowsMatchProtos(t, db_rows, []int{4}, resp.GetEnvelopes()) +} + +func TestQueryMultipleTopicsFromLastSeen(t *testing.T) { + api, store, cleanup := apiTestUtils.NewTestAPIClient(t) + defer cleanup() + db_rows := setupQueryTest(t, store) + + resp, err := api.QueryEnvelopes( + context.Background(), + &message_api.QueryEnvelopesRequest{ + Query: &message_api.EnvelopesQuery{ + Topics: []db.Topic{db.Topic("topicA"), db.Topic("topicB")}, + LastSeen: &message_api.VectorClock{ + NodeIdToSequenceId: map[uint32]uint64{1: 2, 2: 1}, + }, + }, + Limit: 0, + }, + ) + require.NoError(t, err) + checkRowsMatchProtos(t, db_rows, []int{3, 4}, resp.GetEnvelopes()) +} + +func TestQueryMultipleOriginatorsFromLastSeen(t *testing.T) { + api, store, cleanup := apiTestUtils.NewTestAPIClient(t) + defer cleanup() + db_rows := setupQueryTest(t, store) + + resp, err := api.QueryEnvelopes( + context.Background(), + &message_api.QueryEnvelopesRequest{ + Query: &message_api.EnvelopesQuery{ + OriginatorNodeIds: []uint32{1, 2}, + LastSeen: &message_api.VectorClock{ + NodeIdToSequenceId: map[uint32]uint64{1: 1, 2: 1}, + }, + }, + Limit: 0, + }, + ) + require.NoError(t, err) + checkRowsMatchProtos(t, db_rows, []int{2, 3, 4}, resp.GetEnvelopes()) +} + func TestQueryEnvelopesWithEmptyResult(t *testing.T) { - api, db, cleanup := apiTestUtils.NewTestAPIClient(t) + api, store, cleanup := apiTestUtils.NewTestAPIClient(t) defer cleanup() - db_rows := setupQueryTest(t, db) + db_rows := setupQueryTest(t, store) resp, err := api.QueryEnvelopes( context.Background(), &message_api.QueryEnvelopesRequest{ Query: &message_api.EnvelopesQuery{ - Topics: [][]byte{[]byte("topicC")}, + Topics: []db.Topic{db.Topic("topicC")}, }, Limit: 0, }, @@ -170,6 +234,24 @@ func TestQueryEnvelopesWithEmptyResult(t *testing.T) { checkRowsMatchProtos(t, db_rows, []int{}, resp.GetEnvelopes()) } +func TestInvalidQuery(t *testing.T) { + api, store, cleanup := apiTestUtils.NewTestAPIClient(t) + defer cleanup() + _ = setupQueryTest(t, store) + + _, err := api.QueryEnvelopes( + context.Background(), + &message_api.QueryEnvelopesRequest{ + Query: &message_api.EnvelopesQuery{ + Topics: []db.Topic{db.Topic("topicA")}, + OriginatorNodeIds: []uint32{1}, + }, + Limit: 0, + }, + ) + require.Error(t, err) +} + func checkRowsMatchProtos( t *testing.T, allRows []queries.InsertGatewayEnvelopeParams, diff --git a/pkg/api/service.go b/pkg/api/service.go index cd343afc..5ba3dbbc 100644 --- a/pkg/api/service.go +++ b/pkg/api/service.go @@ -176,46 +176,38 @@ func (s *Service) validateQuery( } } + vc := query.GetLastSeen().GetNodeIdToSequenceId() + if len(vc) > maxVectorClockLength { + return fmt.Errorf( + "vector clock length exceeds maximum of %d", + maxVectorClockLength, + ) + } + return nil } func (s *Service) queryReqToDBParams( req *message_api.QueryEnvelopesRequest, ) (*queries.SelectGatewayEnvelopesParams, error) { - params := queries.SelectGatewayEnvelopesParams{ - Topic: nil, - OriginatorNodeID: sql.NullInt32{}, - RowLimit: sql.NullInt32{}, - CursorNodeIds: nil, - CursorSequenceIds: nil, - } - query := req.GetQuery() if err := s.validateQuery(query); err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid query: %v", err) } - // TODO(rich): Properly support batch queries - if len(query.GetTopics()) > 0 { - params.Topic = query.GetTopics()[0] - } else if len(query.GetOriginatorNodeIds()) > 0 { - params.OriginatorNodeID = db.NullInt32(int32(query.GetOriginatorNodeIds()[0])) + params := queries.SelectGatewayEnvelopesParams{ + Topics: query.GetTopics(), + OriginatorNodeIds: make([]int32, 0, len(query.GetOriginatorNodeIds())), + RowLimit: int32(req.GetLimit()), + CursorNodeIds: nil, + CursorSequenceIds: nil, } - vc := query.GetLastSeen().GetNodeIdToSequenceId() - if len(vc) > maxVectorClockLength { - return nil, status.Errorf( - codes.InvalidArgument, - "vector clock length exceeds maximum of %d", - maxVectorClockLength, - ) + for _, o := range query.GetOriginatorNodeIds() { + params.OriginatorNodeIds = append(params.OriginatorNodeIds, int32(o)) } - db.SetVectorClock(¶ms, vc) - limit := req.GetLimit() - if limit > 0 && limit <= maxRequestedRows { - params.RowLimit = db.NullInt32(int32(limit)) - } + db.SetVectorClock(¶ms, query.GetLastSeen().GetNodeIdToSequenceId()) return ¶ms, nil } diff --git a/pkg/api/subscribe_test.go b/pkg/api/subscribe_test.go index 851eed6b..af6a1490 100644 --- a/pkg/api/subscribe_test.go +++ b/pkg/api/subscribe_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" "github.com/xmtp/xmtpd/pkg/api" + "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" "github.com/xmtp/xmtpd/pkg/testutils" @@ -127,9 +128,9 @@ func TestSubscribeEnvelopesAll(t *testing.T) { } func TestSubscribeEnvelopesByTopic(t *testing.T) { - client, db, cleanup := setupTest(t) + client, store, cleanup := setupTest(t) defer cleanup() - insertInitialRows(t, db) + insertInitialRows(t, store) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -137,14 +138,14 @@ func TestSubscribeEnvelopesByTopic(t *testing.T) { ctx, &message_api.SubscribeEnvelopesRequest{ Query: &message_api.EnvelopesQuery{ - Topics: [][]byte{[]byte("topicA"), []byte("topicC")}, + Topics: []db.Topic{db.Topic("topicA"), []byte("topicC")}, LastSeen: nil, }, }, ) require.NoError(t, err) - insertAdditionalRows(t, db) + insertAdditionalRows(t, store) validateUpdates(t, stream, []int{4}) } @@ -171,9 +172,9 @@ func TestSubscribeEnvelopesByOriginator(t *testing.T) { } func TestSimultaneousSubscriptions(t *testing.T) { - client, db, cleanup := setupTest(t) + client, store, cleanup := setupTest(t) defer cleanup() - insertInitialRows(t, db) + insertInitialRows(t, store) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -189,7 +190,7 @@ func TestSimultaneousSubscriptions(t *testing.T) { ctx, &message_api.SubscribeEnvelopesRequest{ Query: &message_api.EnvelopesQuery{ - Topics: [][]byte{[]byte("topicB")}, + Topics: []db.Topic{db.Topic("topicB")}, LastSeen: nil, }, }, @@ -207,7 +208,7 @@ func TestSimultaneousSubscriptions(t *testing.T) { ) require.NoError(t, err) - insertAdditionalRows(t, db) + insertAdditionalRows(t, store) validateUpdates(t, stream1, []int{2, 3, 4}) validateUpdates(t, stream2, []int{2, 3}) validateUpdates(t, stream3, []int{3}) @@ -221,7 +222,7 @@ func TestSubscribeEnvelopesInvalidRequest(t *testing.T) { context.Background(), &message_api.SubscribeEnvelopesRequest{ Query: &message_api.EnvelopesQuery{ - Topics: [][]byte{[]byte("topicA")}, + Topics: []db.Topic{db.Topic("topicA")}, OriginatorNodeIds: []uint32{1}, LastSeen: nil, }, diff --git a/pkg/db/queries.sql b/pkg/db/queries.sql index b6c5e457..69cc2b20 100644 --- a/pkg/db/queries.sql +++ b/pkg/db/queries.sql @@ -19,31 +19,10 @@ ON CONFLICT DO NOTHING; -- name: SelectGatewayEnvelopes :many -WITH cursors AS ( - SELECT - UNNEST(@cursor_node_ids::INT[]) AS cursor_node_id, - UNNEST(@cursor_sequence_ids::BIGINT[]) AS cursor_sequence_id -) SELECT - gateway_envelopes.* + * FROM - gateway_envelopes - -- Assumption: There is only one cursor per node ID. Caller must verify this - LEFT JOIN cursors ON gateway_envelopes.originator_node_id = cursors.cursor_node_id -WHERE (sqlc.narg('topic')::BYTEA IS NULL - OR length(@topic) = 0 - OR topic = @topic) -AND (sqlc.narg('originator_node_id')::INT IS NULL - OR originator_node_id = @originator_node_id) -AND (cursor_sequence_id IS NULL - OR originator_sequence_id > cursor_sequence_id) -ORDER BY - -- Assumption: envelopes are inserted in sequence_id order per originator, therefore - -- gateway_time preserves sequence_id order - gateway_time, - originator_node_id, - originator_sequence_id ASC -LIMIT sqlc.narg('row_limit')::INT; + select_gateway_envelopes(@cursor_node_ids::INT[], @cursor_sequence_ids::BIGINT[], @topics::BYTEA[], @originator_node_ids::INT[], @row_limit::INT); -- name: InsertStagedOriginatorEnvelope :one SELECT diff --git a/pkg/db/queries/queries.sql.go b/pkg/db/queries/queries.sql.go index adbe9133..d737282e 100644 --- a/pkg/db/queries/queries.sql.go +++ b/pkg/db/queries/queries.sql.go @@ -212,48 +212,27 @@ func (q *Queries) RevokeAddressFromLog(ctx context.Context, arg RevokeAddressFro } const selectGatewayEnvelopes = `-- name: SelectGatewayEnvelopes :many -WITH cursors AS ( - SELECT - UNNEST($4::INT[]) AS cursor_node_id, - UNNEST($5::BIGINT[]) AS cursor_sequence_id -) SELECT - gateway_envelopes.gateway_time, gateway_envelopes.originator_node_id, gateway_envelopes.originator_sequence_id, gateway_envelopes.topic, gateway_envelopes.originator_envelope + gateway_time, originator_node_id, originator_sequence_id, topic, originator_envelope FROM - gateway_envelopes - -- Assumption: There is only one cursor per node ID. Caller must verify this - LEFT JOIN cursors ON gateway_envelopes.originator_node_id = cursors.cursor_node_id -WHERE ($1::BYTEA IS NULL - OR length($1) = 0 - OR topic = $1) -AND ($2::INT IS NULL - OR originator_node_id = $2) -AND (cursor_sequence_id IS NULL - OR originator_sequence_id > cursor_sequence_id) -ORDER BY - -- Assumption: envelopes are inserted in sequence_id order per originator, therefore - -- gateway_time preserves sequence_id order - gateway_time, - originator_node_id, - originator_sequence_id ASC -LIMIT $3::INT + select_gateway_envelopes($1::INT[], $2::BIGINT[], $3::BYTEA[], $4::INT[], $5::INT) ` type SelectGatewayEnvelopesParams struct { - Topic []byte - OriginatorNodeID sql.NullInt32 - RowLimit sql.NullInt32 CursorNodeIds []int32 CursorSequenceIds []int64 + Topics [][]byte + OriginatorNodeIds []int32 + RowLimit int32 } func (q *Queries) SelectGatewayEnvelopes(ctx context.Context, arg SelectGatewayEnvelopesParams) ([]GatewayEnvelope, error) { rows, err := q.db.QueryContext(ctx, selectGatewayEnvelopes, - arg.Topic, - arg.OriginatorNodeID, - arg.RowLimit, pq.Array(arg.CursorNodeIds), pq.Array(arg.CursorSequenceIds), + pq.Array(arg.Topics), + pq.Array(arg.OriginatorNodeIds), + arg.RowLimit, ) if err != nil { return nil, err diff --git a/pkg/db/subscription_test.go b/pkg/db/subscription_test.go index 42e52057..03c8ff6f 100644 --- a/pkg/db/subscription_test.go +++ b/pkg/db/subscription_test.go @@ -44,8 +44,8 @@ func envelopesQuery(store *sql.DB) db.PollableDBQuery[queries.GatewayEnvelope, d return func(ctx context.Context, lastSeen db.VectorClock, numRows int32) ([]queries.GatewayEnvelope, db.VectorClock, error) { envs, err := queries.New(store). SelectGatewayEnvelopes(ctx, *db.SetVectorClock(&queries.SelectGatewayEnvelopesParams{ - OriginatorNodeID: db.NullInt32(1), - RowLimit: db.NullInt32(numRows), + OriginatorNodeIds: []int32{1}, + RowLimit: numRows, }, lastSeen)) if err != nil { return nil, lastSeen, err diff --git a/pkg/db/types.go b/pkg/db/types.go index 59789305..a0b4dc93 100644 --- a/pkg/db/types.go +++ b/pkg/db/types.go @@ -7,6 +7,7 @@ import ( ) type VectorClock = map[uint32]uint64 +type Topic = []byte func NullInt32(v int32) sql.NullInt32 { return sql.NullInt32{Int32: v, Valid: true} diff --git a/pkg/indexer/e2e_test.go b/pkg/indexer/e2e_test.go index 935109c7..e6643f17 100644 --- a/pkg/indexer/e2e_test.go +++ b/pkg/indexer/e2e_test.go @@ -67,7 +67,7 @@ func TestStoreMessages(t *testing.T) { envelopes, err := querier.SelectGatewayEnvelopes( context.Background(), queries.SelectGatewayEnvelopesParams{ - Topic: topic, + Topics: [][]byte{topic}, }, ) require.NoError(t, err) diff --git a/pkg/indexer/storer/groupMessage_test.go b/pkg/indexer/storer/groupMessage_test.go index 4d9241b7..5d425996 100644 --- a/pkg/indexer/storer/groupMessage_test.go +++ b/pkg/indexer/storer/groupMessage_test.go @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/require" "github.com/xmtp/xmtpd/pkg/abis" "github.com/xmtp/xmtpd/pkg/blockchain" - "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/testutils" "github.com/xmtp/xmtpd/pkg/utils" @@ -57,7 +56,7 @@ func TestStoreGroupMessages(t *testing.T) { envelopes, queryErr := storer.queries.SelectGatewayEnvelopes( ctx, - queries.SelectGatewayEnvelopesParams{OriginatorNodeID: db.NullInt32(0)}, + queries.SelectGatewayEnvelopesParams{OriginatorNodeIds: []int32{0}}, ) require.NoError(t, queryErr) @@ -93,7 +92,7 @@ func TestStoreGroupMessageDuplicate(t *testing.T) { envelopes, queryErr := storer.queries.SelectGatewayEnvelopes( ctx, - queries.SelectGatewayEnvelopesParams{OriginatorNodeID: db.NullInt32(0)}, + queries.SelectGatewayEnvelopesParams{OriginatorNodeIds: []int32{0}}, ) require.NoError(t, queryErr) diff --git a/pkg/indexer/storer/identityUpdate.go b/pkg/indexer/storer/identityUpdate.go index 831aea09..8b79d97a 100644 --- a/pkg/indexer/storer/identityUpdate.go +++ b/pkg/indexer/storer/identityUpdate.go @@ -195,9 +195,9 @@ func (s *IdentityUpdateStorer) validateIdentityUpdate( gatewayEnvelopes, err := querier.SelectGatewayEnvelopes( ctx, queries.SelectGatewayEnvelopesParams{ - Topic: []byte(BuildInboxTopic(inboxId)), - OriginatorNodeID: sql.NullInt32{Int32: IDENTITY_UPDATE_ORIGINATOR_ID, Valid: true}, - RowLimit: sql.NullInt32{Int32: 256, Valid: true}, + Topics: []db.Topic{db.Topic(BuildInboxTopic(inboxId))}, + OriginatorNodeIds: []int32{IDENTITY_UPDATE_ORIGINATOR_ID}, + RowLimit: 256, }, ) if err != nil { diff --git a/pkg/indexer/storer/identityUpdate_test.go b/pkg/indexer/storer/identityUpdate_test.go index f4f28bfc..aa44df0e 100644 --- a/pkg/indexer/storer/identityUpdate_test.go +++ b/pkg/indexer/storer/identityUpdate_test.go @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/require" "github.com/xmtp/xmtpd/pkg/abis" "github.com/xmtp/xmtpd/pkg/blockchain" - "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/mlsvalidate" mlsvalidateMock "github.com/xmtp/xmtpd/pkg/mocks/mlsvalidate" @@ -83,7 +82,8 @@ func TestStoreIdentityUpdate(t *testing.T) { envelopes, queryErr := querier.SelectGatewayEnvelopes( ctx, queries.SelectGatewayEnvelopesParams{ - OriginatorNodeID: db.NullInt32(IDENTITY_UPDATE_ORIGINATOR_ID), + OriginatorNodeIds: []int32{IDENTITY_UPDATE_ORIGINATOR_ID}, + RowLimit: 10, }, ) require.NoError(t, queryErr) diff --git a/pkg/migrations/00002_select-gateway-envelopes.down.sql b/pkg/migrations/00002_select-gateway-envelopes.down.sql new file mode 100644 index 00000000..31652d99 --- /dev/null +++ b/pkg/migrations/00002_select-gateway-envelopes.down.sql @@ -0,0 +1,2 @@ +DROP FUNCTION select_gateway_envelopes; + diff --git a/pkg/migrations/00002_select-gateway-envelopes.up.sql b/pkg/migrations/00002_select-gateway-envelopes.up.sql new file mode 100644 index 00000000..44a97144 --- /dev/null +++ b/pkg/migrations/00002_select-gateway-envelopes.up.sql @@ -0,0 +1,34 @@ +-- pgFormatter-ignore +CREATE FUNCTION select_gateway_envelopes(cursor_node_ids INT[], cursor_sequence_ids BIGINT[], topics BYTEA[], originator_node_ids INT[], row_limit INT) + RETURNS SETOF gateway_envelopes + AS $$ +DECLARE + num_topics INT := COALESCE(ARRAY_LENGTH(topics, 1), 0); + num_originators INT := COALESCE(ARRAY_LENGTH(originator_node_ids, 1), 0); +BEGIN + RETURN QUERY + WITH cursors AS ( + SELECT + UNNEST(cursor_node_ids) AS cursor_node_id, + UNNEST(cursor_sequence_ids) AS cursor_sequence_id + ) + SELECT + gateway_envelopes.* + FROM + gateway_envelopes + -- Assumption: There is only one cursor per node ID. Caller must verify this + LEFT JOIN cursors ON gateway_envelopes.originator_node_id = cursors.cursor_node_id + WHERE (num_topics = 0 OR topic = ANY (topics)) + AND (num_originators = 0 OR originator_node_id = ANY (originator_node_ids)) + AND originator_sequence_id > COALESCE(cursor_sequence_id, 0) + ORDER BY + -- Assumption: envelopes are inserted in sequence_id order per originator, therefore + -- gateway_time preserves sequence_id order + gateway_time, + originator_node_id, + originator_sequence_id ASC + LIMIT NULLIF(row_limit, 0); +END; +$$ +LANGUAGE plpgsql; +