From 2b75efbf7470f6a7fe7a1dffbd4ebce2790a95e2 Mon Sep 17 00:00:00 2001 From: Richard Hua Date: Sun, 25 Aug 2024 22:33:02 -0700 Subject: [PATCH] Add publish worker --- pkg/api/publishWorker.go | 137 ++++++++++++++++++++++++ pkg/api/service.go | 97 ++++++++++++++--- pkg/api/service_test.go | 102 ++++++++++++++++-- pkg/db/queries.sql | 2 +- pkg/db/queries/models.go | 1 + pkg/db/queries/queries.sql.go | 29 +++-- pkg/migrations/00001_init-schema.up.sql | 18 ++-- pkg/registrant/registrant.go | 4 + 8 files changed, 346 insertions(+), 44 deletions(-) create mode 100644 pkg/api/publishWorker.go diff --git a/pkg/api/publishWorker.go b/pkg/api/publishWorker.go new file mode 100644 index 00000000..c07c0caf --- /dev/null +++ b/pkg/api/publishWorker.go @@ -0,0 +1,137 @@ +package api + +import ( + "context" + "database/sql" + "time" + + "github.com/xmtp/xmtpd/pkg/db" + "github.com/xmtp/xmtpd/pkg/db/queries" + "github.com/xmtp/xmtpd/pkg/registrant" + "go.uber.org/zap" + "google.golang.org/protobuf/proto" +) + +type PublishWorker struct { + ctx context.Context + log *zap.Logger + listener <-chan []queries.StagedOriginatorEnvelope + registrant *registrant.Registrant + store *sql.DB + subscription db.DBSubscription[queries.StagedOriginatorEnvelope] +} + +func StartPublishWorker( + ctx context.Context, + log *zap.Logger, + reg *registrant.Registrant, + store *sql.DB, + notifier <-chan bool, +) (*PublishWorker, error) { + query := func(ctx context.Context, lastSeenID int64, numRows int32) ([]queries.StagedOriginatorEnvelope, int64, error) { + results, err := queries.New(store).SelectStagedOriginatorEnvelopes( + ctx, + queries.SelectStagedOriginatorEnvelopesParams{ + LastSeenID: lastSeenID, + NumRows: numRows, + }, + ) + if err != nil { + return nil, 0, err + } + if len(results) > 0 { + lastSeenID = results[len(results)-1].ID + } + return results, lastSeenID, nil + } + subscription := db.NewDBSubscription( + ctx, + log, + query, + 0, // lastSeenID + db.PollingOptions{Interval: time.Second, Notifier: notifier, NumRows: 100}, + ) + listener, err := subscription.Start() + if err != nil { + return nil, err + } + + worker := &PublishWorker{ + ctx: ctx, + log: log, + subscription: *subscription, + listener: listener, + registrant: reg, + store: store, + } + go worker.start() + + return worker, nil +} + +func (p *PublishWorker) start() { + for { + select { + case <-p.ctx.Done(): + return + case new_batch := <-p.listener: + for _, stagedEnv := range new_batch { + for !p.publishStagedEnvelope(stagedEnv) { + // Infinite retry on failure to publish; we cannot + // continue to the next envelope until this one is processed + time.Sleep(time.Second) + } + } + } + } +} + +func (p *PublishWorker) publishStagedEnvelope(stagedEnv queries.StagedOriginatorEnvelope) bool { + logger := p.log.With(zap.Int64("sequenceID", stagedEnv.ID)) + originatorEnv, err := p.registrant.SignStagedEnvelope(stagedEnv) + if err != nil { + logger.Error( + "Failed to sign staged envelope", + zap.Error(err), + ) + return false + } + originatorBytes, err := proto.Marshal(originatorEnv) + if err != nil { + logger.Error("Failed to marshal originator envelope", zap.Error(err)) + return false + } + + q := queries.New(p.store) + + // On unique constraint conflicts, no error is thrown, but numRows is 0 + inserted, err := q.InsertGatewayEnvelope( + p.ctx, + queries.InsertGatewayEnvelopeParams{ + OriginatorID: int32(p.registrant.NodeID()), + OriginatorSequenceID: stagedEnv.ID, + Topic: stagedEnv.Topic, + OriginatorEnvelope: originatorBytes, + }, + ) + if err != nil { + logger.Error("Failed to insert gateway envelope", zap.Error(err)) + return false + } else if inserted == 0 { + // Envelope was already inserted by another worker + logger.Debug("Envelope already inserted") + } + + // Try to delete the row regardless of if the gateway envelope was inserted elsewhere + deleted, err := q.DeleteStagedOriginatorEnvelope(context.Background(), stagedEnv.ID) + if err != nil { + logger.Error("Failed to delete staged envelope", zap.Error(err)) + // Envelope is already inserted, so it is safe to continue + return true + } else if deleted == 0 { + // Envelope was already deleted by another worker + logger.Debug("Envelope already deleted") + } + + return true +} diff --git a/pkg/api/service.go b/pkg/api/service.go index 32367558..c20ab12c 100644 --- a/pkg/api/service.go +++ b/pkg/api/service.go @@ -17,19 +17,33 @@ import ( type Service struct { message_api.UnimplementedReplicationApiServer - ctx context.Context - log *zap.Logger - registrant *registrant.Registrant - queries *queries.Queries + ctx context.Context + log *zap.Logger + notifyStagedPublish chan<- bool + registrant *registrant.Registrant + store *sql.DB + worker *PublishWorker } func NewReplicationApiService( ctx context.Context, log *zap.Logger, registrant *registrant.Registrant, - writerDB *sql.DB, + store *sql.DB, ) (*Service, error) { - return &Service{ctx: ctx, log: log, registrant: registrant, queries: queries.New(writerDB)}, nil + notifier := make(chan bool, 1) + worker, err := StartPublishWorker(ctx, log, registrant, store, notifier) + if err != nil { + return nil, err + } + return &Service{ + ctx: ctx, + log: log, + notifyStagedPublish: notifier, + registrant: registrant, + store: store, + worker: worker, + }, nil } func (s *Service) Close() { @@ -54,28 +68,37 @@ func (s *Service) PublishEnvelope( ctx context.Context, req *message_api.PublishEnvelopeRequest, ) (*message_api.PublishEnvelopeResponse, error) { - payerEnv := req.GetPayerEnvelope() - clientBytes := payerEnv.GetUnsignedClientEnvelope() - sig := payerEnv.GetPayerSignature() - if (clientBytes == nil) || (sig == nil) { - return nil, status.Errorf(codes.InvalidArgument, "missing envelope or signature") + clientEnv, err := s.validatePayerInfo(req.GetPayerEnvelope()) + if err != nil { + return nil, err } - // TODO(rich): Verify payer signature - // TODO(rich): Verify all originators have synced past `last_originator_sids` - // TODO(rich): Check that the blockchain sequence ID is equal to the latest on the group - // TODO(rich): Perform any payload-specific validation (e.g. identity updates) + + topic, err := s.validateClientInfo(clientEnv) + if err != nil { + return nil, err + } + // TODO(rich): If it is a commit, publish it to blockchain instead - payerBytes, err := proto.Marshal(payerEnv) + payerBytes, err := proto.Marshal(req.GetPayerEnvelope()) if err != nil { return nil, status.Errorf(codes.Internal, "could not marshal envelope: %v", err) } - stagedEnv, err := s.queries.InsertStagedOriginatorEnvelope(ctx, payerBytes) + stagedEnv, err := queries.New(s.store). + InsertStagedOriginatorEnvelope(ctx, queries.InsertStagedOriginatorEnvelopeParams{ + Topic: topic, + PayerEnvelope: payerBytes, + }) if err != nil { return nil, status.Errorf(codes.Internal, "could not insert staged envelope: %v", err) } + select { + case s.notifyStagedPublish <- true: + default: + } + originatorEnv, err := s.registrant.SignStagedEnvelope(stagedEnv) if err != nil { return nil, status.Errorf(codes.Internal, "could not sign envelope: %v", err) @@ -83,3 +106,43 @@ func (s *Service) PublishEnvelope( return &message_api.PublishEnvelopeResponse{OriginatorEnvelope: originatorEnv}, nil } + +func (s *Service) validatePayerInfo( + payerEnv *message_api.PayerEnvelope, +) (*message_api.ClientEnvelope, error) { + clientBytes := payerEnv.GetUnsignedClientEnvelope() + sig := payerEnv.GetPayerSignature() + if (clientBytes == nil) || (sig == nil) { + return nil, status.Errorf(codes.InvalidArgument, "missing envelope or signature") + } + // TODO(rich): Verify payer signature + + clientEnv := &message_api.ClientEnvelope{} + err := proto.Unmarshal(clientBytes, clientEnv) + if err != nil { + return nil, status.Errorf( + codes.InvalidArgument, + "could not unmarshal client envelope: %v", + err, + ) + } + + return clientEnv, nil +} + +func (s *Service) validateClientInfo(clientEnv *message_api.ClientEnvelope) ([]byte, error) { + if clientEnv.GetAad().GetTargetOriginator() != uint32(s.registrant.NodeID()) { + return nil, status.Errorf(codes.InvalidArgument, "invalid target originator") + } + + topic := clientEnv.GetAad().GetTargetTopic() + if len(topic) == 0 { + return nil, status.Errorf(codes.InvalidArgument, "missing target topic") + } + + // TODO(rich): Verify all originators have synced past `last_originator_sids` + // TODO(rich): Check that the blockchain sequence ID is equal to the latest on the group + // TODO(rich): Perform any payload-specific validation (e.g. identity updates) + + return topic, nil +} diff --git a/pkg/api/service_test.go b/pkg/api/service_test.go index e33fa4b8..c91ba4cc 100644 --- a/pkg/api/service_test.go +++ b/pkg/api/service_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "testing" + "time" "github.com/ethereum/go-ethereum/crypto" "github.com/stretchr/testify/require" @@ -40,17 +41,41 @@ func newTestService(t *testing.T) (*Service, *sql.DB, func()) { } } +func createClientEnvelope() *message_api.ClientEnvelope { + return &message_api.ClientEnvelope{ + Payload: nil, + Aad: &message_api.AuthenticatedData{ + TargetOriginator: 1, + TargetTopic: []byte{0x5}, + LastOriginatorSids: []uint64{}, + }, + } +} + +func createPayerEnvelope( + t *testing.T, + clientEnv ...*message_api.ClientEnvelope, +) *message_api.PayerEnvelope { + if len(clientEnv) == 0 { + clientEnv = append(clientEnv, createClientEnvelope()) + } + clientEnvBytes, err := proto.Marshal(clientEnv[0]) + require.NoError(t, err) + + return &message_api.PayerEnvelope{ + UnsignedClientEnvelope: clientEnvBytes, + PayerSignature: &associations.RecoverableEcdsaSignature{}, + } +} + func TestSimplePublish(t *testing.T) { - svc, _, cleanup := newTestService(t) + svc, db, cleanup := newTestService(t) defer cleanup() resp, err := svc.PublishEnvelope( context.Background(), &message_api.PublishEnvelopeRequest{ - PayerEnvelope: &message_api.PayerEnvelope{ - UnsignedClientEnvelope: []byte{0x5}, - PayerSignature: &associations.RecoverableEcdsaSignature{}, - }, + PayerEnvelope: createPayerEnvelope(t), }, ) require.NoError(t, err) @@ -61,7 +86,70 @@ func TestSimplePublish(t *testing.T) { t, proto.Unmarshal(resp.GetOriginatorEnvelope().GetUnsignedOriginatorEnvelope(), unsignedEnv), ) - require.Equal(t, uint8(0x5), unsignedEnv.GetPayerEnvelope().GetUnsignedClientEnvelope()[0]) + clientEnv := &message_api.ClientEnvelope{} + require.NoError( + t, + proto.Unmarshal(unsignedEnv.GetPayerEnvelope().GetUnsignedClientEnvelope(), clientEnv), + ) + require.Equal(t, uint8(0x5), clientEnv.Aad.GetTargetTopic()[0]) - // TODO(rich) Test that the published envelope is retrievable via the query API + // Check that the envelope was published to the database after a delay + require.Eventually(t, func() bool { + envs, err := queries.New(db). + SelectGatewayEnvelopes(context.Background(), queries.SelectGatewayEnvelopesParams{}) + require.NoError(t, err) + + if len(envs) != 1 { + return false + } + + originatorEnv := &message_api.OriginatorEnvelope{} + require.NoError(t, proto.Unmarshal(envs[0].OriginatorEnvelope, originatorEnv)) + return proto.Equal(originatorEnv, resp.GetOriginatorEnvelope()) + }, 500*time.Millisecond, 50*time.Millisecond) +} + +func TestUnmarshalError(t *testing.T) { + svc, _, cleanup := newTestService(t) + defer cleanup() + + envelope := createPayerEnvelope(t) + envelope.UnsignedClientEnvelope = []byte("invalidbytes") + _, err := svc.PublishEnvelope( + context.Background(), + &message_api.PublishEnvelopeRequest{ + PayerEnvelope: envelope, + }, + ) + require.ErrorContains(t, err, "unmarshal") +} + +func TestMismatchingOriginator(t *testing.T) { + svc, _, cleanup := newTestService(t) + defer cleanup() + + clientEnv := createClientEnvelope() + clientEnv.Aad.TargetOriginator = 2 + _, err := svc.PublishEnvelope( + context.Background(), + &message_api.PublishEnvelopeRequest{ + PayerEnvelope: createPayerEnvelope(t, clientEnv), + }, + ) + require.ErrorContains(t, err, "originator") +} + +func TestMissingTopic(t *testing.T) { + svc, _, cleanup := newTestService(t) + defer cleanup() + + clientEnv := createClientEnvelope() + clientEnv.Aad.TargetTopic = nil + _, err := svc.PublishEnvelope( + context.Background(), + &message_api.PublishEnvelopeRequest{ + PayerEnvelope: createPayerEnvelope(t, clientEnv), + }, + ) + require.ErrorContains(t, err, "topic") } diff --git a/pkg/db/queries.sql b/pkg/db/queries.sql index 45f18a05..cc380f7f 100644 --- a/pkg/db/queries.sql +++ b/pkg/db/queries.sql @@ -35,7 +35,7 @@ LIMIT sqlc.narg('row_limit')::INT; SELECT * FROM - insert_staged_originator_envelope(@payer_envelope); + insert_staged_originator_envelope(@topic, @payer_envelope); -- name: SelectStagedOriginatorEnvelopes :many SELECT diff --git a/pkg/db/queries/models.go b/pkg/db/queries/models.go index c6c3d4d6..36178292 100644 --- a/pkg/db/queries/models.go +++ b/pkg/db/queries/models.go @@ -33,5 +33,6 @@ type NodeInfo struct { type StagedOriginatorEnvelope struct { ID int64 OriginatorTime time.Time + Topic []byte PayerEnvelope []byte } diff --git a/pkg/db/queries/queries.sql.go b/pkg/db/queries/queries.sql.go index 7244bb48..f68153fd 100644 --- a/pkg/db/queries/queries.sql.go +++ b/pkg/db/queries/queries.sql.go @@ -70,15 +70,25 @@ func (q *Queries) InsertNodeInfo(ctx context.Context, arg InsertNodeInfoParams) const insertStagedOriginatorEnvelope = `-- name: InsertStagedOriginatorEnvelope :one SELECT - id, originator_time, payer_envelope + id, originator_time, topic, payer_envelope FROM - insert_staged_originator_envelope($1) + insert_staged_originator_envelope($1, $2) ` -func (q *Queries) InsertStagedOriginatorEnvelope(ctx context.Context, payerEnvelope []byte) (StagedOriginatorEnvelope, error) { - row := q.db.QueryRowContext(ctx, insertStagedOriginatorEnvelope, payerEnvelope) +type InsertStagedOriginatorEnvelopeParams struct { + Topic []byte + PayerEnvelope []byte +} + +func (q *Queries) InsertStagedOriginatorEnvelope(ctx context.Context, arg InsertStagedOriginatorEnvelopeParams) (StagedOriginatorEnvelope, error) { + row := q.db.QueryRowContext(ctx, insertStagedOriginatorEnvelope, arg.Topic, arg.PayerEnvelope) var i StagedOriginatorEnvelope - err := row.Scan(&i.ID, &i.OriginatorTime, &i.PayerEnvelope) + err := row.Scan( + &i.ID, + &i.OriginatorTime, + &i.Topic, + &i.PayerEnvelope, + ) return i, err } @@ -159,7 +169,7 @@ func (q *Queries) SelectNodeInfo(ctx context.Context) (NodeInfo, error) { const selectStagedOriginatorEnvelopes = `-- name: SelectStagedOriginatorEnvelopes :many SELECT - id, originator_time, payer_envelope + id, originator_time, topic, payer_envelope FROM staged_originator_envelopes WHERE @@ -183,7 +193,12 @@ func (q *Queries) SelectStagedOriginatorEnvelopes(ctx context.Context, arg Selec var items []StagedOriginatorEnvelope for rows.Next() { var i StagedOriginatorEnvelope - if err := rows.Scan(&i.ID, &i.OriginatorTime, &i.PayerEnvelope); err != nil { + if err := rows.Scan( + &i.ID, + &i.OriginatorTime, + &i.Topic, + &i.PayerEnvelope, + ); err != nil { return nil, err } items = append(items, i) diff --git a/pkg/migrations/00001_init-schema.up.sql b/pkg/migrations/00001_init-schema.up.sql index b08e8d45..21cf35d0 100644 --- a/pkg/migrations/00001_init-schema.up.sql +++ b/pkg/migrations/00001_init-schema.up.sql @@ -41,30 +41,24 @@ END; $$ LANGUAGE plpgsql; --- Process for originating envelopes: --- 1. Perform any necessary validation --- 2. Insert into originated_envelopes --- 3. Singleton background task will continuously query (or subscribe to) --- staged_originated_envelopes, and for each envelope in order of ID: --- 2.1. Construct and sign OriginatorEnvelope proto --- 2.2. Atomically insert into all_envelopes and delete from originated_envelopes, --- ignoring unique index violations on originator_sid --- This preserves total ordering, while avoiding gaps in sequence ID's. +-- Newly published envelopes will be queued here first (and assigned an originator +-- sequence ID), before being inserted in-order into the gateway_envelopes table. CREATE TABLE staged_originator_envelopes( -- used to construct originator_sid id BIGSERIAL PRIMARY KEY, originator_time TIMESTAMP NOT NULL DEFAULT now(), + topic BYTEA NOT NULL, payer_envelope BYTEA NOT NULL ); -CREATE FUNCTION insert_staged_originator_envelope(payer_envelope BYTEA) +CREATE FUNCTION insert_staged_originator_envelope(topic BYTEA, payer_envelope BYTEA) RETURNS SETOF staged_originator_envelopes AS $$ BEGIN PERFORM pg_advisory_xact_lock(hashtext('staged_originator_envelopes_sequence')); - RETURN QUERY INSERT INTO staged_originator_envelopes(payer_envelope) - VALUES(payer_envelope) + RETURN QUERY INSERT INTO staged_originator_envelopes(topic, payer_envelope) + VALUES(topic, payer_envelope) ON CONFLICT DO NOTHING RETURNING diff --git a/pkg/registrant/registrant.go b/pkg/registrant/registrant.go index e642a7e1..38fdb458 100644 --- a/pkg/registrant/registrant.go +++ b/pkg/registrant/registrant.go @@ -60,6 +60,10 @@ func (r *Registrant) signKeccak256(data []byte) ([]byte, error) { return crypto.Sign(hash, r.privateKey) } +func (r *Registrant) NodeID() uint16 { + return r.record.NodeID +} + func (r *Registrant) SignStagedEnvelope( stagedEnv queries.StagedOriginatorEnvelope, ) (*message_api.OriginatorEnvelope, error) {