From 80fcbb53209598737f141a0caa1327af2ff1f9d8 Mon Sep 17 00:00:00 2001 From: Nicholas Molnar <65710+neekolas@users.noreply.github.com> Date: Wed, 16 Oct 2024 16:35:55 -0700 Subject: [PATCH] Add auth token to GRPC requests --- pkg/authn/tokenFactory.go | 6 +- pkg/authn/verifier_test.go | 18 ++-- pkg/constants/constants.go | 1 + pkg/interceptors/client/auth.go | 98 ++++++++++++++++++++++ pkg/interceptors/client/auth_test.go | 120 +++++++++++++++++++++++++++ pkg/registrant/registrant.go | 17 +++- pkg/sync/syncWorker.go | 4 + 7 files changed, 248 insertions(+), 16 deletions(-) create mode 100644 pkg/interceptors/client/auth.go create mode 100644 pkg/interceptors/client/auth_test.go diff --git a/pkg/authn/tokenFactory.go b/pkg/authn/tokenFactory.go index 69a9b2f5..67eb5b7a 100644 --- a/pkg/authn/tokenFactory.go +++ b/pkg/authn/tokenFactory.go @@ -14,17 +14,17 @@ const ( type TokenFactory struct { privateKey *ecdsa.PrivateKey - nodeID int32 + nodeID uint32 } -func NewTokenFactory(privateKey *ecdsa.PrivateKey, nodeID int32) *TokenFactory { +func NewTokenFactory(privateKey *ecdsa.PrivateKey, nodeID uint32) *TokenFactory { return &TokenFactory{ privateKey: privateKey, nodeID: nodeID, } } -func (f *TokenFactory) CreateToken(forNodeID int32) (*Token, error) { +func (f *TokenFactory) CreateToken(forNodeID uint32) (*Token, error) { now := time.Now() expiresAt := now.Add(TOKEN_DURATION) diff --git a/pkg/authn/verifier_test.go b/pkg/authn/verifier_test.go index 9599eadd..944fff69 100644 --- a/pkg/authn/verifier_test.go +++ b/pkg/authn/verifier_test.go @@ -53,7 +53,7 @@ func buildJwt( func TestVerifier(t *testing.T) { signerPrivateKey := testutils.RandomPrivateKey(t) - tokenFactory := NewTokenFactory(signerPrivateKey, int32(SIGNER_NODE_ID)) + tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID)) verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID)) nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(®istry.Node{ @@ -62,14 +62,14 @@ func TestVerifier(t *testing.T) { }, nil) // Create a token targeting the verifier's node as the audience - token, err := tokenFactory.CreateToken(int32(VERIFIER_NODE_ID)) + token, err := tokenFactory.CreateToken(uint32(VERIFIER_NODE_ID)) require.NoError(t, err) // This should verify correctly verificationError := verifier.Verify(token.SignedString) require.NoError(t, verificationError) // Create a token targeting a different node as the audience - tokenForWrongNode, err := tokenFactory.CreateToken(int32(300)) + tokenForWrongNode, err := tokenFactory.CreateToken(uint32(300)) require.NoError(t, err) // This should not verify correctly verificationError = verifier.Verify(tokenForWrongNode.SignedString) @@ -79,7 +79,7 @@ func TestVerifier(t *testing.T) { func TestWrongAudience(t *testing.T) { signerPrivateKey := testutils.RandomPrivateKey(t) - tokenFactory := NewTokenFactory(signerPrivateKey, int32(SIGNER_NODE_ID)) + tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID)) verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID)) nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(®istry.Node{ @@ -87,7 +87,7 @@ func TestWrongAudience(t *testing.T) { NodeID: uint32(SIGNER_NODE_ID), }, nil) // Create a token targeting a different node as the audience - tokenForWrongNode, err := tokenFactory.CreateToken(int32(300)) + tokenForWrongNode, err := tokenFactory.CreateToken(uint32(300)) require.NoError(t, err) // This should not verify correctly verificationError := verifier.Verify(tokenForWrongNode.SignedString) @@ -97,12 +97,12 @@ func TestWrongAudience(t *testing.T) { func TestUnknownNode(t *testing.T) { signerPrivateKey := testutils.RandomPrivateKey(t) - tokenFactory := NewTokenFactory(signerPrivateKey, int32(SIGNER_NODE_ID)) + tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID)) verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID)) nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(nil, errors.New("node not found")) - token, err := tokenFactory.CreateToken(int32(VERIFIER_NODE_ID)) + token, err := tokenFactory.CreateToken(uint32(VERIFIER_NODE_ID)) require.NoError(t, err) verificationError := verifier.Verify(token.SignedString) @@ -112,7 +112,7 @@ func TestUnknownNode(t *testing.T) { func TestWrongPublicKey(t *testing.T) { signerPrivateKey := testutils.RandomPrivateKey(t) - tokenFactory := NewTokenFactory(signerPrivateKey, int32(SIGNER_NODE_ID)) + tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID)) verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID)) @@ -122,7 +122,7 @@ func TestWrongPublicKey(t *testing.T) { NodeID: uint32(SIGNER_NODE_ID), }, nil) - token, err := tokenFactory.CreateToken(int32(VERIFIER_NODE_ID)) + token, err := tokenFactory.CreateToken(uint32(VERIFIER_NODE_ID)) require.NoError(t, err) verificationError := verifier.Verify(token.SignedString) diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index 2efe6fac..f34d1733 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -4,4 +4,5 @@ const ( JWT_DOMAIN_SEPARATION_LABEL = "jwt|" PAYER_DOMAIN_SEPARATION_LABEL = "payer|" ORIGINATOR_DOMAIN_SEPARATION_LABEL = "originator|" + NODE_AUTHORIZATION_HEADER_NAME = "node-authorization" ) diff --git a/pkg/interceptors/client/auth.go b/pkg/interceptors/client/auth.go new file mode 100644 index 00000000..e9a0ac93 --- /dev/null +++ b/pkg/interceptors/client/auth.go @@ -0,0 +1,98 @@ +package client + +import ( + "context" + "log" + "time" + + "github.com/xmtp/xmtpd/pkg/authn" + "github.com/xmtp/xmtpd/pkg/constants" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +// AuthInterceptor is a struct for holding the token and adding it to each request. +type AuthInterceptor struct { + tokenFactory *authn.TokenFactory + targetNodeID uint32 + currentToken *authn.Token +} + +func NewAuthInterceptor( + tokenFactory *authn.TokenFactory, + targetNodeID uint32, +) *AuthInterceptor { + // This should never happen + if tokenFactory == nil { + log.Fatal("tokenFactory is required") + } + return &AuthInterceptor{ + tokenFactory: tokenFactory, + targetNodeID: targetNodeID, + } +} + +func (i *AuthInterceptor) getToken() (*authn.Token, error) { + // If we have a token that is not expired (or nearing expiry) then return it + if i.currentToken != nil && + i.currentToken.ExpiresAt.After(time.Now().Add(authn.MAX_CLOCK_SKEW)) { + return i.currentToken, nil + } + token, err := i.tokenFactory.CreateToken(i.targetNodeID) + if err != nil { + return nil, err + } + + i.currentToken = token + return token, nil +} + +// Unary method to intercept requests and inject the token into headers. +func (i *AuthInterceptor) Unary() grpc.UnaryClientInterceptor { + return func( + ctx context.Context, + method string, + req interface{}, + reply interface{}, + cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, + opts ...grpc.CallOption, + ) error { + token, err := i.getToken() + if err != nil { + return status.Errorf(codes.Unauthenticated, "failed to get token: %v", err) + } + // Create the metadata with the token + md := metadata.Pairs(constants.NODE_AUTHORIZATION_HEADER_NAME, token.SignedString) + // Attach metadata to the outgoing context + ctx = metadata.NewOutgoingContext(ctx, md) + + // Proceed with the request + return invoker(ctx, method, req, reply, cc, opts...) + } +} + +func (i *AuthInterceptor) Stream() grpc.StreamClientInterceptor { + return func( + ctx context.Context, + desc *grpc.StreamDesc, + cc *grpc.ClientConn, + method string, + streamer grpc.Streamer, + opts ...grpc.CallOption, + ) (grpc.ClientStream, error) { + token, err := i.getToken() + if err != nil { + return nil, status.Errorf(codes.Unauthenticated, "failed to get token: %v", err) + } + // Create the metadata with the token + md := metadata.Pairs(constants.NODE_AUTHORIZATION_HEADER_NAME, token.SignedString) + // Attach the metadata to the outgoing context + ctx = metadata.NewOutgoingContext(ctx, md) + + // Proceed with the stream + return streamer(ctx, desc, cc, method, opts...) + } +} diff --git a/pkg/interceptors/client/auth_test.go b/pkg/interceptors/client/auth_test.go new file mode 100644 index 00000000..6e2796c4 --- /dev/null +++ b/pkg/interceptors/client/auth_test.go @@ -0,0 +1,120 @@ +package client + +import ( + "context" + "net" + "testing" + + "github.com/stretchr/testify/require" + "github.com/xmtp/xmtpd/pkg/api" + "github.com/xmtp/xmtpd/pkg/authn" + "github.com/xmtp/xmtpd/pkg/constants" + "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" + "github.com/xmtp/xmtpd/pkg/testutils" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/grpc/test/bufconn" +) + +// Create a mock implementation of the ReplicationApiServer interface +// but that embeds `UnimplementedReplicationApiServer` (which mockery won't do for us) +type mockReplicationApiServer struct { + api.Service + expectedToken string +} + +func (s *mockReplicationApiServer) QueryEnvelopes( + ctx context.Context, + req *message_api.QueryEnvelopesRequest, +) (*message_api.QueryEnvelopesResponse, error) { + // Get metadata from the context + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Error(codes.Unauthenticated, "metadata is not provided") + } + + // Extract and verify the token + tokens := md.Get(constants.NODE_AUTHORIZATION_HEADER_NAME) + if len(tokens) == 0 { + return nil, status.Error(codes.Unauthenticated, "authorization token is not provided") + } + token := tokens[0] + if token != s.expectedToken { + return nil, status.Error(codes.Unauthenticated, "invalid authorization token") + } + + // You can add more assertions here to verify the token's content + // For example, you might want to decode the token and check its claims + return &message_api.QueryEnvelopesResponse{}, nil +} + +func TestAuthInterceptor(t *testing.T) { + privateKey := testutils.RandomPrivateKey(t) + myNodeID := uint32(100) + targetNodeID := uint32(200) + tokenFactory := authn.NewTokenFactory(privateKey, myNodeID) + interceptor := NewAuthInterceptor(tokenFactory, targetNodeID) + token, err := interceptor.getToken() + require.NoError(t, err) + + // Use a bufconn listener to simulate a gRPC connection without actually dialing + listener := bufconn.Listen(1024 * 1024) + + // Register the mock service on the server + server := grpc.NewServer() + message_api.RegisterReplicationApiServer( + server, + &mockReplicationApiServer{expectedToken: token.SignedString}, + ) + + // Start the gRPC server in a goroutine + go func() { + if err := server.Serve(listener); err != nil { + t.Fail() + } + }() + + t.Cleanup(func() { + server.Stop() + listener.Close() + }) + + // Connect to the fake server and set the right interceptors + conn, err := grpc.NewClient( + "passthrough://bufnet", + grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { + return listener.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithUnaryInterceptor(interceptor.Unary()), + ) + require.NoError(t, err) + defer conn.Close() + + // Create a client with the connection + client := message_api.NewReplicationApiClient(conn) + + // Call the unary method and check the response + _, err = client.QueryEnvelopes(context.Background(), &message_api.QueryEnvelopesRequest{}) + require.NoError(t, err) + + // Create another client without the interceptor + connWithoutInterceptor, err := grpc.NewClient( + "passthrough://bufnet", + grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { + return listener.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + defer connWithoutInterceptor.Close() + + client = message_api.NewReplicationApiClient(connWithoutInterceptor) + + // Call the unary method and check the response + _, err = client.QueryEnvelopes(context.Background(), &message_api.QueryEnvelopesRequest{}) + require.Error(t, err) +} diff --git a/pkg/registrant/registrant.go b/pkg/registrant/registrant.go index f38889c2..5c3dae25 100644 --- a/pkg/registrant/registrant.go +++ b/pkg/registrant/registrant.go @@ -10,6 +10,7 @@ import ( "go.uber.org/zap" "github.com/ethereum/go-ethereum/crypto" + "github.com/xmtp/xmtpd/pkg/authn" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/proto/identity/associations" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" @@ -19,8 +20,9 @@ import ( ) type Registrant struct { - record *registry.Node - privateKey *ecdsa.PrivateKey + record *registry.Node + privateKey *ecdsa.PrivateKey + tokenFactory *authn.TokenFactory } func NewRegistrant( @@ -45,9 +47,12 @@ func NewRegistrant( return nil, err } + tokenFactory := authn.NewTokenFactory(privateKey, record.NodeID) + return &Registrant{ - record: record, - privateKey: privateKey, + record: record, + privateKey: privateKey, + tokenFactory: tokenFactory, }, nil } @@ -59,6 +64,10 @@ func (r *Registrant) NodeID() uint32 { return r.record.NodeID } +func (r *Registrant) TokenFactory() *authn.TokenFactory { + return r.tokenFactory +} + func (r *Registrant) SignStagedEnvelope( stagedEnv queries.StagedOriginatorEnvelope, ) (*message_api.OriginatorEnvelope, error) { diff --git a/pkg/sync/syncWorker.go b/pkg/sync/syncWorker.go index 02dc6bbd..f0754efd 100644 --- a/pkg/sync/syncWorker.go +++ b/pkg/sync/syncWorker.go @@ -8,6 +8,7 @@ import ( "time" "github.com/xmtp/xmtpd/pkg/db/queries" + clientInterceptors "github.com/xmtp/xmtpd/pkg/interceptors/client" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" "github.com/xmtp/xmtpd/pkg/registrant" "github.com/xmtp/xmtpd/pkg/registry" @@ -111,10 +112,13 @@ func (s *syncWorker) connectToNode(node registry.Node) (*grpc.ClientConn, error) return nil, fmt.Errorf("Failed to get credentials: %v", err) } + interceptor := clientInterceptors.NewAuthInterceptor(s.registrant.TokenFactory(), node.NodeID) conn, err := grpc.NewClient( target, grpc.WithTransportCredentials(creds), grpc.WithDefaultCallOptions(), + grpc.WithUnaryInterceptor(interceptor.Unary()), + grpc.WithStreamInterceptor(interceptor.Stream()), ) if err != nil { return nil, fmt.Errorf("Failed to connect to peer at %s: %v", target, err)