Skip to content

Commit

Permalink
Add node connection pool
Browse files Browse the repository at this point in the history
  • Loading branch information
neekolas committed Oct 18, 2024
1 parent d0201d6 commit 70837ce
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 50 deletions.
59 changes: 59 additions & 0 deletions pkg/api/payer/clientManager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package payer

import (
"sync"

"github.com/xmtp/xmtpd/pkg/registry"
"go.uber.org/zap"
"google.golang.org/grpc"
)

/*
*
The ClientManager contains a mapping of nodeIDs to gRPC client connections.
These client connections are safe to be shared and re-used and will automatically attempt
to reconnect if the underlying socket connection is lost.
*
*/
type ClientManager struct {
log *zap.Logger
connections sync.Map // map[uint32]*grpc.ClientConn
nodeRegistry registry.NodeRegistry
}

func NewClientManager(log *zap.Logger, nodeRegistry registry.NodeRegistry) *ClientManager {
return &ClientManager{log: log, nodeRegistry: nodeRegistry}
}

func (c *ClientManager) GetClient(nodeID uint32) (*grpc.ClientConn, error) {
existing, ok := c.connections.Load(nodeID)
if ok {
return existing.(*grpc.ClientConn), nil
}

conn, err := c.newClientConnection(nodeID)
if err != nil {
return nil, err
}
// Store the connection
c.connections.Store(nodeID, conn)

return conn, nil
}

func (c *ClientManager) newClientConnection(
nodeID uint32,
) (*grpc.ClientConn, error) {
c.log.Info("connecting to node", zap.Uint32("nodeID", nodeID))
node, err := c.nodeRegistry.GetNode(nodeID)
if err != nil {
return nil, err
}
conn, err := node.BuildClient()
if err != nil {
return nil, err
}

return conn, nil
}
60 changes: 60 additions & 0 deletions pkg/api/payer/clientManager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package payer_test

import (
"context"
"fmt"
"strings"
"testing"

"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/api/payer"
"github.com/xmtp/xmtpd/pkg/registry"
"github.com/xmtp/xmtpd/pkg/testutils"
apiTestUtils "github.com/xmtp/xmtpd/pkg/testutils/api"
"google.golang.org/grpc/health/grpc_health_v1"
)

func formatAddress(addr string) string {
chunks := strings.Split(addr, ":")
return fmt.Sprintf("http://localhost:%s", chunks[len(chunks)-1])
}

func TestClientManager(t *testing.T) {
server1, _, cleanup1 := apiTestUtils.NewTestAPIServer(t)
defer cleanup1()
server2, _, cleanup2 := apiTestUtils.NewTestAPIServer(t)
defer cleanup2()

nodeRegistry := registry.NewFixedNodeRegistry([]registry.Node{
{
NodeID: 100,
HttpAddress: formatAddress(server1.Addr().String()),
},
{
NodeID: 200,
HttpAddress: formatAddress(server2.Addr().String()),
},
})

cm := payer.NewClientManager(testutils.NewLog(t), nodeRegistry)

client1, err := cm.GetClient(100)
require.NoError(t, err)
require.NotNil(t, client1)

healthClient := grpc_health_v1.NewHealthClient(client1)
healthResponse, err := healthClient.Check(
context.Background(),
&grpc_health_v1.HealthCheckRequest{},
)
require.NoError(t, err)
require.Equal(t, grpc_health_v1.HealthCheckResponse_SERVING, healthResponse.Status)

client2, err := cm.GetClient(200)
require.NoError(t, err)
require.NotNil(t, client2)

_, err = cm.GetClient(300)
require.Error(t, err)

}
17 changes: 12 additions & 5 deletions pkg/api/payer/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"

