From 70837ceeb1d90a093cd7798668cfcfd403a67d33 Mon Sep 17 00:00:00 2001 From: Nicholas Molnar <65710+neekolas@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:34:54 -0700 Subject: [PATCH] Add node connection pool --- pkg/api/payer/clientManager.go | 59 ++++++++++++++++++++++++++++ pkg/api/payer/clientManager_test.go | 60 +++++++++++++++++++++++++++++ pkg/api/payer/service.go | 17 +++++--- pkg/api/server.go | 44 ++++----------------- pkg/registry/node.go | 40 ++++++++++++++++++- pkg/server/server.go | 32 +++++++++++++-- pkg/testutils/api/api.go | 29 +++++++++++--- 7 files changed, 231 insertions(+), 50 deletions(-) create mode 100644 pkg/api/payer/clientManager.go create mode 100644 pkg/api/payer/clientManager_test.go diff --git a/pkg/api/payer/clientManager.go b/pkg/api/payer/clientManager.go new file mode 100644 index 00000000..e2bb2163 --- /dev/null +++ b/pkg/api/payer/clientManager.go @@ -0,0 +1,59 @@ +package payer + +import ( + "sync" + + "github.com/xmtp/xmtpd/pkg/registry" + "go.uber.org/zap" + "google.golang.org/grpc" +) + +/* +* +The ClientManager contains a mapping of nodeIDs to gRPC client connections. + +These client connections are safe to be shared and re-used and will automatically attempt +to reconnect if the underlying socket connection is lost. +* +*/ +type ClientManager struct { + log *zap.Logger + connections sync.Map // map[uint32]*grpc.ClientConn + nodeRegistry registry.NodeRegistry +} + +func NewClientManager(log *zap.Logger, nodeRegistry registry.NodeRegistry) *ClientManager { + return &ClientManager{log: log, nodeRegistry: nodeRegistry} +} + +func (c *ClientManager) GetClient(nodeID uint32) (*grpc.ClientConn, error) { + existing, ok := c.connections.Load(nodeID) + if ok { + return existing.(*grpc.ClientConn), nil + } + + conn, err := c.newClientConnection(nodeID) + if err != nil { + return nil, err + } + // Store the connection + c.connections.Store(nodeID, conn) + + return conn, nil +} + +func (c *ClientManager) newClientConnection( + nodeID uint32, +) (*grpc.ClientConn, error) { + c.log.Info("connecting to node", zap.Uint32("nodeID", nodeID)) + node, err := c.nodeRegistry.GetNode(nodeID) + if err != nil { + return nil, err + } + conn, err := node.BuildClient() + if err != nil { + return nil, err + } + + return conn, nil +} diff --git a/pkg/api/payer/clientManager_test.go b/pkg/api/payer/clientManager_test.go new file mode 100644 index 00000000..f81b9479 --- /dev/null +++ b/pkg/api/payer/clientManager_test.go @@ -0,0 +1,60 @@ +package payer_test + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/xmtp/xmtpd/pkg/api/payer" + "github.com/xmtp/xmtpd/pkg/registry" + "github.com/xmtp/xmtpd/pkg/testutils" + apiTestUtils "github.com/xmtp/xmtpd/pkg/testutils/api" + "google.golang.org/grpc/health/grpc_health_v1" +) + +func formatAddress(addr string) string { + chunks := strings.Split(addr, ":") + return fmt.Sprintf("http://localhost:%s", chunks[len(chunks)-1]) +} + +func TestClientManager(t *testing.T) { + server1, _, cleanup1 := apiTestUtils.NewTestAPIServer(t) + defer cleanup1() + server2, _, cleanup2 := apiTestUtils.NewTestAPIServer(t) + defer cleanup2() + + nodeRegistry := registry.NewFixedNodeRegistry([]registry.Node{ + { + NodeID: 100, + HttpAddress: formatAddress(server1.Addr().String()), + }, + { + NodeID: 200, + HttpAddress: formatAddress(server2.Addr().String()), + }, + }) + + cm := payer.NewClientManager(testutils.NewLog(t), nodeRegistry) + + client1, err := cm.GetClient(100) + require.NoError(t, err) + require.NotNil(t, client1) + + healthClient := grpc_health_v1.NewHealthClient(client1) + healthResponse, err := healthClient.Check( + context.Background(), + &grpc_health_v1.HealthCheckRequest{}, + ) + require.NoError(t, err) + require.Equal(t, grpc_health_v1.HealthCheckResponse_SERVING, healthResponse.Status) + + client2, err := cm.GetClient(200) + require.NoError(t, err) + require.NotNil(t, client2) + + _, err = cm.GetClient(300) + require.Error(t, err) + +} diff --git a/pkg/api/payer/service.go b/pkg/api/payer/service.go index c45974d3..8e396f89 100644 --- a/pkg/api/payer/service.go +++ b/pkg/api/payer/service.go @@ -4,6 +4,7 @@ import ( "context" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/payer_api" + "github.com/xmtp/xmtpd/pkg/registry" "go.uber.org/zap" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -12,14 +13,20 @@ import ( type Service struct { payer_api.UnimplementedPayerApiServer - ctx context.Context - log *zap.Logger + ctx context.Context + log *zap.Logger + clientManager *ClientManager } -func NewPayerApiService(ctx context.Context, log *zap.Logger) (*Service, error) { +func NewPayerApiService( + ctx context.Context, + log *zap.Logger, + registry registry.NodeRegistry, +) (*Service, error) { return &Service{ - ctx: ctx, - log: log, + ctx: ctx, + log: log, + clientManager: NewClientManager(log, registry), }, nil } diff --git a/pkg/api/server.go b/pkg/api/server.go index 39e700b8..809756e3 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -2,7 +2,6 @@ package api import ( "context" - "database/sql" "fmt" "net" "strings" @@ -12,12 +11,6 @@ import ( "google.golang.org/grpc/reflection" "github.com/pires/go-proxyproto" - "github.com/xmtp/xmtpd/pkg/api/message" - "github.com/xmtp/xmtpd/pkg/api/payer" - "github.com/xmtp/xmtpd/pkg/blockchain" - "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" - "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/payer_api" - "github.com/xmtp/xmtpd/pkg/registrant" "github.com/xmtp/xmtpd/pkg/tracing" "go.uber.org/zap" "google.golang.org/grpc" @@ -27,24 +20,22 @@ import ( "google.golang.org/grpc/keepalive" ) +type RegistrationFunc func(server *grpc.Server) error + type ApiServer struct { ctx context.Context - db *sql.DB grpcListener net.Listener grpcServer *grpc.Server log *zap.Logger - registrant *registrant.Registrant wg sync.WaitGroup } func NewAPIServer( ctx context.Context, - writerDB *sql.DB, log *zap.Logger, port int, - registrant *registrant.Registrant, enableReflection bool, - messagePublisher blockchain.IBlockchainPublisher, + registrationFunc RegistrationFunc, ) (*ApiServer, error) { grpcListener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", port)) @@ -53,18 +44,15 @@ func NewAPIServer( } s := &ApiServer{ ctx: ctx, - db: writerDB, grpcListener: &proxyproto.Listener{ Listener: grpcListener, ReadHeaderTimeout: 10 * time.Second, }, - log: log.Named("api"), - registrant: registrant, - wg: sync.WaitGroup{}, + log: log.Named("api"), + wg: sync.WaitGroup{}, } // TODO: Add interceptors - options := []grpc.ServerOption{ grpc.Creds(insecure.NewCredentials()), grpc.KeepaliveParams(keepalive.ServerParameters{ @@ -78,6 +66,9 @@ func NewAPIServer( } s.grpcServer = grpc.NewServer(options...) + if err := registrationFunc(s.grpcServer); err != nil { + return nil, err + } if enableReflection { // Register reflection service on gRPC server. @@ -88,25 +79,6 @@ func NewAPIServer( healthcheck := health.NewServer() healthgrpc.RegisterHealthServer(s.grpcServer, healthcheck) - replicationService, err := message.NewReplicationApiService( - ctx, - log, - registrant, - writerDB, - messagePublisher, - ) - if err != nil { - return nil, err - } - - message_api.RegisterReplicationApiServer(s.grpcServer, replicationService) - - payerService, err := payer.NewPayerApiService(ctx, log) - if err != nil { - return nil, err - } - payer_api.RegisterPayerApiServer(s.grpcServer, payerService) - tracing.GoPanicWrap(s.ctx, &s.wg, "grpc", func(ctx context.Context) { s.log.Info("serving grpc", zap.String("address", s.grpcListener.Addr().String())) if err = s.grpcServer.Serve(s.grpcListener); err != nil && diff --git a/pkg/registry/node.go b/pkg/registry/node.go index 78cb8915..fd93dc01 100644 --- a/pkg/registry/node.go +++ b/pkg/registry/node.go @@ -1,6 +1,14 @@ package registry -import "crypto/ecdsa" +import ( + "crypto/ecdsa" + "fmt" + + "github.com/xmtp/xmtpd/pkg/utils" + "google.golang.org/grpc" +) + +type DialOptionFunc func(node Node) []grpc.DialOption type Node struct { NodeID uint32 @@ -24,3 +32,33 @@ func (n *Node) Equals(other Node) bool { n.IsHealthy == other.IsHealthy && n.IsValidConfig == other.IsValidConfig } + +func (node *Node) BuildClient( + extraDialOpts ...grpc.DialOption, +) (*grpc.ClientConn, error) { + target, isTLS, err := utils.HttpAddressToGrpcTarget(node.HttpAddress) + if err != nil { + return nil, fmt.Errorf("Failed to convert HTTP address to gRPC target: %v", err) + } + + creds, err := utils.GetCredentialsForAddress(isTLS) + if err != nil { + return nil, fmt.Errorf("Failed to get credentials: %v", err) + } + + dialOpts := append([]grpc.DialOption{ + grpc.WithTransportCredentials(creds), + grpc.WithDefaultCallOptions(), + }, extraDialOpts...) + + conn, err := grpc.NewClient( + target, + dialOpts..., + ) + + if err != nil { + return nil, fmt.Errorf("Failed to create channel at %s: %v", target, err) + } + + return conn, nil +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 022c1581..be35ad01 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -10,11 +10,16 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/collectors" + "github.com/xmtp/xmtpd/pkg/api/message" + "github.com/xmtp/xmtpd/pkg/api/payer" "github.com/xmtp/xmtpd/pkg/blockchain" "github.com/xmtp/xmtpd/pkg/indexer" "github.com/xmtp/xmtpd/pkg/metrics" "github.com/xmtp/xmtpd/pkg/mlsvalidate" + "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" + "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/payer_api" "github.com/xmtp/xmtpd/pkg/sync" + "google.golang.org/grpc" "github.com/xmtp/xmtpd/pkg/api" "github.com/xmtp/xmtpd/pkg/config" @@ -92,6 +97,7 @@ func NewReplicationServer( if err != nil { return nil, err } + err = indexer.StartIndexer( s.ctx, log, @@ -103,15 +109,35 @@ func NewReplicationServer( return nil, err } + serviceRegistrationFunc := func(grpcServer *grpc.Server) error { + replicationService, err := message.NewReplicationApiService( + ctx, + log, + s.registrant, + writerDB, + blockchainPublisher, + ) + if err != nil { + return err + } + message_api.RegisterReplicationApiServer(grpcServer, replicationService) + + payerService, err := payer.NewPayerApiService(ctx, log, s.nodeRegistry) + if err != nil { + return err + } + payer_api.RegisterPayerApiServer(grpcServer, payerService) + + return nil + } + // TODO(rich): Add configuration to specify whether to run API/sync server s.apiServer, err = api.NewAPIServer( s.ctx, - s.writerDB, log, options.API.Port, - s.registrant, options.Reflection.Enable, - blockchainPublisher, + serviceRegistrationFunc, ) if err != nil { return nil, err diff --git a/pkg/testutils/api/api.go b/pkg/testutils/api/api.go index 19e2c1e2..d9887936 100644 --- a/pkg/testutils/api/api.go +++ b/pkg/testutils/api/api.go @@ -9,10 +9,13 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/stretchr/testify/require" "github.com/xmtp/xmtpd/pkg/api" + "github.com/xmtp/xmtpd/pkg/api/message" + "github.com/xmtp/xmtpd/pkg/api/payer" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/mocks/blockchain" mocks "github.com/xmtp/xmtpd/pkg/mocks/registry" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" + "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/payer_api" "github.com/xmtp/xmtpd/pkg/registrant" "github.com/xmtp/xmtpd/pkg/registry" "github.com/xmtp/xmtpd/pkg/testutils" @@ -54,16 +57,32 @@ func NewTestAPIServer(t *testing.T) (*api.ApiServer, *sql.DB, func()) { }, nil) registrant, err := registrant.NewRegistrant(ctx, log, queries.New(db), mockRegistry, privKeyStr) require.NoError(t, err) - mockMessagePublsiher := blockchain.NewMockIBlockchainPublisher(t) + mockMessagePublisher := blockchain.NewMockIBlockchainPublisher(t) + + serviceRegistrationFunc := func(grpcServer *grpc.Server) error { + replicationService, err := message.NewReplicationApiService( + ctx, + log, + registrant, + db, + mockMessagePublisher, + ) + require.NoError(t, err) + message_api.RegisterReplicationApiServer(grpcServer, replicationService) + + payerService, err := payer.NewPayerApiService(ctx, log, mockRegistry) + require.NoError(t, err) + payer_api.RegisterPayerApiServer(grpcServer, payerService) + + return nil + } svr, err := api.NewAPIServer( ctx, - db, log, - 0, /*port*/ - registrant, + 0, /*port*/ true, /*enableReflection*/ - mockMessagePublsiher, + serviceRegistrationFunc, ) require.NoError(t, err)