diff --git a/p2p/peering.go b/p2p/peering.go index dbe6d209e..7e1bec982 100644 --- a/p2p/peering.go +++ b/p2p/peering.go @@ -114,9 +114,10 @@ type RemotePeer struct { // headers message management. Headers can either be fetched synchronously // or used to push block notifications with sendheaders. - requestedHeaders chan<- *wire.MsgHeaders // non-nil result chan when synchronous getheaders in process - sendheaders bool // whether a sendheaders message was sent - requestedHeadersMu sync.Mutex + requestedHeaders chan<- *wire.MsgHeaders // non-nil result chan when synchronous getheaders in process + requestedHeadersLoc []*chainhash.Hash // non-nil when requested headers with getheaders + sendheaders bool // whether a sendheaders message was sent + requestedHeadersMu sync.Mutex // init state message management. requestedInitState chan<- *wire.MsgInitState // non-nil result chan when synchronous getinitstate in process @@ -1097,7 +1098,7 @@ func (rp *RemotePeer) receivedCFilterV2(ctx context.Context, msg *wire.MsgCFilte } } -func (rp *RemotePeer) addRequestedHeaders(c chan<- *wire.MsgHeaders) (sendheaders, newRequest bool) { +func (rp *RemotePeer) addRequestedHeaders(c chan<- *wire.MsgHeaders, loc []*chainhash.Hash) (sendheaders, newRequest bool) { rp.requestedHeadersMu.Lock() if rp.sendheaders { rp.requestedHeadersMu.Unlock() @@ -1108,6 +1109,7 @@ func (rp *RemotePeer) addRequestedHeaders(c chan<- *wire.MsgHeaders) (sendheader return false, false } rp.requestedHeaders = c + rp.requestedHeadersLoc = loc rp.requestedHeadersMu.Unlock() return false, true } @@ -1115,18 +1117,67 @@ func (rp *RemotePeer) addRequestedHeaders(c chan<- *wire.MsgHeaders) (sendheader func (rp *RemotePeer) deleteRequestedHeaders() { rp.requestedHeadersMu.Lock() rp.requestedHeaders = nil + rp.requestedHeadersLoc = nil rp.requestedHeadersMu.Unlock() } func (rp *RemotePeer) receivedHeaders(ctx context.Context, msg *wire.MsgHeaders) { const opf = "remotepeer(%v).receivedHeaders" rp.requestedHeadersMu.Lock() + + // Ensure the wallet requested headers from this peer. + if !rp.sendheaders && rp.requestedHeaders == nil { + op := errors.Opf(opf, rp.raddr) + err := errors.E(op, errors.Protocol, "received unrequested headers") + rp.Disconnect(err) + rp.requestedHeadersMu.Unlock() + return + } + + // Ensure the remote peer sent as many headers as it could. It can only + // send fewer than 2k headers when the last one is >= their advertised + // height (their tip height). This handles cases where a peer might + // drip headers instead of sending a full batch. + tooFewHeaders := len(msg.Headers) > 0 && + len(msg.Headers) < wire.MaxBlockHeadersPerMsg && + msg.Headers[len(msg.Headers)-1].Height < uint32(rp.initHeight) + if tooFewHeaders { + op := errors.Opf(opf, rp.raddr) + err := errors.E(op, errors.Protocol, "peer sent too few headers") + rp.Disconnect(err) + rp.requestedHeadersMu.Unlock() + return + } + + // The parent of the first header (if there is one) MUST be one of the + // block locators we used to request headers from the peer when this + // is a response to a getheaders request. + if len(msg.Headers) > 0 && rp.requestedHeadersLoc != nil { + wantParent := msg.Headers[0].PrevBlock + contains := false + for _, loc := range rp.requestedHeadersLoc { + if *loc == wantParent { + contains = true + break + } + } + if !contains { + op := errors.Opf(opf, rp.raddr) + err := errors.E(op, errors.Protocol, + "peer sent headers that do not connect "+ + "to block locators") + rp.Disconnect(err) + rp.requestedHeadersMu.Unlock() + return + } + } + + // Sanity check the headers connect to each other in sequence. var prevHash chainhash.Hash var prevHeight uint32 for i, h := range msg.Headers { hash := h.BlockHash() - // Sanity check the headers connect to each other in sequence. if i > 0 && (!prevHash.IsEqual(&h.PrevBlock) || h.Height != prevHeight+1) { op := errors.Opf(opf, rp.raddr) err := errors.E(op, errors.Protocol, "received out-of-sequence headers") @@ -1139,6 +1190,7 @@ func (rp *RemotePeer) receivedHeaders(ctx context.Context, msg *wire.MsgHeaders) prevHeight = h.Height } + // Track the height of the last received header. if prevHeight > 0 { rp.lastHeightMu.Lock() if int32(prevHeight) > rp.lastHeight { @@ -1147,6 +1199,7 @@ func (rp *RemotePeer) receivedHeaders(ctx context.Context, msg *wire.MsgHeaders) rp.lastHeightMu.Unlock() } + // Async headers. if rp.sendheaders { rp.requestedHeadersMu.Unlock() select { @@ -1155,15 +1208,11 @@ func (rp *RemotePeer) receivedHeaders(ctx context.Context, msg *wire.MsgHeaders) } return } - if rp.requestedHeaders == nil { - op := errors.Opf(opf, rp.raddr) - err := errors.E(op, errors.Protocol, "received unrequested headers") - rp.Disconnect(err) - rp.requestedHeadersMu.Unlock() - return - } + + // Headers as a response to getheaders. c := rp.requestedHeaders rp.requestedHeaders = nil + rp.requestedHeadersLoc = nil rp.requestedHeadersMu.Unlock() select { case <-ctx.Done(): @@ -1966,7 +2015,7 @@ func (rp *RemotePeer) Headers(ctx context.Context, blockLocators []*chainhash.Ha HashStop: *hashStop, } c := make(chan *wire.MsgHeaders, 1) - sendheaders, newRequest := rp.addRequestedHeaders(c) + sendheaders, newRequest := rp.addRequestedHeaders(c, blockLocators) if sendheaders { op := errors.Opf(opf, rp.raddr) return nil, errors.E(op, errors.Invalid, "synchronous getheaders after sendheaders is unsupported") @@ -1998,28 +2047,6 @@ func (rp *RemotePeer) Headers(ctx context.Context, blockLocators []*chainhash.Ha case m := <-c: stalled.Stop() - // The parent of the first header (if there is one) MUST - // be one of the block locators we used to request - // headers from the peer. - if len(m.Headers) > 0 { - wantParent := m.Headers[0].PrevBlock - contains := false - for _, loc := range blockLocators { - if *loc == wantParent { - contains = true - break - } - } - if !contains { - op := errors.Opf(opf, rp.raddr) - err := errors.E(op, errors.Protocol, - "peer sent headers that do not connect "+ - "to block locators") - rp.Disconnect(err) - return nil, err - } - } - return m.Headers, nil } }