From c35df8aa1f52e9a5f86dab77bd703f6b3656d338 Mon Sep 17 00:00:00 2001 From: sukun Date: Thu, 9 Jan 2025 19:13:31 +0000 Subject: [PATCH] accelerated-dht: cleanup peer from message sender on disconnection (#1009) * accelerated-dht: cleanup peers from message sender on disconnection --------- Co-authored-by: Marco Munizaga --- fullrt/dht.go | 44 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 39 insertions(+), 5 deletions(-) diff --git a/fullrt/dht.go b/fullrt/dht.go index a630d995..295358a1 100644 --- a/fullrt/dht.go +++ b/fullrt/dht.go @@ -15,12 +15,14 @@ import ( "github.com/multiformats/go-multihash" "github.com/libp2p/go-libp2p-routing-helpers/tracing" + "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/routing" + "github.com/libp2p/go-libp2p/p2p/host/eventbus" swarm "github.com/libp2p/go-libp2p/p2p/net/swarm" "github.com/gogo/protobuf/proto" @@ -98,6 +100,8 @@ type FullRT struct { bulkSendParallelism int self peer.ID + + peerConnectednessSubscriber event.Subscription } // NewFullRT creates a DHT client that tracks the full network. It takes a protocol prefix for the given network, @@ -151,6 +155,11 @@ func NewFullRT(h host.Host, protocolPrefix protocol.ID, options ...Option) (*Ful } } + sub, err := h.EventBus().Subscribe(new(event.EvtPeerConnectednessChanged), eventbus.Name("fullrt-dht")) + if err != nil { + return nil, fmt.Errorf("peer connectedness subscription failed: %w", err) + } + ctx, cancel := context.WithCancel(context.Background()) self := h.ID() @@ -195,14 +204,14 @@ func NewFullRT(h host.Host, protocolPrefix protocol.ID, options ...Option) (*Ful crawlerInterval: fullrtcfg.crawlInterval, - bulkSendParallelism: fullrtcfg.bulkSendParallelism, - - self: self, + bulkSendParallelism: fullrtcfg.bulkSendParallelism, + self: self, + peerConnectednessSubscriber: sub, } - rt.wg.Add(1) + rt.wg.Add(2) go rt.runCrawler(ctx) - + go rt.runSubscriber() return rt, nil } @@ -211,6 +220,31 @@ type crawlVal struct { key kadkey.Key } +func (dht *FullRT) runSubscriber() { + defer dht.wg.Done() + ms, ok := dht.messageSender.(dht_pb.MessageSenderWithDisconnect) + defer dht.peerConnectednessSubscriber.Close() + if !ok { + return + } + for { + select { + case e := <-dht.peerConnectednessSubscriber.Out(): + pc, ok := e.(event.EvtPeerConnectednessChanged) + if !ok { + logger.Errorf("invalid event message type: %T", e) + continue + } + + if pc.Connectedness != network.Connected { + ms.OnDisconnect(dht.ctx, pc.Peer) + } + case <-dht.ctx.Done(): + return + } + } +} + func (dht *FullRT) TriggerRefresh(ctx context.Context) error { select { case <-ctx.Done():