Skip to content

Commit

Permalink
Add support for batch query (#219)
Browse files Browse the repository at this point in the history
Implements batch querying and unit tests, which will also be used in the
next PR for catch-up of historical messages when streaming. Also got rid
of our use of NullInt32 in queries, which simplifies the go code a
little.

For now we let up to 10k topics be specified in these queries - we can
do proper benchmarking and SQL analyzing later to see how we should
adjust it.

Closes #218
  • Loading branch information
richardhuaaa authored Oct 15, 2024
1 parent 4047800 commit 53e0598
Show file tree
Hide file tree
Showing 14 changed files with 173 additions and 104 deletions.
2 changes: 1 addition & 1 deletion dev/generate
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
set -euo pipefail

./dev/gen_protos
sqlc generate
go generate ./...
rm -rf pkg/mocks/*
./dev/abigen
mockery
sqlc generate
94 changes: 88 additions & 6 deletions pkg/api/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"

"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/db"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/testutils"
Expand Down Expand Up @@ -116,15 +117,15 @@ func TestQueryEnvelopesByOriginator(t *testing.T) {
}

func TestQueryEnvelopesByTopic(t *testing.T) {
api, db, cleanup := apiTestUtils.NewTestAPIClient(t)
api, store, cleanup := apiTestUtils.NewTestAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, db)
db_rows := setupQueryTest(t, store)

resp, err := api.QueryEnvelopes(
context.Background(),
&message_api.QueryEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: [][]byte{[]byte("topicA")},
Topics: []db.Topic{db.Topic("topicA")},
LastSeen: nil,
},
Limit: 0,
Expand Down Expand Up @@ -152,16 +153,79 @@ func TestQueryEnvelopesFromLastSeen(t *testing.T) {
checkRowsMatchProtos(t, db_rows, []int{1, 3, 4}, resp.GetEnvelopes())
}

func TestQueryTopicFromLastSeen(t *testing.T) {
api, store, cleanup := apiTestUtils.NewTestAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, store)

resp, err := api.QueryEnvelopes(
context.Background(),
&message_api.QueryEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: []db.Topic{db.Topic("topicA")},
LastSeen: &message_api.VectorClock{
NodeIdToSequenceId: map[uint32]uint64{1: 2, 2: 1},
},
},
Limit: 0,
},
)
require.NoError(t, err)
checkRowsMatchProtos(t, db_rows, []int{4}, resp.GetEnvelopes())
}

func TestQueryMultipleTopicsFromLastSeen(t *testing.T) {
api, store, cleanup := apiTestUtils.NewTestAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, store)

resp, err := api.QueryEnvelopes(
context.Background(),
&message_api.QueryEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: []db.Topic{db.Topic("topicA"), db.Topic("topicB")},
LastSeen: &message_api.VectorClock{
NodeIdToSequenceId: map[uint32]uint64{1: 2, 2: 1},
},
},
Limit: 0,
},
)
require.NoError(t, err)
checkRowsMatchProtos(t, db_rows, []int{3, 4}, resp.GetEnvelopes())
}

func TestQueryMultipleOriginatorsFromLastSeen(t *testing.T) {
api, store, cleanup := apiTestUtils.NewTestAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, store)

resp, err := api.QueryEnvelopes(
context.Background(),
&message_api.QueryEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
OriginatorNodeIds: []uint32{1, 2},
LastSeen: &message_api.VectorClock{
NodeIdToSequenceId: map[uint32]uint64{1: 1, 2: 1},
},
},
Limit: 0,
},
)
require.NoError(t, err)
checkRowsMatchProtos(t, db_rows, []int{2, 3, 4}, resp.GetEnvelopes())
}

func TestQueryEnvelopesWithEmptyResult(t *testing.T) {
api, db, cleanup := apiTestUtils.NewTestAPIClient(t)
api, store, cleanup := apiTestUtils.NewTestAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, db)
db_rows := setupQueryTest(t, store)

resp, err := api.QueryEnvelopes(
context.Background(),
&message_api.QueryEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: [][]byte{[]byte("topicC")},
Topics: []db.Topic{db.Topic("topicC")},
},
Limit: 0,
},
Expand All @@ -170,6 +234,24 @@ func TestQueryEnvelopesWithEmptyResult(t *testing.T) {
checkRowsMatchProtos(t, db_rows, []int{}, resp.GetEnvelopes())
}

func TestInvalidQuery(t *testing.T) {
api, store, cleanup := apiTestUtils.NewTestAPIClient(t)
defer cleanup()
_ = setupQueryTest(t, store)

_, err := api.QueryEnvelopes(
context.Background(),
&message_api.QueryEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: []db.Topic{db.Topic("topicA")},
OriginatorNodeIds: []uint32{1},
},
Limit: 0,
},
)
require.Error(t, err)
}

func checkRowsMatchProtos(
t *testing.T,
allRows []queries.InsertGatewayEnvelopeParams,
Expand Down
42 changes: 17 additions & 25 deletions pkg/api/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,46 +176,38 @@ func (s *Service) validateQuery(
}
}

vc := query.GetLastSeen().GetNodeIdToSequenceId()
if len(vc) > maxVectorClockLength {
return fmt.Errorf(
"vector clock length exceeds maximum of %d",
maxVectorClockLength,
)
}

return nil
}

func (s *Service) queryReqToDBParams(
req *message_api.QueryEnvelopesRequest,
) (*queries.SelectGatewayEnvelopesParams, error) {
params := queries.SelectGatewayEnvelopesParams{
Topic: nil,
OriginatorNodeID: sql.NullInt32{},
RowLimit: sql.NullInt32{},
CursorNodeIds: nil,
CursorSequenceIds: nil,
}

query := req.GetQuery()
if err := s.validateQuery(query); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid query: %v", err)
}

// TODO(rich): Properly support batch queries
if len(query.GetTopics()) > 0 {
params.Topic = query.GetTopics()[0]
} else if len(query.GetOriginatorNodeIds()) > 0 {
params.OriginatorNodeID = db.NullInt32(int32(query.GetOriginatorNodeIds()[0]))
params := queries.SelectGatewayEnvelopesParams{
Topics: query.GetTopics(),
OriginatorNodeIds: make([]int32, 0, len(query.GetOriginatorNodeIds())),
RowLimit: int32(req.GetLimit()),
CursorNodeIds: nil,
CursorSequenceIds: nil,
}

vc := query.GetLastSeen().GetNodeIdToSequenceId()
if len(vc) > maxVectorClockLength {
return nil, status.Errorf(
codes.InvalidArgument,
"vector clock length exceeds maximum of %d",
maxVectorClockLength,
)
for _, o := range query.GetOriginatorNodeIds() {
params.OriginatorNodeIds = append(params.OriginatorNodeIds, int32(o))
}
db.SetVectorClock(&params, vc)

limit := req.GetLimit()
if limit > 0 && limit <= maxRequestedRows {
params.RowLimit = db.NullInt32(int32(limit))
}
db.SetVectorClock(&params, query.GetLastSeen().GetNodeIdToSequenceId())

return &params, nil
}
Expand Down
19 changes: 10 additions & 9 deletions pkg/api/subscribe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/api"
"github.com/xmtp/xmtpd/pkg/db"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/testutils"
Expand Down Expand Up @@ -127,24 +128,24 @@ func TestSubscribeEnvelopesAll(t *testing.T) {
}

func TestSubscribeEnvelopesByTopic(t *testing.T) {
client, db, cleanup := setupTest(t)
client, store, cleanup := setupTest(t)
defer cleanup()
insertInitialRows(t, db)
insertInitialRows(t, store)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
stream, err := client.SubscribeEnvelopes(
ctx,
&message_api.SubscribeEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: [][]byte{[]byte("topicA"), []byte("topicC")},
Topics: []db.Topic{db.Topic("topicA"), []byte("topicC")},
LastSeen: nil,
},
},
)
require.NoError(t, err)

insertAdditionalRows(t, db)
insertAdditionalRows(t, store)
validateUpdates(t, stream, []int{4})
}

Expand All @@ -171,9 +172,9 @@ func TestSubscribeEnvelopesByOriginator(t *testing.T) {
}

func TestSimultaneousSubscriptions(t *testing.T) {
client, db, cleanup := setupTest(t)
client, store, cleanup := setupTest(t)
defer cleanup()
insertInitialRows(t, db)
insertInitialRows(t, store)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand All @@ -189,7 +190,7 @@ func TestSimultaneousSubscriptions(t *testing.T) {
ctx,
&message_api.SubscribeEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: [][]byte{[]byte("topicB")},
Topics: []db.Topic{db.Topic("topicB")},
LastSeen: nil,
},
},
Expand All @@ -207,7 +208,7 @@ func TestSimultaneousSubscriptions(t *testing.T) {
)
require.NoError(t, err)

insertAdditionalRows(t, db)
insertAdditionalRows(t, store)
validateUpdates(t, stream1, []int{2, 3, 4})
validateUpdates(t, stream2, []int{2, 3})
validateUpdates(t, stream3, []int{3})
Expand All @@ -221,7 +222,7 @@ func TestSubscribeEnvelopesInvalidRequest(t *testing.T) {
context.Background(),
&message_api.SubscribeEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: [][]byte{[]byte("topicA")},
Topics: []db.Topic{db.Topic("topicA")},
OriginatorNodeIds: []uint32{1},
LastSeen: nil,
},
Expand Down
25 changes: 2 additions & 23 deletions pkg/db/queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,10 @@ ON CONFLICT
DO NOTHING;

-- name: SelectGatewayEnvelopes :many
WITH cursors AS (
SELECT
UNNEST(@cursor_node_ids::INT[]) AS cursor_node_id,
UNNEST(@cursor_sequence_ids::BIGINT[]) AS cursor_sequence_id
)
SELECT
gateway_envelopes.*
*
FROM
gateway_envelopes
-- Assumption: There is only one cursor per node ID. Caller must verify this
LEFT JOIN cursors ON gateway_envelopes.originator_node_id = cursors.cursor_node_id
WHERE (sqlc.narg('topic')::BYTEA IS NULL
OR length(@topic) = 0
OR topic = @topic)
AND (sqlc.narg('originator_node_id')::INT IS NULL
OR originator_node_id = @originator_node_id)
AND (cursor_sequence_id IS NULL
OR originator_sequence_id > cursor_sequence_id)
ORDER BY
-- Assumption: envelopes are inserted in sequence_id order per originator, therefore
-- gateway_time preserves sequence_id order
gateway_time,
originator_node_id,
originator_sequence_id ASC
LIMIT sqlc.narg('row_limit')::INT;
select_gateway_envelopes(@cursor_node_ids::INT[], @cursor_sequence_ids::BIGINT[], @topics::BYTEA[], @originator_node_ids::INT[], @row_limit::INT);

-- name: InsertStagedOriginatorEnvelope :one
SELECT
Expand Down
37 changes: 8 additions & 29 deletions pkg/db/queries/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pkg/db/subscription_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ func envelopesQuery(store *sql.DB) db.PollableDBQuery[queries.GatewayEnvelope, d
return func(ctx context.Context, lastSeen db.VectorClock, numRows int32) ([]queries.GatewayEnvelope, db.VectorClock, error) {
envs, err := queries.New(store).
SelectGatewayEnvelopes(ctx, *db.SetVectorClock(&queries.SelectGatewayEnvelopesParams{
OriginatorNodeID: db.NullInt32(1),
RowLimit: db.NullInt32(numRows),
OriginatorNodeIds: []int32{1},
RowLimit: numRows,
}, lastSeen))
if err != nil {
return nil, lastSeen, err
Expand Down
1 change: 1 addition & 0 deletions pkg/db/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
)

type VectorClock = map[uint32]uint64
type Topic = []byte

func NullInt32(v int32) sql.NullInt32 {
return sql.NullInt32{Int32: v, Valid: true}
Expand Down
Loading

0 comments on commit 53e0598

Please sign in to comment.