Skip to content

Commit

Permalink
Add auth token to GRPC requests (#236)
Browse files Browse the repository at this point in the history
## tl;dr

- Creates a client interceptor for attaching auth tokens to requests
- Actually attaches JWTs to requests coming from the sync service

## AI Generated Summary

Implemented node-to-node authentication using JWT tokens for secure
communication between XMTP nodes.

### What changed?

- Added an `AuthInterceptor` for client-side gRPC calls to inject JWT
tokens into request headers.
- Updated the `TokenFactory` to use `uint32` for node IDs instead of
`int32`.
- Modified the `Registrant` to include a `TokenFactory`.
- Updated the `syncWorker` to use the new `AuthInterceptor` when
connecting to other nodes.
- Added a new constant `NODE_AUTHORIZATION_HEADER_NAME` for the
authentication header.
- Created mock implementations and tests for the new authentication
system.

### How to test?

1. Run the existing test suite to ensure no regressions.
2. Check the new test file `pkg/interceptors/client/auth_test.go` for
specific tests of the authentication system.
3. Manually test node-to-node communication to verify that
authentication is working as expected.

### Why make this change?

This change enhances the security of the XMTP network by implementing
node-to-node authentication. It ensures that only authorized nodes can
communicate with each other, preventing unauthorized access to sensitive
data and operations. This is a crucial step in maintaining the integrity
and confidentiality of the XMTP network.
  • Loading branch information
neekolas authored Oct 18, 2024
1 parent 161823d commit d3fe5c1
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 16 deletions.
6 changes: 3 additions & 3 deletions pkg/authn/tokenFactory.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ const (

type TokenFactory struct {
privateKey *ecdsa.PrivateKey
nodeID int32
nodeID uint32
}

func NewTokenFactory(privateKey *ecdsa.PrivateKey, nodeID int32) *TokenFactory {
func NewTokenFactory(privateKey *ecdsa.PrivateKey, nodeID uint32) *TokenFactory {
return &TokenFactory{
privateKey: privateKey,
nodeID: nodeID,
}
}

func (f *TokenFactory) CreateToken(forNodeID int32) (*Token, error) {
func (f *TokenFactory) CreateToken(forNodeID uint32) (*Token, error) {
now := time.Now()
expiresAt := now.Add(TOKEN_DURATION)

Expand Down
18 changes: 9 additions & 9 deletions pkg/authn/verifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func buildJwt(
func TestVerifier(t *testing.T) {
signerPrivateKey := testutils.RandomPrivateKey(t)

tokenFactory := NewTokenFactory(signerPrivateKey, int32(SIGNER_NODE_ID))
tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID))

verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID))
nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(&registry.Node{
Expand All @@ -62,14 +62,14 @@ func TestVerifier(t *testing.T) {
}, nil)

// Create a token targeting the verifier's node as the audience
token, err := tokenFactory.CreateToken(int32(VERIFIER_NODE_ID))
token, err := tokenFactory.CreateToken(uint32(VERIFIER_NODE_ID))
require.NoError(t, err)
// This should verify correctly
verificationError := verifier.Verify(token.SignedString)
require.NoError(t, verificationError)

// Create a token targeting a different node as the audience
tokenForWrongNode, err := tokenFactory.CreateToken(int32(300))
tokenForWrongNode, err := tokenFactory.CreateToken(uint32(300))
require.NoError(t, err)
// This should not verify correctly
verificationError = verifier.Verify(tokenForWrongNode.SignedString)
Expand All @@ -79,15 +79,15 @@ func TestVerifier(t *testing.T) {
func TestWrongAudience(t *testing.T) {
signerPrivateKey := testutils.RandomPrivateKey(t)

tokenFactory := NewTokenFactory(signerPrivateKey, int32(SIGNER_NODE_ID))
tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID))

verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID))
nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(&registry.Node{
SigningKey: &signerPrivateKey.PublicKey,
NodeID: uint32(SIGNER_NODE_ID),
}, nil)
// Create a token targeting a different node as the audience
tokenForWrongNode, err := tokenFactory.CreateToken(int32(300))
tokenForWrongNode, err := tokenFactory.CreateToken(uint32(300))
require.NoError(t, err)
// This should not verify correctly
verificationError := verifier.Verify(tokenForWrongNode.SignedString)
Expand All @@ -97,12 +97,12 @@ func TestWrongAudience(t *testing.T) {
func TestUnknownNode(t *testing.T) {
signerPrivateKey := testutils.RandomPrivateKey(t)

tokenFactory := NewTokenFactory(signerPrivateKey, int32(SIGNER_NODE_ID))
tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID))

verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID))
nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(nil, errors.New("node not found"))

token, err := tokenFactory.CreateToken(int32(VERIFIER_NODE_ID))
token, err := tokenFactory.CreateToken(uint32(VERIFIER_NODE_ID))
require.NoError(t, err)

