diff --git a/privval/load_balancer.go b/privval/load_balancer.go index c9bc576..7e84430 100644 --- a/privval/load_balancer.go +++ b/privval/load_balancer.go @@ -2,7 +2,6 @@ package privval import ( "errors" - "sync" cometlog "github.com/cometbft/cometbft/libs/log" privvalproto "github.com/cometbft/cometbft/proto/tendermint/privval" @@ -12,33 +11,32 @@ import ( type RemoteSignerLoadBalancer struct { logger cometlog.Logger listeners []SignerListener + avail chan SignerListener // Available listeners that are ready to accept requests. } func NewRemoteSignerLoadBalancer(logger cometlog.Logger, listeners []SignerListener) *RemoteSignerLoadBalancer { + ch := make(chan SignerListener, len(listeners)) + for i := range listeners { + ch <- listeners[i] + } return &RemoteSignerLoadBalancer{ logger: logger, listeners: listeners, + avail: ch, } } // SendRequest sends a request to the first available listener. -func (sl *RemoteSignerLoadBalancer) SendRequest(request privvalproto.Message) (*privvalproto.Message, error) { - var r racer - var res signerListenerEndpointResponse - - r.wg.Add(1) - - for _, listener := range sl.listeners { - go sl.sendRequestIfFirst(listener, &r, request, &res) - } +func (lb *RemoteSignerLoadBalancer) SendRequest(request privvalproto.Message) (*privvalproto.Message, error) { + lis := <-lb.avail + defer func() { lb.avail <- lis }() - r.wg.Wait() - - return res.res, res.err + lb.logger.Debug("Sent request to listener", "address", lis.address) + return lis.SendRequest(request) } -func (sl *RemoteSignerLoadBalancer) Start() error { - for _, listener := range sl.listeners { +func (lb *RemoteSignerLoadBalancer) Start() error { + for _, listener := range lb.listeners { if err := listener.Start(); err != nil { return err } @@ -46,46 +44,10 @@ func (sl *RemoteSignerLoadBalancer) Start() error { return nil } -func (sl *RemoteSignerLoadBalancer) Stop() error { - var errs []error - for _, listener := range sl.listeners { - if err := listener.Stop(); err != nil { - errs = append(errs, err) - } - } - return errors.Join(errs...) -} - -type signerListenerEndpointResponse struct { - res *privvalproto.Message - err error -} - -func (l *RemoteSignerLoadBalancer) sendRequestIfFirst(listener SignerListener, r *racer, request privvalproto.Message, res *signerListenerEndpointResponse) { - listener.instanceMtx.Lock() - defer listener.instanceMtx.Unlock() - first := r.race() - if !first { - return - } - res.res, res.err = listener.SendRequestLocked(request) - r.wg.Done() - l.logger.Debug("Sent request to listener", "address", listener.address) -} - -type racer struct { - mu sync.Mutex - wg sync.WaitGroup - handled bool -} - -// returns true if first -func (r *racer) race() bool { - r.mu.Lock() - defer r.mu.Unlock() - if r.handled { - return false +func (lb *RemoteSignerLoadBalancer) Stop() error { + var err error + for _, listener := range lb.listeners { + err = errors.Join(err, listener.Stop()) } - r.handled = true - return true + return err } diff --git a/privval/load_balancer_test.go b/privval/load_balancer_test.go index 23e7b1a..924be5a 100644 --- a/privval/load_balancer_test.go +++ b/privval/load_balancer_test.go @@ -1,6 +1,7 @@ package privval_test import ( + "io" "net" "testing" "time" @@ -14,12 +15,6 @@ import ( "github.com/strangelove-ventures/horcrux-proxy/privval" ) -type devNull struct{} - -func (devNull) Write(p []byte) (int, error) { - return len(p), nil -} - func TestLoadBalancer(t *testing.T) { var listenAddrs = []string{ "tcp://127.0.0.1:37321", @@ -28,7 +23,7 @@ func TestLoadBalancer(t *testing.T) { "tcp://127.0.0.1:37324", } - logger := log.NewTMJSONLogger(devNull{}) + logger := log.NewTMJSONLogger(io.Discard) listeners := make([]privval.SignerListener, len(listenAddrs)) for i, addr := range listenAddrs { @@ -37,13 +32,11 @@ func TestLoadBalancer(t *testing.T) { lb := privval.NewRemoteSignerLoadBalancer(logger, listeners) - err := lb.Start() - t.Cleanup(func() { _ = lb.Stop() }) - require.NoError(t, err) + require.NoError(t, lb.Start()) remoteSigners := make([]*MockRemoteSigner, len(listenAddrs)) @@ -70,8 +63,7 @@ func TestLoadBalancer(t *testing.T) { }) } - err = eg.Wait() - require.NoError(t, err) + require.NoError(t, eg.Wait()) total := 0 for i := range listenAddrs { diff --git a/privval/remote_signer_test.go b/privval/remote_signer_test.go index 0aecbf5..f92a2ba 100644 --- a/privval/remote_signer_test.go +++ b/privval/remote_signer_test.go @@ -31,8 +31,8 @@ type MockRemoteSigner struct { dialer net.Dialer } -func (m *MockRemoteSigner) Counter() Counter { - return m.counter.Copy() +func (rs *MockRemoteSigner) Counter() Counter { + return rs.counter.Copy() } // NewMockRemoteSigner return a MockRemoteSigner that will dial using the given diff --git a/privval/signer_listener_endpoint.go b/privval/signer_listener_endpoint.go index 3a6d454..9b6b033 100644 --- a/privval/signer_listener_endpoint.go +++ b/privval/signer_listener_endpoint.go @@ -108,11 +108,6 @@ func (sl *SignerListenerEndpoint) SendRequest(request privvalproto.Message) (*pr sl.instanceMtx.Lock() defer sl.instanceMtx.Unlock() - return sl.SendRequestLocked(request) -} - -// SendRequest ensures there is a connection, sends a request and waits for a response -func (sl *SignerListenerEndpoint) SendRequestLocked(request privvalproto.Message) (*privvalproto.Message, error) { err := sl.ensureConnection(sl.timeoutAccept) if err != nil { return nil, err