Skip to content

Commit

Permalink
Add publish worker (#115)
Browse files Browse the repository at this point in the history
- Adds a publish worker that performs in-order insertion into the
`gateway_envelopes` table
- Adds basic validation of the client envelope on publish
- Store the topic on the `staged_originated_envelopes` table - we
extract this during the API call, so that the publish worker doesn't
need to do any additional unmarshaling or validation.

Would particularly love feedback on the error handling in the worker,
and if there's any test cases I should add to `service_test.go`.
  • Loading branch information
richardhuaaa authored Aug 27, 2024
1 parent 9771fb2 commit 650bf24
Show file tree
Hide file tree
Showing 8 changed files with 346 additions and 41 deletions.
147 changes: 147 additions & 0 deletions pkg/api/publishWorker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package api

import (
"context"
"database/sql"
"time"

"github.com/xmtp/xmtpd/pkg/db"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/registrant"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
)

type PublishWorker struct {
ctx context.Context
log *zap.Logger
listener <-chan []queries.StagedOriginatorEnvelope
notifier chan<- bool
registrant *registrant.Registrant
store *sql.DB
subscription db.DBSubscription[queries.StagedOriginatorEnvelope]
}

func StartPublishWorker(
ctx context.Context,
log *zap.Logger,
reg *registrant.Registrant,
store *sql.DB,
) (*PublishWorker, error) {
q := queries.New(store)
query := func(ctx context.Context, lastSeenID int64, numRows int32) ([]queries.StagedOriginatorEnvelope, int64, error) {
results, err := q.SelectStagedOriginatorEnvelopes(
ctx,
queries.SelectStagedOriginatorEnvelopesParams{
LastSeenID: lastSeenID,
NumRows: numRows,
},
)
if err != nil {
return nil, 0, err
}
if len(results) > 0 {
lastSeenID = results[len(results)-1].ID
}
return results, lastSeenID, nil
}
notifier := make(chan bool, 1)
subscription := db.NewDBSubscription(
ctx,
log,
query,
0, // lastSeenID
db.PollingOptions{Interval: time.Second, Notifier: notifier, NumRows: 100},
)
listener, err := subscription.Start()
if err != nil {
return nil, err
}

worker := &PublishWorker{
ctx: ctx,
log: log,
notifier: notifier,
subscription: *subscription,
listener: listener,
registrant: reg,
store: store,
}
go worker.start()

return worker, nil
}

func (p *PublishWorker) NotifyStagedPublish() {
select {
case p.notifier <- true:
default:
}
}

func (p *PublishWorker) start() {
for {
select {
case <-p.ctx.Done():
return
case new_batch := <-p.listener:
for _, stagedEnv := range new_batch {
for !p.publishStagedEnvelope(stagedEnv) {
// Infinite retry on failure to publish; we cannot
// continue to the next envelope until this one is processed
time.Sleep(time.Second)
}
}
}
}
}

func (p *PublishWorker) publishStagedEnvelope(stagedEnv queries.StagedOriginatorEnvelope) bool {
logger := p.log.With(zap.Int64("sequenceID", stagedEnv.ID))
originatorEnv, err := p.registrant.SignStagedEnvelope(stagedEnv)
if err != nil {
logger.Error(
"Failed to sign staged envelope",
zap.Error(err),
)
return false
}
originatorBytes, err := proto.Marshal(originatorEnv)
if err != nil {
logger.Error("Failed to marshal originator envelope", zap.Error(err))
return false
}

q := queries.New(p.store)

// On unique constraint conflicts, no error is thrown, but numRows is 0
inserted, err := q.InsertGatewayEnvelope(
p.ctx,
queries.InsertGatewayEnvelopeParams{
OriginatorID: int32(p.registrant.NodeID()),
OriginatorSequenceID: stagedEnv.ID,
Topic: stagedEnv.Topic,
OriginatorEnvelope: originatorBytes,
},
)
if err != nil {
logger.Error("Failed to insert gateway envelope", zap.Error(err))
return false
} else if inserted == 0 {
// Envelope was already inserted by another worker
logger.Debug("Envelope already inserted")
}

// Try to delete the row regardless of if the gateway envelope was inserted elsewhere
deleted, err := q.DeleteStagedOriginatorEnvelope(context.Background(), stagedEnv.ID)
if err != nil {
logger.Error("Failed to delete staged envelope", zap.Error(err))
// Envelope is already inserted, so it is safe to continue
return true
} else if deleted == 0 {
// Envelope was already deleted by another worker
logger.Debug("Envelope already deleted")
}

return true
}
84 changes: 70 additions & 14 deletions pkg/api/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,27 @@ type Service struct {
ctx context.Context
log *zap.Logger
registrant *registrant.Registrant
queries *queries.Queries
store *sql.DB
worker *PublishWorker
}

func NewReplicationApiService(
ctx context.Context,
log *zap.Logger,
registrant *registrant.Registrant,
writerDB *sql.DB,
store *sql.DB,
) (*Service, error) {
return &Service{ctx: ctx, log: log, registrant: registrant, queries: queries.New(writerDB)}, nil
worker, err := StartPublishWorker(ctx, log, registrant, store)
if err != nil {
return nil, err
}
return &Service{
ctx: ctx,
log: log,
registrant: registrant,
store: store,
worker: worker,
}, nil
}

func (s *Service) Close() {
Expand All @@ -54,27 +65,32 @@ func (s *Service) PublishEnvelope(
ctx context.Context,
req *message_api.PublishEnvelopeRequest,
) (*message_api.PublishEnvelopeResponse, error) {
payerEnv := req.GetPayerEnvelope()
clientBytes := payerEnv.GetUnsignedClientEnvelope()
sig := payerEnv.GetPayerSignature()
if (clientBytes == nil) || (sig == nil) {
return nil, status.Errorf(codes.InvalidArgument, "missing envelope or signature")
clientEnv, err := s.validatePayerInfo(req.GetPayerEnvelope())
if err != nil {
return nil, err
}
// TODO(rich): Verify payer signature
// TODO(rich): Verify all originators have synced past `last_originator_sids`
// TODO(rich): Check that the blockchain sequence ID is equal to the latest on the group
// TODO(rich): Perform any payload-specific validation (e.g. identity updates)

topic, err := s.validateClientInfo(clientEnv)
if err != nil {
return nil, err
}

// TODO(rich): If it is a commit, publish it to blockchain instead

payerBytes, err := proto.Marshal(payerEnv)
payerBytes, err := proto.Marshal(req.GetPayerEnvelope())
if err != nil {
return nil, status.Errorf(codes.Internal, "could not marshal envelope: %v", err)
}

stagedEnv, err := s.queries.InsertStagedOriginatorEnvelope(ctx, payerBytes)
stagedEnv, err := queries.New(s.store).
InsertStagedOriginatorEnvelope(ctx, queries.InsertStagedOriginatorEnvelopeParams{
Topic: topic,
PayerEnvelope: payerBytes,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "could not insert staged envelope: %v", err)
}
s.worker.NotifyStagedPublish()

originatorEnv, err := s.registrant.SignStagedEnvelope(stagedEnv)
if err != nil {
Expand All @@ -83,3 +99,43 @@ func (s *Service) PublishEnvelope(

return &message_api.PublishEnvelopeResponse{OriginatorEnvelope: originatorEnv}, nil
}

func (s *Service) validatePayerInfo(
payerEnv *message_api.PayerEnvelope,
) (*message_api.ClientEnvelope, error) {
clientBytes := payerEnv.GetUnsignedClientEnvelope()
sig := payerEnv.GetPayerSignature()
if (clientBytes == nil) || (sig == nil) {
return nil, status.Errorf(codes.InvalidArgument, "missing envelope or signature")
}
// TODO(rich): Verify payer signature

clientEnv := &message_api.ClientEnvelope{}
err := proto.Unmarshal(clientBytes, clientEnv)
if err != nil {
return nil, status.Errorf(
codes.InvalidArgument,
"could not unmarshal client envelope: %v",
err,
)
}

return clientEnv, nil
}

func (s *Service) validateClientInfo(clientEnv *message_api.ClientEnvelope) ([]byte, error) {
if clientEnv.GetAad().GetTargetOriginator() != uint32(s.registrant.NodeID()) {
return nil, status.Errorf(codes.InvalidArgument, "invalid target originator")
}

topic := clientEnv.GetAad().GetTargetTopic()
if len(topic) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "missing target topic")
}

// TODO(rich): Verify all originators have synced past `last_originator_sids`
// TODO(rich): Check that the blockchain sequence ID is equal to the latest on the group
// TODO(rich): Perform any payload-specific validation (e.g. identity updates)

return topic, nil
}
102 changes: 95 additions & 7 deletions pkg/api/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"testing"
"time"

"github.com/ethereum/go-ethereum/crypto"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -40,17 +41,41 @@ func newTestService(t *testing.T) (*Service, *sql.DB, func()) {
}
}

func createClientEnvelope() *message_api.ClientEnvelope {
return &message_api.ClientEnvelope{
Payload: nil,
Aad: &message_api.AuthenticatedData{
TargetOriginator: 1,
TargetTopic: []byte{0x5},
LastOriginatorSids: []uint64{},
},
}
}

func createPayerEnvelope(
t *testing.T,
clientEnv ...*message_api.ClientEnvelope,
) *message_api.PayerEnvelope {
if len(clientEnv) == 0 {
clientEnv = append(clientEnv, createClientEnvelope())
}
clientEnvBytes, err := proto.Marshal(clientEnv[0])
require.NoError(t, err)

return &message_api.PayerEnvelope{
UnsignedClientEnvelope: clientEnvBytes,
PayerSignature: &associations.RecoverableEcdsaSignature{},
}
}

func TestSimplePublish(t *testing.T) {
svc, _, cleanup := newTestService(t)
svc, db, cleanup := newTestService(t)
defer cleanup()

resp, err := svc.PublishEnvelope(
context.Background(),
&message_api.PublishEnvelopeRequest{
PayerEnvelope: &message_api.PayerEnvelope{
UnsignedClientEnvelope: []byte{0x5},
PayerSignature: &associations.RecoverableEcdsaSignature{},
},
PayerEnvelope: createPayerEnvelope(t),
},
)
require.NoError(t, err)
Expand All @@ -61,7 +86,70 @@ func TestSimplePublish(t *testing.T) {
t,
proto.Unmarshal(resp.GetOriginatorEnvelope().GetUnsignedOriginatorEnvelope(), unsignedEnv),
)
require.Equal(t, uint8(0x5), unsignedEnv.GetPayerEnvelope().GetUnsignedClientEnvelope()[0])
clientEnv := &message_api.ClientEnvelope{}
require.NoError(
t,
proto.Unmarshal(unsignedEnv.GetPayerEnvelope().GetUnsignedClientEnvelope(), clientEnv),
)
require.Equal(t, uint8(0x5), clientEnv.Aad.GetTargetTopic()[0])

// TODO(rich) Test that the published envelope is retrievable via the query API
// Check that the envelope was published to the database after a delay
require.Eventually(t, func() bool {
envs, err := queries.New(db).
SelectGatewayEnvelopes(context.Background(), queries.SelectGatewayEnvelopesParams{})
require.NoError(t, err)

if len(envs) != 1 {
return false
}

originatorEnv := &message_api.OriginatorEnvelope{}
require.NoError(t, proto.Unmarshal(envs[0].OriginatorEnvelope, originatorEnv))
return proto.Equal(originatorEnv, resp.GetOriginatorEnvelope())
}, 500*time.Millisecond, 50*time.Millisecond)
}

func TestUnmarshalError(t *testing.T) {
svc, _, cleanup := newTestService(t)
defer cleanup()

envelope := createPayerEnvelope(t)
envelope.UnsignedClientEnvelope = []byte("invalidbytes")
_, err := svc.PublishEnvelope(
context.Background(),
&message_api.PublishEnvelopeRequest{
PayerEnvelope: envelope,
},
)
require.ErrorContains(t, err, "unmarshal")
}

func TestMismatchingOriginator(t *testing.T) {
svc, _, cleanup := newTestService(t)
defer cleanup()

clientEnv := createClientEnvelope()
clientEnv.Aad.TargetOriginator = 2
_, err := svc.PublishEnvelope(
context.Background(),
&message_api.PublishEnvelopeRequest{
PayerEnvelope: createPayerEnvelope(t, clientEnv),
},
)
require.ErrorContains(t, err, "originator")
}

func TestMissingTopic(t *testing.T) {
svc, _, cleanup := newTestService(t)
defer cleanup()

clientEnv := createClientEnvelope()
clientEnv.Aad.TargetTopic = nil
_, err := svc.PublishEnvelope(
context.Background(),
&message_api.PublishEnvelopeRequest{
PayerEnvelope: createPayerEnvelope(t, clientEnv),
},
)
require.ErrorContains(t, err, "topic")
}
Loading

0 comments on commit 650bf24

Please sign in to comment.