verificationError := verifier.Verify(token.SignedString)
Expand All @@ -112,7 +112,7 @@ func TestUnknownNode(t *testing.T) {
func TestWrongPublicKey(t *testing.T) {
signerPrivateKey := testutils.RandomPrivateKey(t)

tokenFactory := NewTokenFactory(signerPrivateKey, int32(SIGNER_NODE_ID))
tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID))

verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID))

Expand All @@ -122,7 +122,7 @@ func TestWrongPublicKey(t *testing.T) {
NodeID: uint32(SIGNER_NODE_ID),
}, nil)

token, err := tokenFactory.CreateToken(int32(VERIFIER_NODE_ID))
token, err := tokenFactory.CreateToken(uint32(VERIFIER_NODE_ID))
require.NoError(t, err)

verificationError := verifier.Verify(token.SignedString)
Expand Down
1 change: 1 addition & 0 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ const (
JWT_DOMAIN_SEPARATION_LABEL = "jwt|"
PAYER_DOMAIN_SEPARATION_LABEL = "payer|"
ORIGINATOR_DOMAIN_SEPARATION_LABEL = "originator|"
NODE_AUTHORIZATION_HEADER_NAME = "node-authorization"
)
98 changes: 98 additions & 0 deletions pkg/interceptors/client/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package client

import (
"context"
"log"
"time"

"github.com/xmtp/xmtpd/pkg/authn"
"github.com/xmtp/xmtpd/pkg/constants"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

// AuthInterceptor is a struct for holding the token and adding it to each request.
type AuthInterceptor struct {
tokenFactory *authn.TokenFactory
targetNodeID uint32
currentToken *authn.Token
}

func NewAuthInterceptor(
tokenFactory *authn.TokenFactory,
targetNodeID uint32,
) *AuthInterceptor {
// This should never happen
if tokenFactory == nil {
log.Fatal("tokenFactory is required")
}
return &AuthInterceptor{
tokenFactory: tokenFactory,
targetNodeID: targetNodeID,
}
}

func (i *AuthInterceptor) getToken() (*authn.Token, error) {
// If we have a token that is not expired (or nearing expiry) then return it
if i.currentToken != nil &&
i.currentToken.ExpiresAt.After(time.Now().Add(authn.MAX_CLOCK_SKEW)) {
return i.currentToken, nil
}
token, err := i.tokenFactory.CreateToken(i.targetNodeID)
if err != nil {
return nil, err
}

i.currentToken = token
return token, nil
}

// Unary method to intercept requests and inject the token into headers.
func (i *AuthInterceptor) Unary() grpc.UnaryClientInterceptor {
return func(
ctx context.Context,
method string,
req interface{},
reply interface{},
cc *grpc.ClientConn,
invoker grpc.UnaryInvoker,
opts ...grpc.CallOption,
) error {
token, err := i.getToken()
if err != nil {
return status.Errorf(codes.Unauthenticated, "failed to get token: %v", err)
}
// Create the metadata with the token
md := metadata.Pairs(constants.NODE_AUTHORIZATION_HEADER_NAME, token.SignedString)
// Attach metadata to the outgoing context
ctx = metadata.NewOutgoingContext(ctx, md)

// Proceed with the request
return invoker(ctx, method, req, reply, cc, opts...)
}
}

func (i *AuthInterceptor) Stream() grpc.StreamClientInterceptor {
return func(
ctx context.Context,
desc *grpc.StreamDesc,
cc *grpc.ClientConn,
method string,
streamer grpc.Streamer,
opts ...grpc.CallOption,
) (grpc.ClientStream, error) {
token, err := i.getToken()
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "failed to get token: %v", err)
}
// Create the metadata with the token
md := metadata.Pairs(constants.NODE_AUTHORIZATION_HEADER_NAME, token.SignedString)
// Attach the metadata to the outgoing context
ctx = metadata.NewOutgoingContext(ctx, md)

// Proceed with the stream
return streamer(ctx, desc, cc, method, opts...)
}
}
120 changes: 120 additions & 0 deletions pkg/interceptors/client/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package client

import (
"context"
"net"
"testing"

"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/api"
"github.com/xmtp/xmtpd/pkg/authn"
"github.com/xmtp/xmtpd/pkg/constants"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/testutils"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn"
)

// Create a mock implementation of the ReplicationApiServer interface
// but that embeds `UnimplementedReplicationApiServer` (which mockery won't do for us)
type mockReplicationApiServer struct {
api.Service
expectedToken string
}

func (s *mockReplicationApiServer) QueryEnvelopes(
ctx context.Context,
req *message_api.QueryEnvelopesRequest,
) (*message_api.QueryEnvelopesResponse, error) {
// Get metadata from the context
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Error(codes.Unauthenticated, "metadata is not provided")
}

