diff --git a/header/p2p/exchange.go b/header/p2p/exchange.go index 55abceaec4..a488fc6f97 100644 --- a/header/p2p/exchange.go +++ b/header/p2p/exchange.go @@ -5,10 +5,10 @@ import ( "context" "fmt" "math/rand" + "sort" "time" logging "github.com/ipfs/go-log/v2" - "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/protocol" @@ -29,6 +29,8 @@ const ( writeDeadline = time.Second * 5 // readDeadline sets timeout for reading messages from the stream readDeadline = time.Minute + // the target minimum amount of responses with the same chain head + minResponses = 2 ) // PubSubTopic hardcodes the name of the ExtendedHeader @@ -61,11 +63,35 @@ func (ex *Exchange) Head(ctx context.Context) (*header.ExtendedHeader, error) { Data: &p2p_pb.ExtendedHeaderRequest_Origin{Origin: uint64(0)}, Amount: 1, } - headers, err := ex.performRequest(ctx, req) - if err != nil { - return nil, err + + headerCh := make(chan *header.ExtendedHeader) + // request head from each trusted peer + for _, from := range ex.trustedPeers { + go func(from peer.ID) { + headers, err := request(ctx, from, ex.host, req) + if err != nil { + log.Errorw("head request to trusted peer failed", "trustedPeer", from, "err", err) + headerCh <- nil + return + } + // doRequest ensures that the result slice will have at least one ExtendedHeader + headerCh <- headers[0] + }(from) + } + + result := make([]*header.ExtendedHeader, 0, len(ex.trustedPeers)) + for range ex.trustedPeers { + select { + case h := <-headerCh: + if h != nil { + result = append(result, h) + } + case <-ctx.Done(): + return nil, ctx.Err() + } } - return headers[0], nil + + return bestHead(result) } // GetByHeight performs a request for the ExtendedHeader at the given @@ -135,7 +161,17 @@ func (ex *Exchange) performRequest( //nolint:gosec // G404: Use of weak random number generator index := rand.Intn(len(ex.trustedPeers)) - stream, err := ex.host.NewStream(ctx, ex.trustedPeers[index], exchangeProtocolID) + return request(ctx, ex.trustedPeers[index], ex.host, req) +} + +// request sends the ExtendedHeaderRequest to a remote peer. +func request( + ctx context.Context, + to peer.ID, + host host.Host, + req *p2p_pb.ExtendedHeaderRequest, +) ([]*header.ExtendedHeader, error) { + stream, err := host.NewStream(ctx, to, exchangeProtocolID) if err != nil { return nil, err } @@ -172,9 +208,41 @@ func (ex *Exchange) performRequest( headers[i] = header } + if err = stream.Close(); err != nil { + log.Errorw("closing stream", "err", err) + } // ensure at least one header was retrieved if len(headers) == 0 { return nil, header.ErrNotFound } - return headers, stream.Close() + return headers, nil +} + +// bestHead chooses ExtendedHeader that matches the conditions: +// * should have max height among received; +// * should be received at least from 2 peers; +// If neither condition is met, then latest ExtendedHeader will be returned (header of the highest height). +func bestHead(result []*header.ExtendedHeader) (*header.ExtendedHeader, error) { + if len(result) == 0 { + return nil, header.ErrNotFound + } + counter := make(map[string]int) + // go through all of ExtendedHeaders and count the number of headers with a specific hash + for _, res := range result { + counter[res.Hash().String()]++ + } + // sort results in a decreasing order + sort.Slice(result, func(i, j int) bool { + return result[i].Height > result[j].Height + }) + + // try to find ExtendedHeader with the maximum height that was received at least from 2 peers + for _, res := range result { + if counter[res.Hash().String()] >= minResponses { + return res, nil + } + } + log.Debug("could not find latest header received from at least two peers, returning header with the max height") + // otherwise return header with the max height + return result[0], nil } diff --git a/header/p2p/exchange_test.go b/header/p2p/exchange_test.go index a2c28d6194..af07e6ac12 100644 --- a/header/p2p/exchange_test.go +++ b/header/p2p/exchange_test.go @@ -96,6 +96,71 @@ func TestExchange_RequestByHash(t *testing.T) { assert.Equal(t, store.headers[reqHeight].Hash(), eh.Hash()) } +func Test_bestHead(t *testing.T) { + gen := func() []*header.ExtendedHeader { + suite := header.NewTestSuite(t, 3) + res := make([]*header.ExtendedHeader, 0) + for i := 0; i < 3; i++ { + res = append(res, suite.GenExtendedHeader()) + } + return res + } + testCases := []struct { + precondition func() []*header.ExtendedHeader + expectedHeight int64 + }{ + /* + Height -> Amount + headerHeight[0]=1 -> 1 + headerHeight[1]=2 -> 1 + headerHeight[2]=3 -> 1 + result -> headerHeight[2] + */ + { + precondition: gen, + expectedHeight: 3, + }, + /* + Height -> Amount + headerHeight[0]=1 -> 2 + headerHeight[1]=2 -> 1 + headerHeight[2]=3 -> 1 + result -> headerHeight[0] + */ + { + precondition: func() []*header.ExtendedHeader { + res := gen() + res = append(res, res[0]) + return res + }, + expectedHeight: 1, + }, + /* + Height -> Amount + headerHeight[0]=1 -> 3 + headerHeight[1]=2 -> 2 + headerHeight[2]=3 -> 1 + result -> headerHeight[1] + */ + { + precondition: func() []*header.ExtendedHeader { + res := gen() + res = append(res, res[0]) + res = append(res, res[0]) + res = append(res, res[1]) + return res + }, + expectedHeight: 2, + }, + } + for _, tt := range testCases { + res := tt.precondition() + header, err := bestHead(res) + require.NoError(t, err) + require.True(t, header.Height == tt.expectedHeight) + } +} + func createMocknet(t *testing.T) (libhost.Host, libhost.Host) { net, err := mocknet.FullMeshConnected(2) require.NoError(t, err)