diff --git a/internal/transports/p2p/peer/peer.go b/internal/transports/p2p/peer/peer.go index 593ae4f7..483d415c 100644 --- a/internal/transports/p2p/peer/peer.go +++ b/internal/transports/p2p/peer/peer.go @@ -7,6 +7,7 @@ import ( "math/big" "net" "strconv" + "sync" "time" "github.com/bitcoin-sv/block-headers-service/config" @@ -38,8 +39,10 @@ type Peer struct { verAckReceived bool synced bool - msgChan chan wire.Message - quit chan struct{} + wg sync.WaitGroup + msgChan chan wire.Message + quitting bool + quit chan struct{} } func NewPeer( @@ -80,7 +83,9 @@ func NewPeer( services: wire.SFspv, protocolVersion: initialProtocolVersion, synced: nextCheckpoint == nil, + wg: sync.WaitGroup{}, msgChan: make(chan wire.Message), + quitting: false, quit: make(chan struct{}), } return peer, nil @@ -99,8 +104,16 @@ func (p *Peer) Connect() error { func (p *Peer) Disconnect() error { p.log.Info().Msgf("disconnecting peer: %s", p.addr.String()) - p.conn.Close() + + p.quitting = true close(p.quit) + err := p.conn.Close() + if err != nil { + return err + } + + p.wg.Wait() + p.log.Info().Msgf("successfully disconnected peer %s", p) return nil } @@ -156,9 +169,11 @@ func (p *Peer) negotiateProtocol() error { // PingHandler is a handler for sending ping messages to peers. // Must be run as a goroutine. func (p *Peer) pingHandler() { + p.wg.Add(1) pingTicker := time.NewTicker(pingInterval) defer pingTicker.Stop() +out: for { select { case <-pingTicker.C: @@ -171,23 +186,31 @@ func (p *Peer) pingHandler() { p.queueMessage(wire.NewMsgPing(nonce)) case <-p.quit: - return + break out } } + + p.log.Info().Msgf("ping handler shutdown for peer %s", p) + p.wg.Done() } // MsgHandler is a message handler for incoming messages. // Must be run as a goroutine. func (p *Peer) readMsgHandler() { + p.wg.Add(1) + +out: for { select { case <-p.quit: - return + break out default: remoteMsg, _, err := wire.ReadMessage(p.conn, p.protocolVersion, p.chainParams.Net) if err != nil { - p.log.Error().Msgf("cannot read message from peer %s, reason: %v", p.addr.String(), err) + if !p.quitting { + p.log.Error().Msgf("cannot read message from peer %s, reason: %v", p.addr.String(), err) + } continue } @@ -205,6 +228,9 @@ func (p *Peer) readMsgHandler() { } } } + + p.log.Info().Msgf("read msg handler shutdown for peer %s", p) + p.wg.Done() } func (p *Peer) writeMessage(msg wire.Message) error { @@ -234,14 +260,29 @@ func (p *Peer) queueMessage(msg wire.Message) { // writeMsgHandler serves as a queue for writing messages to peers, // must be run as a goroutine. func (p *Peer) writeMsgHandler() { + p.wg.Add(1) + +out: for { select { case msg := <-p.msgChan: p.writeMessage(msg) case <-p.quit: - return + break out + } + } + +cleanup: + for { + select { + case <-p.msgChan: + default: + break cleanup } } + + p.log.Info().Msgf("write msg handler shutdown for peer %s", p) + p.wg.Done() } func (p *Peer) writeOurVersionMsg() error { @@ -486,3 +527,7 @@ func (p *Peer) verifyCheckpointReached(h *domains.BlockHeader, receivedCheckpoin } return receivedCheckpoint, nil } + +func (p *Peer) String() string { + return p.addr.String() +} diff --git a/internal/transports/p2p/server.go b/internal/transports/p2p/server.go index 3c11e77b..6b551899 100644 --- a/internal/transports/p2p/server.go +++ b/internal/transports/p2p/server.go @@ -73,6 +73,9 @@ func (s *server) Start() error { } func (s *server) Shutdown() error { - s.peer.Disconnect() + err := s.peer.Disconnect() + if err != nil { + return err + } return nil }