// Extract and verify the token
tokens := md.Get(constants.NODE_AUTHORIZATION_HEADER_NAME)
if len(tokens) == 0 {
return nil, status.Error(codes.Unauthenticated, "authorization token is not provided")
}
token := tokens[0]
if token != s.expectedToken {
return nil, status.Error(codes.Unauthenticated, "invalid authorization token")
}

// You can add more assertions here to verify the token's content
// For example, you might want to decode the token and check its claims
return &message_api.QueryEnvelopesResponse{}, nil
}

func TestAuthInterceptor(t *testing.T) {
privateKey := testutils.RandomPrivateKey(t)
myNodeID := uint32(100)
targetNodeID := uint32(200)
tokenFactory := authn.NewTokenFactory(privateKey, myNodeID)
interceptor := NewAuthInterceptor(tokenFactory, targetNodeID)
token, err := interceptor.getToken()
require.NoError(t, err)

// Use a bufconn listener to simulate a gRPC connection without actually dialing
listener := bufconn.Listen(1024 * 1024)

// Register the mock service on the server
server := grpc.NewServer()
message_api.RegisterReplicationApiServer(
server,
&mockReplicationApiServer{expectedToken: token.SignedString},
)

// Start the gRPC server in a goroutine
go func() {
if err := server.Serve(listener); err != nil {
t.Fail()
}
}()

t.Cleanup(func() {
server.Stop()
listener.Close()
})

// Connect to the fake server and set the right interceptors
conn, err := grpc.NewClient(
"passthrough://bufnet",
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
return listener.Dial()
}),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithUnaryInterceptor(interceptor.Unary()),
)
require.NoError(t, err)
defer conn.Close()

// Create a client with the connection
client := message_api.NewReplicationApiClient(conn)

// Call the unary method and check the response
_, err = client.QueryEnvelopes(context.Background(), &message_api.QueryEnvelopesRequest{})
require.NoError(t, err)

// Create another client without the interceptor
connWithoutInterceptor, err := grpc.NewClient(
"passthrough://bufnet",
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
return listener.Dial()
}),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
require.NoError(t, err)
defer connWithoutInterceptor.Close()

client = message_api.NewReplicationApiClient(connWithoutInterceptor)

// Call the unary method and check the response
_, err = client.QueryEnvelopes(context.Background(), &message_api.QueryEnvelopesRequest{})
require.Error(t, err)
}
17 changes: 13 additions & 4 deletions pkg/registrant/registrant.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"go.uber.org/zap"

"github.com/ethereum/go-ethereum/crypto"
"github.com/xmtp/xmtpd/pkg/authn"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/proto/identity/associations"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
Expand All @@ -19,8 +20,9 @@ import (
)

type Registrant struct {
record *registry.Node
privateKey *ecdsa.PrivateKey
record *registry.Node
privateKey *ecdsa.PrivateKey
tokenFactory *authn.TokenFactory
}

func NewRegistrant(
Expand All @@ -45,9 +47,12 @@ func NewRegistrant(
return nil, err
}

tokenFactory := authn.NewTokenFactory(privateKey, record.NodeID)

return &Registrant{
record: record,
privateKey: privateKey,
record: record,
privateKey: privateKey,
tokenFactory: tokenFactory,
}, nil
}

Expand All @@ -59,6 +64,10 @@ func (r *Registrant) NodeID() uint32 {
return r.record.NodeID
}

func (r *Registrant) TokenFactory() *authn.TokenFactory {
return r.tokenFactory
}

func (r *Registrant) SignStagedEnvelope(
stagedEnv queries.StagedOriginatorEnvelope,
) (*message_api.OriginatorEnvelope, error) {
Expand Down
4 changes: 4 additions & 0 deletions pkg/sync/syncWorker.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/xmtp/xmtpd/pkg/db/queries"
clientInterceptors "github.com/xmtp/xmtpd/pkg/interceptors/client"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/registrant"
"github.com/xmtp/xmtpd/pkg/registry"
Expand Down Expand Up @@ -111,10 +112,13 @@ func (s *syncWorker) connectToNode(node registry.Node) (*grpc.ClientConn, error)
return nil, fmt.Errorf("Failed to get credentials: %v", err)
}

interceptor := clientInterceptors.NewAuthInterceptor(s.registrant.TokenFactory(), node.NodeID)
conn, err := grpc.NewClient(
target,
grpc.WithTransportCredentials(creds),
grpc.WithDefaultCallOptions(),
grpc.WithUnaryInterceptor(interceptor.Unary()),
grpc.WithStreamInterceptor(interceptor.Stream()),
)
if err != nil {
return nil, fmt.Errorf("Failed to connect to peer at %s: %v", target, err)
Expand Down

0 comments on commit d3fe5c1

Please sign in to comment.