Skip to content

Commit

Permalink
Merge branch 'main' into dependabot/go_modules/main/github.com/rs/zer…
Browse files Browse the repository at this point in the history
…olog-1.33.0
  • Loading branch information
ac4ch authored Jun 3, 2024
2 parents 5701cbf + e90e707 commit 2eedd30
Show file tree
Hide file tree
Showing 9 changed files with 382 additions and 65 deletions.
27 changes: 27 additions & 0 deletions database/repository/header_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,30 @@ func (r *HeaderRepository) GetChainBetweenTwoHashes(low string, high string) ([]
}
return nil, err
}

// GetHeadersStartHeight returns height of the highest header from the list of hashes.
func (r *HeaderRepository) GetHeadersStartHeight(hashtable []string) (int, error) {
sh, err := r.db.GetHeadersStartHeight(hashtable)
if err != nil {
return 0, err
}
return sh, nil
}

// GetHeadersByHeightRange returns headers from db in specified height range.
func (r *HeaderRepository) GetHeadersByHeightRange(from int, to int) ([]*domains.BlockHeader, error) {
bh, err := r.db.GetHeadersByHeightRange(from, to)
if err != nil {
return nil, err
}
return dto.ConvertToBlockHeader(bh), nil
}

// GetHeadersStopHeight returns height of hashstop header from db.
func (r *HeaderRepository) GetHeadersStopHeight(hashStop string) (int, error) {
hs, err := r.db.GetHeadersStopHeight(hashStop)
if err != nil {
return 0, err
}
return hs, nil
}
61 changes: 61 additions & 0 deletions database/sql/headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
const (
HeadersTableName = "headers"

longestChainState = "LONGEST_CHAIN"

sqlInsertHeader = `
INSERT INTO headers(hash, height, version, merkleroot, nonce, bits, header_state, chainwork, previous_block, timestamp , cumulated_work)
VALUES(:hash, :height, :version, :merkleroot, :nonce, :bits, :header_state, :chainwork, :previous_block, :timestamp, :cumulated_work)
Expand All @@ -34,6 +36,12 @@ const (
WHERE hash = ?
`

sqlHeaderHeightFromHashAndState = `
SELECT height
FROM headers
WHERE hash = ? AND header_state = ?
`

sqlHeaderByHeight = `
SELECT hash, height, version, merkleroot, nonce, bits, chainwork, previous_block, timestamp, header_state, cumulated_work
FROM headers
Expand Down Expand Up @@ -160,6 +168,20 @@ const (
sqlTipOfChainHeight = `SELECT MAX(height) FROM headers WHERE header_state = 'LONGEST_CHAIN'`

sqlVerifyHash = `SELECT hash FROM headers WHERE merkleroot = $1 AND height = $2 AND header_state = 'LONGEST_CHAIN'`

sqlGetHeadersHeight = `
SELECT COALESCE(MAX(height), 0) AS startHeight
FROM headers
WHERE header_state = 'LONGEST_CHAIN'
AND hash IN (?)
`

sqlHeaderByHeightRangeLongestChain = `
SELECT
hash, height, version, merkleroot, nonce, bits, chainwork, previous_block, timestamp, header_state, cumulated_work
FROM headers
WHERE height BETWEEN ? AND ? AND header_state = 'LONGEST_CHAIN';
`
)

// HeadersDb represents a database connection and map of related sql queries.
Expand Down Expand Up @@ -396,6 +418,45 @@ func (h *HeadersDb) GetMerkleRootsConfirmations(
return confirmations, nil
}

// GetHashStartHeight returns hash and height from db with given locators.
func (h *HeadersDb) GetHeadersStartHeight(hashTable []string) (int, error) {
query, args, err := sqlx.In(sqlGetHeadersHeight, hashTable)
if err != nil {
h.log.Error().Err(err).Msg("Error while constructing query")
return 0, err
}

var heightStart int
if err := h.db.Get(&heightStart, h.db.Rebind(query), args...); err != nil {
h.log.Error().Err(err).Msg("Failed to get headers by locators")
return 0, err
}

return heightStart, nil
}

// GetHeadersStopHeight will return header from db with given hash.
func (h *HeadersDb) GetHeadersStopHeight(hashStop string) (int, error) {
var dbHashStopHeight int
if err := h.db.Get(&dbHashStopHeight, h.db.Rebind(sqlHeaderHeightFromHashAndState), hashStop, longestChainState); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return 0, nil
}
return 0, errors.Wrapf(err, "failed to get stophash %s", hashStop)
}

return dbHashStopHeight, nil
}

// GetHeadersByHeightRange returns headers from db in specified height range.
func (h *HeadersDb) GetHeadersByHeightRange(from int, to int) ([]*dto.DbBlockHeader, error) {
var listOfHeaders []*dto.DbBlockHeader
if err := h.db.Select(&listOfHeaders, h.db.Rebind(sqlHeaderByHeightRangeLongestChain), from, to); err != nil {
return nil, errors.Wrapf(err, "failed to get headers using given range from: %d to: %d", from, to)
}
return listOfHeaders, nil
}

func (h *HeadersDb) getChainTipHeight() (int32, error) {
var tipHeight int32
err := h.db.Get(&tipHeight, sqlTipOfChainHeight)
Expand Down
7 changes: 5 additions & 2 deletions internal/tests/fixtures/blockheader_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ const (
DefaultChainWork = 4295032833
)

// HashOf returns chainhash.Hash representation of string, ignoring errors.
// HashOf returns chainhash.Hash representation of string, panic when error occurs.
func HashOf(s string) *chainhash.Hash {
h, _ := chainhash.NewHashFromStr(s)
h, err := chainhash.NewHashFromStr(s)
if err != nil {
panic("Invalid hash string")
}
return h
}

Expand Down
33 changes: 33 additions & 0 deletions internal/tests/testrepository/header_testrepository.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,39 @@ func (r *HeaderTestRepository) GetMerkleRootsConfirmations(
return mrcfs, nil
}

func (r *HeaderTestRepository) GetHeadersStartHeight(hashtable []string) (int, error) {
for i := len(*r.db) - 1; i >= 0; i-- {
header := (*r.db)[i]
for j := len(hashtable) - 1; j >= 0; j-- {
if header.Hash.String() == hashtable[j] && header.State == domains.LongestChain {
return int(header.Height), nil
}
}
}
return 0, nil
}

func (r *HeaderTestRepository) GetHeadersByHeightRange(from int, to int) ([]*domains.BlockHeader, error) {
filteredHeaders := make([]*domains.BlockHeader, 0)
for _, header := range *r.db {
if header.Height >= int32(from) && header.Height <= int32(to) {
headerCopy := header
filteredHeaders = append(filteredHeaders, &headerCopy)
}
}
return filteredHeaders, nil
}

func (r *HeaderTestRepository) GetHeadersStopHeight(hashStop string) (int, error) {
for i := len(*r.db) - 1; i >= 0; i-- {
header := (*r.db)[i]
if header.Hash.String() == hashStop {
return int(header.Height), nil
}
}
return 0, errors.New("could not find stop height")
}

// FillWithLongestChain fills the test header repository
// with 4 additional blocks to create a longest chain.
func (r *HeaderTestRepository) FillWithLongestChain() {
Expand Down
20 changes: 20 additions & 0 deletions internal/transports/p2p/peer/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ func (p *Peer) readMsgHandler() {
p.handleHeadersMsg(msg)
case *wire.MsgInv:
p.handleInvMsg(msg)
case *wire.MsgGetHeaders:
p.handleGetHeadersMsg(msg)
default:
p.log.Info().Msgf("received msg of type: %T", msg)
}
Expand Down Expand Up @@ -555,6 +557,24 @@ func (p *Peer) handleHeadersMsg(msg *wire.MsgHeaders) {
}
}

func (p *Peer) handleGetHeadersMsg(msg *wire.MsgGetHeaders) {
p.log.Info().Msgf("received getheaders msg from peer %s", p)
if !p.syncedCheckpoints {
p.log.Info().Msgf("we are still syncing, ignoring getHeaders msg from peer %s", p)
return
}

bh, err := p.headersService.LocateHeadersGetHeaders(msg.BlockLocatorHashes, &msg.HashStop)
if err != nil {
p.log.Error().Msgf("error locating headers for getheaders msg from peer %s, reason: %v", p, err)
return
}

msgHeaders := wire.NewMsgHeaders()
msgHeaders.Headers = bh
p.queueMessage(msgHeaders)
}

func (p *Peer) switchToSendHeadersMode() {
if !p.sendHeadersMode && p.protocolVersion >= wire.SendHeadersVersion {
p.log.Info().Msgf("switching to send headers mode - requesting peer %s to send us headers directly instead of inv msg", p)
Expand Down
3 changes: 3 additions & 0 deletions repository/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ type Headers interface {
GetAllTips() ([]*domains.BlockHeader, error)
GetAncestorOnHeight(hash string, height int32) (*domains.BlockHeader, error)
GetChainBetweenTwoHashes(low string, high string) ([]*domains.BlockHeader, error)
GetHeadersStartHeight(hashtable []string) (int, error)
GetHeadersByHeightRange(from int, to int) ([]*domains.BlockHeader, error)
GetHeadersStopHeight(hashStop string) (int, error)
}

// Tokens is a interface which represents methods performed on tokens table in defined storage.
Expand Down
130 changes: 67 additions & 63 deletions service/header_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,86 +291,90 @@ func (hs *HeaderService) GetMerkleRootsConfirmations(
return hs.repo.Headers.GetMerkleRootsConfirmations(request, hs.merkleCfg.MaxBlockHeightExcess)
}

// LocateHeaders fetches headers for a number of blocks after the most recent known block
// in the best chain, based on the provided block locator and stop hash, and defaults to the
// genesis block if the locator is unknown.
func (hs *HeaderService) LocateHeaders(locator domains.BlockLocator, hashStop *chainhash.Hash) []wire.BlockHeader {
headers := hs.locateHeaders(locator, hashStop, wire.MaxBlockHeadersPerMsg)
return headers
// LocateHeadersGetHeaders returns headers with given hashes.
func (hs *HeaderService) LocateHeadersGetHeaders(locators []*chainhash.Hash, hashstop *chainhash.Hash) ([]*wire.BlockHeader, error) {
headers, err := hs.locateHeadersGetHeaders(locators, hashstop)
if err != nil {
return nil, err
}
return headers, nil
}

func (hs *HeaderService) locateHeaders(locator domains.BlockLocator, hashStop *chainhash.Hash, maxHeaders uint32) []wire.BlockHeader {
// Find the node after the first known block in the locator and the
// total number of nodes after it needed while respecting the stop hash
// and max entries.
node, total := hs.locateInventory(locator, hashStop, maxHeaders)
if total == 0 {
return nil
func (hs *HeaderService) locateHeadersGetHeaders(locators []*chainhash.Hash, hashstop *chainhash.Hash) ([]*wire.BlockHeader, error) {

if len(locators) == 0 {
return nil, errors.New("no locators provided")
}

// Populate and return the found headers.
headers := make([]wire.BlockHeader, 0, total)
for i := uint32(0); i < total; i++ {
header := wire.BlockHeader{
Version: node.Version,
PrevBlock: node.PreviousBlock,
MerkleRoot: node.MerkleRoot,
Timestamp: node.Timestamp,
Bits: node.Bits,
Nonce: node.Nonce,
}
headers = append(headers, header)
node = hs.nodeByHeight(node.Height + 1)
hashes := make([]string, len(locators))
for i, v := range locators {
hashes[i] = v.String()
}
return headers
}

func (hs *HeaderService) locateInventory(locator domains.BlockLocator, hashStop *chainhash.Hash, maxEntries uint32) (*domains.BlockHeader, uint32) {
// There are no block locators so a specific block is being requested
// as identified by the stop hash.
stopNode, _ := hs.GetHeaderByHash(hashStop.String())
if len(locator) == 0 {
if stopNode == nil {
// No blocks with the stop hash were found so there is
// nothing to do.
return nil, 0
startHeight, err := hs.repo.Headers.GetHeadersStartHeight(hashes)
if err != nil {
return nil, fmt.Errorf("error getting headers of locators: %v", err)
}
var stopHeight int
if hashstop.IsEqual(&chainhash.Hash{}) {
stopHeight = startHeight + wire.MaxCFHeadersPerMsg
} else {
stopHeight, err = hs.repo.Headers.GetHeadersStopHeight(hashstop.String())
if err != nil {
return nil, fmt.Errorf("error getting hashstop height: %v", err)
}
return stopNode, 1
}

// Find the most recent locator block hash in the main chain. In the
// case none of the hashes in the locator are in the main chain, fall
// back to the genesis block.
startNode, _ := hs.repo.Headers.GetHeaderByHeight(0)
for _, hash := range locator {
node, _ := hs.GetHeaderByHash(hash.String())
if node != nil && hs.Contains(node) {
startNode = node
break
}
if stopHeight == 0 {
stopHeight = startHeight + wire.MaxCFHeadersPerMsg
}

// Start at the block after the most recently known block. When there
// is no next block it means the most recently known block is the tip of
// the best chain, so there is nothing more to do.
next := hs.Next(startNode)
if next == nil {
return nil, 0
if stopHeight <= startHeight {
return nil, errors.New("hashStop is lower than first valid height")
}
startNode = next

// Calculate how many entries are needed.
total := uint32((hs.GetTipHeight() - startNode.Height) + 1)
if stopNode != nil && hs.Contains(stopNode) &&
stopNode.Height >= startNode.Height {
// Check if peer requested number of headers is higher than the maximum number of headers per message
if wire.MaxCFHeadersPerMsg < stopHeight-startHeight {
stopHeight = startHeight + wire.MaxCFHeadersPerMsg
}

total = uint32((stopNode.Height - startNode.Height) + 1)
dbHeaders, err := hs.repo.Headers.GetHeadersByHeightRange(startHeight+1, stopHeight)
if err != nil {
return nil, fmt.Errorf("error getting headers between heights: %v", err)
}

headers := make([]*wire.BlockHeader, 0, len(dbHeaders))
for _, dbHeader := range dbHeaders {
header := &wire.BlockHeader{
Version: dbHeader.Version,
PrevBlock: dbHeader.PreviousBlock,
MerkleRoot: dbHeader.MerkleRoot,
Timestamp: dbHeader.Timestamp,
Bits: dbHeader.Bits,
Nonce: dbHeader.Nonce,
}
headers = append(headers, header)
}
if total > maxEntries {
total = maxEntries

return headers, nil
}

// LocateHeaders fetches headers for a number of blocks after the most recent known block
// in the best chain, based on the provided block locator and stop hash, and defaults to the
// genesis block if the locator is unknown.
func (hs *HeaderService) LocateHeaders(locator domains.BlockLocator, hashStop *chainhash.Hash) []wire.BlockHeader {
headers, err := hs.locateHeadersGetHeaders(locator, hashStop)
if err != nil {
hs.log.Error().Msg(err.Error())
return nil
}

result := make([]wire.BlockHeader, 0, len(headers))
for _, header := range headers {
result = append(result, *header)
}

return startNode, total
return result
}

// Contains checks if given header is stored in db.
Expand Down
Loading

0 comments on commit 2eedd30

Please sign in to comment.