diff --git a/p2p/node/api.go b/p2p/node/api.go index f692e3992..6a259a840 100644 --- a/p2p/node/api.go +++ b/p2p/node/api.go @@ -44,7 +44,7 @@ func (p *P2PNode) Start() error { // Register the Quai protocol handler p.peerManager.GetHost().SetStreamHandler(quaiprotocol.ProtocolVersion, func(s network.Stream) { - quaiprotocol.QuaiProtocolHandler(s, p) + quaiprotocol.QuaiProtocolHandler(p.ctx, s, p) }) // Start the pubsub manager diff --git a/p2p/node/pubsubManager/gossipsub_test.go b/p2p/node/pubsubManager/gossipsub_test.go index 012416666..e72e4403c 100644 --- a/p2p/node/pubsubManager/gossipsub_test.go +++ b/p2p/node/pubsubManager/gossipsub_test.go @@ -174,7 +174,7 @@ func TestMultipleRequests(t *testing.T) { return pubsub.ValidationAccept } - // SUBSCRIBE + // SUBSCRIBE to all topics for i, topic := range topics { err := ps.SubscribeAndRegisterValidator(common.Location{0, 0}, topic, validatorFunc) require.NoError(t, err, "Failed to subscribe to topic %d", topic) @@ -183,7 +183,7 @@ func TestMultipleRequests(t *testing.T) { } } - // BROADCAST + // Create a buffered channel large enough to hold all sent messages testCh := make(chan interface{}, n*len(topics)) ps.SetReceiveHandler(func(receivedFrom peer.ID, msgId string, msgTopic string, data interface{}, location common.Location) { select { @@ -196,6 +196,7 @@ func TestMultipleRequests(t *testing.T) { var messages []interface{} var wg sync.WaitGroup + // BROADCAST messages concurrently for i := 0; i < n; i++ { newWo := types.CopyWorkObject(wo) newWo.WorkObjectHeader().SetNonce(types.EncodeNonce(uint64(i))) @@ -212,18 +213,22 @@ func TestMultipleRequests(t *testing.T) { messages = append(messages, msg) wg.Add(1) + // Add a sleep to not overwhelm the gossipSub broadcasts + time.Sleep(10 * time.Millisecond) + // Broadcast each message in its own goroutine go func(msg interface{}) { defer wg.Done() - err = ps.Broadcast(common.Location{0, 0}, msg) + err := ps.Broadcast(common.Location{0, 0}, msg) require.NoError(t, err, "Failed to broadcast message") }(msg) } } - // VERIFY + // VERIFY receiving concurrently var mu sync.Mutex receivedMessages := make([]interface{}, 0, n*len(topics)) + for i := 0; i < (n * len(topics)); i++ { wg.Add(1) go func(j int) { @@ -238,11 +243,11 @@ func TestMultipleRequests(t *testing.T) { } }(i) } - + // Wait for all broadcasts to complete wg.Wait() // Ensure all broadcasted messages were received - require.Len(t, receivedMessages, len(messages), "The number of received messages does not match the number of broadcasted messages. expected: %d, got: %d", len(messages), len(receivedMessages)) + require.Len(t, receivedMessages, len(messages), "The number of received messages does not match the number of broadcasted messages. sent: %d, received: %d", len(messages), len(receivedMessages)) ps.Stop() if len(ps.GetTopics()) != 0 { diff --git a/p2p/node/streamManager/streamManager.go b/p2p/node/streamManager/streamManager.go index 567562b8e..6ff46d46d 100644 --- a/p2p/node/streamManager/streamManager.go +++ b/p2p/node/streamManager/streamManager.go @@ -73,9 +73,10 @@ type basicStreamManager struct { } type streamWrapper struct { - stream network.Stream - semaphore chan struct{} - errCount int + stream network.Stream + cancelProtocolHandler context.CancelFunc + semaphore chan struct{} + errCount int } func NewStreamManager(node quaiprotocol.QuaiP2PNode, host host.Host) (*basicStreamManager, error) { @@ -110,6 +111,8 @@ func severStream(key p2p.PeerID, wrappedStream streamWrapper) { if streamMetrics != nil { streamMetrics.WithLabelValues("NumStreams").Dec() } + // Clean up the protocolHandler + wrappedStream.cancelProtocolHandler() } func (sm *basicStreamManager) Start() { @@ -135,7 +138,7 @@ func (sm *basicStreamManager) listenForNewStreamRequest() { go func(peerID peer.ID) { err := sm.OpenStream(peerID) if err != nil { - log.Global.WithFields(log.Fields{"peerId": peerID, "err": err}).Warn("Error opening new strean into peer") + log.Global.WithFields(log.Fields{"peerId": peerID, "err": err}).Warn("Error opening new stream into peer") } }(peerID) @@ -146,34 +149,40 @@ func (sm *basicStreamManager) listenForNewStreamRequest() { } func (sm *basicStreamManager) OpenStream(peerID p2p.PeerID) error { - // check if there is an existing stream + // Check if there is an existing stream if _, ok := sm.streamCache.Get(peerID); ok { return nil } - timer := time.NewTimer(c_stream_timeout) - defer timer.Stop() - select { - case <-timer.C: - return errors.New("stream creation timeout") - default: - // Create a new stream to the peer and register it in the cache - stream, err := sm.host.NewStream(sm.ctx, peerID, quaiprotocol.ProtocolVersion) - if err != nil { - return fmt.Errorf("error opening new stream with peer %s err: %s", peerID, err) - } - wrappedStream := streamWrapper{ - stream: stream, - semaphore: make(chan struct{}, c_maxPendingRequests), - errCount: 0, - } - sm.streamCache.Add(peerID, wrappedStream) - go quaiprotocol.QuaiProtocolHandler(stream, sm.p2pBackend) - log.Global.WithField("PeerID", peerID).Info("Had to create new stream") - if streamMetrics != nil { - streamMetrics.WithLabelValues("NumStreams").Inc() + + streamCtx, streamCancel := context.WithTimeout(sm.ctx, c_stream_timeout) + defer streamCancel() + + // Attempt to create the new stream to the peer + stream, err := sm.host.NewStream(streamCtx, peerID, quaiprotocol.ProtocolVersion) + if err != nil { + if streamCtx.Err() == context.DeadlineExceeded { + return errors.New("stream creation timeout") } - return nil + return fmt.Errorf("error opening new stream with peer %s err: %s", peerID, err) } + + handlerCtx, handlerCancel := context.WithCancel(sm.ctx) + + wrappedStream := streamWrapper{ + stream: stream, + cancelProtocolHandler: handlerCancel, + semaphore: make(chan struct{}, c_maxPendingRequests), + errCount: 0, + } + sm.streamCache.Add(peerID, wrappedStream) + + go quaiprotocol.QuaiProtocolHandler(handlerCtx, stream, sm.p2pBackend) + log.Global.WithField("PeerID", peerID).Info("Had to create new stream") + if streamMetrics != nil { + streamMetrics.WithLabelValues("NumStreams").Inc() + } + + return nil } func (sm *basicStreamManager) CloseStream(peerID p2p.PeerID) error { diff --git a/p2p/node/streamManager/streamManager_test.go b/p2p/node/streamManager/streamManager_test.go index 2624dec6a..e5c8cccdb 100644 --- a/p2p/node/streamManager/streamManager_test.go +++ b/p2p/node/streamManager/streamManager_test.go @@ -1,6 +1,7 @@ package streamManager import ( + "context" "testing" gomock "go.uber.org/mock/gomock" @@ -129,10 +130,12 @@ func TestWriteMessageToStream(t *testing.T) { t.Run("Too many pending requests", func(t *testing.T) { // small semaphore to block the stream + _, cancelFunc := context.WithCancel(context.Background()) wrappedStream := streamWrapper{ - stream: mockLibp2pStream, - semaphore: make(chan struct{}, 1), - errCount: 0, + stream: mockLibp2pStream, + cancelProtocolHandler: cancelFunc, + semaphore: make(chan struct{}, 1), + errCount: 0, } // block semaphore wrappedStream.semaphore <- struct{}{} diff --git a/p2p/protocol/handler.go b/p2p/protocol/handler.go index c96b72dc2..ca2973895 100644 --- a/p2p/protocol/handler.go +++ b/p2p/protocol/handler.go @@ -80,7 +80,7 @@ func ProcRequestRate(peerId peer.ID, inbound bool) error { } // QuaiProtocolHandler handles all the incoming requests and responds with corresponding data -func QuaiProtocolHandler(stream network.Stream, node QuaiP2PNode) { +func QuaiProtocolHandler(ctx context.Context, stream network.Stream, node QuaiP2PNode) { defer func() { if r := recover(); r != nil { log.Global.WithFields(log.Fields{ @@ -102,8 +102,6 @@ func QuaiProtocolHandler(stream network.Stream, node QuaiP2PNode) { // Create a channel for messages msgChan := make(chan []byte, msgChanSize) full := 0 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() go func() { defer func() { if r := recover(); r != nil { @@ -148,7 +146,6 @@ func QuaiProtocolHandler(stream network.Stream, node QuaiP2PNode) { } full++ } - } }