"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/payer_api"
"github.com/xmtp/xmtpd/pkg/registry"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand All @@ -12,14 +13,20 @@ import (
type Service struct {
payer_api.UnimplementedPayerApiServer

ctx context.Context
log *zap.Logger
ctx context.Context
log *zap.Logger
clientManager *ClientManager
}

func NewPayerApiService(ctx context.Context, log *zap.Logger) (*Service, error) {
func NewPayerApiService(
ctx context.Context,
log *zap.Logger,
registry registry.NodeRegistry,
) (*Service, error) {
return &Service{
ctx: ctx,
log: log,
ctx: ctx,
log: log,
clientManager: NewClientManager(log, registry),
}, nil
}

Expand Down
44 changes: 8 additions & 36 deletions pkg/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package api

import (
"context"
"database/sql"
"fmt"
"net"
"strings"
Expand All @@ -12,12 +11,6 @@ import (
"google.golang.org/grpc/reflection"

"github.com/pires/go-proxyproto"
"github.com/xmtp/xmtpd/pkg/api/message"
"github.com/xmtp/xmtpd/pkg/api/payer"
"github.com/xmtp/xmtpd/pkg/blockchain"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/payer_api"
"github.com/xmtp/xmtpd/pkg/registrant"
"github.com/xmtp/xmtpd/pkg/tracing"
"go.uber.org/zap"
"google.golang.org/grpc"
Expand All @@ -27,24 +20,22 @@ import (
"google.golang.org/grpc/keepalive"
)

type RegistrationFunc func(server *grpc.Server) error

type ApiServer struct {
ctx context.Context
db *sql.DB
grpcListener net.Listener
grpcServer *grpc.Server
log *zap.Logger
registrant *registrant.Registrant
wg sync.WaitGroup
}

func NewAPIServer(
ctx context.Context,
writerDB *sql.DB,
log *zap.Logger,
port int,
registrant *registrant.Registrant,
enableReflection bool,
messagePublisher blockchain.IBlockchainPublisher,
registrationFunc RegistrationFunc,
) (*ApiServer, error) {
grpcListener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", port))

Expand All @@ -53,18 +44,15 @@ func NewAPIServer(
}
s := &ApiServer{
ctx: ctx,
db: writerDB,
grpcListener: &proxyproto.Listener{
Listener: grpcListener,
ReadHeaderTimeout: 10 * time.Second,
},
log: log.Named("api"),
registrant: registrant,
wg: sync.WaitGroup{},
log: log.Named("api"),
wg: sync.WaitGroup{},
}

// TODO: Add interceptors

options := []grpc.ServerOption{
grpc.Creds(insecure.NewCredentials()),
grpc.KeepaliveParams(keepalive.ServerParameters{
Expand All @@ -78,6 +66,9 @@ func NewAPIServer(
}

s.grpcServer = grpc.NewServer(options...)
if err := registrationFunc(s.grpcServer); err != nil {
return nil, err
}

if enableReflection {
// Register reflection service on gRPC server.
Expand All @@ -88,25 +79,6 @@ func NewAPIServer(
healthcheck := health.NewServer()
healthgrpc.RegisterHealthServer(s.grpcServer, healthcheck)

replicationService, err := message.NewReplicationApiService(
ctx,
log,
registrant,
writerDB,
messagePublisher,
)
if err != nil {
return nil, err
}

message_api.RegisterReplicationApiServer(s.grpcServer, replicationService)

payerService, err := payer.NewPayerApiService(ctx, log)
if err != nil {
return nil, err
}
payer_api.RegisterPayerApiServer(s.grpcServer, payerService)

tracing.GoPanicWrap(s.ctx, &s.wg, "grpc", func(ctx context.Context) {
s.log.Info("serving grpc", zap.String("address", s.grpcListener.Addr().String()))
if err = s.grpcServer.Serve(s.grpcListener); err != nil &&
Expand Down
40 changes: 39 additions & 1 deletion pkg/registry/node.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
package registry

import "crypto/ecdsa"
import (
"crypto/ecdsa"
"fmt"

"github.com/xmtp/xmtpd/pkg/utils"
"google.golang.org/grpc"
)

type DialOptionFunc func(node Node) []grpc.DialOption

type Node struct {
NodeID uint32
Expand All @@ -24,3 +32,33 @@ func (n *Node) Equals(other Node) bool {
n.IsHealthy == other.IsHealthy &&
n.IsValidConfig == other.IsValidConfig
}

func (node *Node) BuildClient(
extraDialOpts ...grpc.DialOption,
) (*grpc.ClientConn, error) {
target, isTLS, err := utils.HttpAddressToGrpcTarget(node.HttpAddress)
if err != nil {
return nil, fmt.Errorf("Failed to convert HTTP address to gRPC target: %v", err)
}

creds, err := utils.GetCredentialsForAddress(isTLS)
if err != nil {
return nil, fmt.Errorf("Failed to get credentials: %v", err)
}

dialOpts := append([]grpc.DialOption{
grpc.WithTransportCredentials(creds),
grpc.WithDefaultCallOptions(),
}, extraDialOpts...)

conn, err := grpc.NewClient(
target,
dialOpts...,
)

if err != nil {
return nil, fmt.Errorf("Failed to create channel at %s: %v", target, err)
}

return conn, nil
}
32 changes: 29 additions & 3 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,16 @@ import (

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/xmtp/xmtpd/pkg/api/message"
"github.com/xmtp/xmtpd/pkg/api/payer"
"github.com/xmtp/xmtpd/pkg/blockchain"
"github.com/xmtp/xmtpd/pkg/indexer"
"github.com/xmtp/xmtpd/pkg/metrics"
"github.com/xmtp/xmtpd/pkg/mlsvalidate"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/payer_api"
"github.com/xmtp/xmtpd/pkg/sync"
"google.golang.org/grpc"

"github.com/xmtp/xmtpd/pkg/api"
"github.com/xmtp/xmtpd/pkg/config"
Expand Down Expand Up @@ -92,6 +97,7 @@ func NewReplicationServer(
if err != nil {
return nil, err
}

err = indexer.StartIndexer(
s.ctx,
log,
Expand All @@ -103,15 +109,35 @@ func NewReplicationServer(
return nil, err
}

serviceRegistrationFunc := func(grpcServer *grpc.Server) error {
replicationService, err := message.NewReplicationApiService(
ctx,
log,
s.registrant,
writerDB,
blockchainPublisher,
)
if err != nil {
return err
}
message_api.RegisterReplicationApiServer(grpcServer, replicationService)

payerService, err := payer.NewPayerApiService(ctx, log, s.nodeRegistry)
if err != nil {
return err
}
payer_api.RegisterPayerApiServer(grpcServer, payerService)

return nil
}

// TODO(rich): Add configuration to specify whether to run API/sync server
s.apiServer, err = api.NewAPIServer(
s.ctx,
s.writerDB,
log,
options.API.Port,
s.registrant,
options.Reflection.Enable,
blockchainPublisher,
serviceRegistrationFunc,
)
if err != nil {
return nil, err
Expand Down
Loading

0 comments on commit 70837ce

Please sign in to comment.