Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate published envelope topics #233

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions pkg/api/publish_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,22 @@ import (
"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/testutils"
apiTestUtils "github.com/xmtp/xmtpd/pkg/testutils/api"
envelopeTestUtils "github.com/xmtp/xmtpd/pkg/testutils/envelopes"
"github.com/xmtp/xmtpd/pkg/topic"
"google.golang.org/protobuf/proto"
)

func TestPublishEnvelope(t *testing.T) {
api, db, cleanup := apiTestUtils.NewTestAPIClient(t)
defer cleanup()

payerEnvelope := envelopeTestUtils.CreatePayerEnvelope(t)

resp, err := api.PublishEnvelopes(
context.Background(),
&message_api.PublishEnvelopesRequest{
PayerEnvelopes: []*message_api.PayerEnvelope{testutils.CreatePayerEnvelope(t)},
PayerEnvelopes: []*message_api.PayerEnvelope{payerEnvelope},
},
)
require.NoError(t, err)
Expand All @@ -39,7 +42,9 @@ func TestPublishEnvelope(t *testing.T) {
t,
proto.Unmarshal(unsignedEnv.GetPayerEnvelope().GetUnsignedClientEnvelope(), clientEnv),
)
require.Equal(t, uint8(0x5), clientEnv.Aad.GetTargetTopic()[0])

_, err = topic.ParseTopic(clientEnv.Aad.GetTargetTopic())
require.NoError(t, err)

// Check that the envelope was published to the database after a delay
require.Eventually(t, func() bool {
Expand All @@ -61,28 +66,28 @@ func TestUnmarshalErrorOnPublish(t *testing.T) {
api, _, cleanup := apiTestUtils.NewTestAPIClient(t)
defer cleanup()

envelope := testutils.CreatePayerEnvelope(t)
envelope := envelopeTestUtils.CreatePayerEnvelope(t)
envelope.UnsignedClientEnvelope = []byte("invalidbytes")
_, err := api.PublishEnvelopes(
context.Background(),
&message_api.PublishEnvelopesRequest{
PayerEnvelopes: []*message_api.PayerEnvelope{envelope},
},
)
require.ErrorContains(t, err, "unmarshal")
require.ErrorContains(t, err, "invalid wire-format data")
}

func TestMismatchingOriginatorOnPublish(t *testing.T) {
api, _, cleanup := apiTestUtils.NewTestAPIClient(t)
defer cleanup()

clientEnv := testutils.CreateClientEnvelope()
clientEnv := envelopeTestUtils.CreateClientEnvelope()
clientEnv.Aad.TargetOriginator = 2
_, err := api.PublishEnvelopes(
context.Background(),
&message_api.PublishEnvelopesRequest{
PayerEnvelopes: []*message_api.PayerEnvelope{
testutils.CreatePayerEnvelope(t, clientEnv),
envelopeTestUtils.CreatePayerEnvelope(t, clientEnv),
},
},
)
Expand All @@ -93,13 +98,13 @@ func TestMissingTopicOnPublish(t *testing.T) {
api, _, cleanup := apiTestUtils.NewTestAPIClient(t)
defer cleanup()

clientEnv := testutils.CreateClientEnvelope()
clientEnv := envelopeTestUtils.CreateClientEnvelope()
clientEnv.Aad.TargetTopic = nil
_, err := api.PublishEnvelopes(
context.Background(),
&message_api.PublishEnvelopesRequest{
PayerEnvelopes: []*message_api.PayerEnvelope{
testutils.CreatePayerEnvelope(t, clientEnv),
envelopeTestUtils.CreatePayerEnvelope(t, clientEnv),
},
},
)
Expand Down
11 changes: 6 additions & 5 deletions pkg/api/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/testutils"
apiTestUtils "github.com/xmtp/xmtpd/pkg/testutils/api"
envelopeTestUtils "github.com/xmtp/xmtpd/pkg/testutils/envelopes"
)

func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopeParams {
Expand All @@ -21,7 +22,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar
Topic: []byte("topicA"),
OriginatorEnvelope: testutils.Marshal(
t,
testutils.CreateOriginatorEnvelope(t, 1, 1),
envelopeTestUtils.CreateOriginatorEnvelope(t, 1, 1),
),
},
{
Expand All @@ -30,7 +31,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar
Topic: []byte("topicA"),
OriginatorEnvelope: testutils.Marshal(
t,
testutils.CreateOriginatorEnvelope(t, 2, 1),
envelopeTestUtils.CreateOriginatorEnvelope(t, 2, 1),
),
},
{
Expand All @@ -39,7 +40,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar
Topic: []byte("topicB"),
OriginatorEnvelope: testutils.Marshal(
t,
testutils.CreateOriginatorEnvelope(t, 1, 2),
envelopeTestUtils.CreateOriginatorEnvelope(t, 1, 2),
),
},
{
Expand All @@ -48,7 +49,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar
Topic: []byte("topicB"),
OriginatorEnvelope: testutils.Marshal(
t,
testutils.CreateOriginatorEnvelope(t, 2, 2),
envelopeTestUtils.CreateOriginatorEnvelope(t, 2, 2),
),
},
{
Expand All @@ -57,7 +58,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar
Topic: []byte("topicA"),
OriginatorEnvelope: testutils.Marshal(
t,
testutils.CreateOriginatorEnvelope(t, 1, 3),
envelopeTestUtils.CreateOriginatorEnvelope(t, 1, 3),
),
},
}
Expand Down
59 changes: 25 additions & 34 deletions pkg/api/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/xmtp/xmtpd/pkg/blockchain"
"github.com/xmtp/xmtpd/pkg/db"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/envelopes"
"github.com/xmtp/xmtpd/pkg/proto/identity/associations"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/registrant"
Expand Down Expand Up @@ -219,17 +220,13 @@ func (s *Service) PublishEnvelopes(
if len(req.GetPayerEnvelopes()) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "missing payer envelope")
}
clientEnv, err := s.validatePayerInfo(req.GetPayerEnvelopes()[0])
if err != nil {
return nil, err
}

topic, err := s.validateClientInfo(clientEnv)
payerEnv, err := s.validatePayerEnvelope(req.GetPayerEnvelopes()[0])
if err != nil {
return nil, err
}

didPublish, err := s.maybePublishToBlockchain(ctx, clientEnv)
didPublish, err := s.maybePublishToBlockchain(ctx, &payerEnv.ClientEnvelope)
if err != nil {
return nil, err
}
Expand All @@ -238,14 +235,16 @@ func (s *Service) PublishEnvelopes(
}

// TODO(rich): Properly support batch publishing
payerBytes, err := proto.Marshal(req.GetPayerEnvelopes()[0])
payerBytes, err := payerEnv.Bytes()
if err != nil {
return nil, status.Errorf(codes.Internal, "could not marshal envelope: %v", err)
}

targetTopic := payerEnv.ClientEnvelope.TargetTopic()

stagedEnv, err := queries.New(s.store).
InsertStagedOriginatorEnvelope(ctx, queries.InsertStagedOriginatorEnvelopeParams{
Topic: topic,
Topic: targetTopic.Bytes(),
PayerEnvelope: payerBytes,
})
if err != nil {
Expand All @@ -265,9 +264,9 @@ func (s *Service) PublishEnvelopes(

func (s *Service) maybePublishToBlockchain(
ctx context.Context,
clientEnv *message_api.ClientEnvelope,
clientEnv *envelopes.ClientEnvelope,
) (didPublish bool, err error) {
payload, ok := clientEnv.GetPayload().(*message_api.ClientEnvelope_IdentityUpdate)
payload, ok := clientEnv.Payload().(*message_api.ClientEnvelope_IdentityUpdate)
if ok && payload.IdentityUpdate != nil {
if err = s.publishIdentityUpdate(ctx, payload.IdentityUpdate); err != nil {
s.log.Error("could not publish identity update", zap.Error(err))
Expand Down Expand Up @@ -340,42 +339,34 @@ func (s *Service) GetInboxIds(
}, nil
}

func (s *Service) validatePayerInfo(
payerEnv *message_api.PayerEnvelope,
) (*message_api.ClientEnvelope, error) {
clientBytes := payerEnv.GetUnsignedClientEnvelope()
sig := payerEnv.GetPayerSignature()
if clientBytes == nil || sig == nil {
return nil, status.Errorf(codes.InvalidArgument, "missing envelope or signature")
func (s *Service) validatePayerEnvelope(
rawEnv *message_api.PayerEnvelope,
) (*envelopes.PayerEnvelope, error) {
payerEnv, err := envelopes.NewPayerEnvelope(rawEnv)
if err != nil {
return nil, err
}
// TODO(rich): Verify payer signature

clientEnv := &message_api.ClientEnvelope{}
err := proto.Unmarshal(clientBytes, clientEnv)
if err != nil {
return nil, status.Errorf(
codes.InvalidArgument,
"could not unmarshal client envelope: %v",
err,
)
if err := s.validateClientInfo(&payerEnv.ClientEnvelope); err != nil {
return nil, err
}

return clientEnv, nil
return payerEnv, nil
}

func (s *Service) validateClientInfo(clientEnv *message_api.ClientEnvelope) ([]byte, error) {
if clientEnv.GetAad().GetTargetOriginator() != s.registrant.NodeID() {
return nil, status.Errorf(codes.InvalidArgument, "invalid target originator")
func (s *Service) validateClientInfo(clientEnv *envelopes.ClientEnvelope) error {
aad := clientEnv.Aad()
if aad.GetTargetOriginator() != s.registrant.NodeID() {
return status.Errorf(codes.InvalidArgument, "invalid target originator")
}

topic := clientEnv.GetAad().GetTargetTopic()
if len(topic) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "missing target topic")
if !clientEnv.TopicMatchesPayload() {
return status.Errorf(codes.InvalidArgument, "topic does not match payload")
}

// TODO(rich): Verify all originators have synced past `last_seen`
// TODO(rich): Check that the blockchain sequence ID is equal to the latest on the group
// TODO(rich): Perform any payload-specific validation (e.g. identity updates)

return topic, nil
return nil
}
13 changes: 7 additions & 6 deletions pkg/api/subscribe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/testutils"
testUtilsApi "github.com/xmtp/xmtpd/pkg/testutils/api"
envelopeTestUtils "github.com/xmtp/xmtpd/pkg/testutils/envelopes"
)

var allRows []queries.InsertGatewayEnvelopeParams
Expand All @@ -26,7 +27,7 @@ func setupTest(t *testing.T) (message_api.ReplicationApiClient, *sql.DB, func())
Topic: []byte("topicA"),
OriginatorEnvelope: testutils.Marshal(
t,
testutils.CreateOriginatorEnvelope(t, 1, 1),
envelopeTestUtils.CreateOriginatorEnvelope(t, 1, 1),
),
},
{
Expand All @@ -35,7 +36,7 @@ func setupTest(t *testing.T) (message_api.ReplicationApiClient, *sql.DB, func())
Topic: []byte("topicA"),
OriginatorEnvelope: testutils.Marshal(
t,
testutils.CreateOriginatorEnvelope(t, 2, 1),
envelopeTestUtils.CreateOriginatorEnvelope(t, 2, 1),
),
},
// Later rows
Expand All @@ -45,7 +46,7 @@ func setupTest(t *testing.T) (message_api.ReplicationApiClient, *sql.DB, func())
Topic: []byte("topicB"),
OriginatorEnvelope: testutils.Marshal(
t,
testutils.CreateOriginatorEnvelope(t, 1, 2),
envelopeTestUtils.CreateOriginatorEnvelope(t, 1, 2),
),
},
{
Expand All @@ -54,7 +55,7 @@ func setupTest(t *testing.T) (message_api.ReplicationApiClient, *sql.DB, func())
Topic: []byte("topicB"),
OriginatorEnvelope: testutils.Marshal(
t,
testutils.CreateOriginatorEnvelope(t, 2, 2),
envelopeTestUtils.CreateOriginatorEnvelope(t, 2, 2),
),
},
{
Expand All @@ -63,7 +64,7 @@ func setupTest(t *testing.T) (message_api.ReplicationApiClient, *sql.DB, func())
Topic: []byte("topicA"),
OriginatorEnvelope: testutils.Marshal(
t,
testutils.CreateOriginatorEnvelope(t, 1, 3),
envelopeTestUtils.CreateOriginatorEnvelope(t, 1, 3),
),
},
}
Expand Down Expand Up @@ -94,7 +95,7 @@ func validateUpdates(
require.NoError(t, err)
for _, env := range envs.Envelopes {
expected := allRows[expectedIndices[i]]
actual := testutils.UnmarshalUnsignedOriginatorEnvelope(
actual := envelopeTestUtils.UnmarshalUnsignedOriginatorEnvelope(
t,
env.UnsignedOriginatorEnvelope,
)
Expand Down
38 changes: 35 additions & 3 deletions pkg/envelopes/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"errors"

"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/topic"
"google.golang.org/protobuf/proto"
)

type ClientEnvelope struct {
proto *message_api.ClientEnvelope
proto *message_api.ClientEnvelope
targetTopic topic.Topic
}

func NewClientEnvelope(proto *message_api.ClientEnvelope) (*ClientEnvelope, error) {
Expand All @@ -24,9 +26,12 @@ func NewClientEnvelope(proto *message_api.ClientEnvelope) (*ClientEnvelope, erro
return nil, errors.New("payload is missing")
}

// TODO:(nm) Validate topic
targetTopic, err := topic.ParseTopic(proto.Aad.TargetTopic)
if err != nil {
return nil, err
}

return &ClientEnvelope{proto: proto}, nil
return &ClientEnvelope{proto: proto, targetTopic: *targetTopic}, nil
}

func NewClientEnvelopeFromBytes(bytes []byte) (*ClientEnvelope, error) {
Expand All @@ -45,10 +50,37 @@ func (c *ClientEnvelope) Bytes() ([]byte, error) {
return bytes, nil
}

func (c *ClientEnvelope) TargetTopic() topic.Topic {
return c.targetTopic
}

func (c *ClientEnvelope) Payload() interface{} {
return c.proto.Payload
}

func (c *ClientEnvelope) Aad() *message_api.AuthenticatedData {
return c.proto.Aad
}

func (c *ClientEnvelope) Proto() *message_api.ClientEnvelope {
return c.proto
}

func (c *ClientEnvelope) TopicMatchesPayload() bool {
targetTopic := c.TargetTopic()
targetTopicKind := targetTopic.Kind()
payload := c.proto.Payload

switch payload.(type) {
case *message_api.ClientEnvelope_WelcomeMessage:
return targetTopicKind == topic.TOPIC_KIND_WELCOME_MESSAGES_V1
case *message_api.ClientEnvelope_GroupMessage:
return targetTopicKind == topic.TOPIC_KIND_GROUP_MESSAGES_V1
case *message_api.ClientEnvelope_IdentityUpdate:
return targetTopicKind == topic.TOPIC_KIND_IDENTITY_UPDATES_V1
case *message_api.ClientEnvelope_UploadKeyPackage:
return targetTopicKind == topic.TOPIC_KIND_KEY_PACKAGES_V1
default:
return false
}
}
Loading
Loading