Skip to content

Commit

Permalink
Reduce slice memory allocation when generating a proof
Browse files Browse the repository at this point in the history
  • Loading branch information
mininny committed Oct 5, 2024
1 parent f8f596b commit c5a040b
Showing 1 changed file with 104 additions and 89 deletions.
193 changes: 104 additions & 89 deletions rvgo/fast/radix.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ type RadixNode interface {
// InvalidateNode invalidates the hash cache along the path to the specified address.
InvalidateNode(addr uint64)
// GenerateProof generates the Merkle proof for the given address.
GenerateProof(addr uint64) [][32]byte
GenerateProof(addr uint64, proofs [][32]byte)
// MerkleizeNode computes the Merkle root hash for the node at the given generalized index.
MerkleizeNode(addr, gindex uint64) [32]byte
}
Expand Down Expand Up @@ -84,51 +84,52 @@ func (m *Memory) InvalidateNode(addr uint64) {

// GenerateProof generates the Merkle proof for the given address.
// It collects the necessary sibling hashes along the path to reconstruct the Merkle proof.
func (n *SmallRadixNode[C]) GenerateProof(addr uint64) [][32]byte {
var proofs [][32]byte
func (n *SmallRadixNode[C]) GenerateProof(addr uint64, proofs [][32]byte) {
path := addressToRadixPath(addr, n.Depth, 4)

if n.Children[path] == nil {
// When no child exists at this path, the rest of the proofs are zero hashes.
proofs = zeroHashRange(0, 60-n.Depth-4)
fillZeroHashRange(proofs, 0, 60-n.Depth-4)
} else {
// Recursively generate proofs from the child node.
proofs = (*n.Children[path]).GenerateProof(addr)
(*n.Children[path]).GenerateProof(addr, proofs)
}

// Collect sibling hashes along the path for the proof.
proofIndex := 60 - n.Depth - 4
for idx := path + 1<<4; idx > 1; idx >>= 1 {
sibling := idx ^ 1 // Get the sibling index.
proofs = append(proofs, n.MerkleizeNode(addr>>(64-n.Depth), sibling))
proofs[proofIndex] = n.MerkleizeNode(addr>>(64-n.Depth), sibling)
proofIndex += 1
}

return proofs
}

func (n *LargeRadixNode[C]) GenerateProof(addr uint64) [][32]byte {
var proofs [][32]byte
func (n *LargeRadixNode[C]) GenerateProof(addr uint64, proofs [][32]byte) {
path := addressToRadixPath(addr, n.Depth, 8)

if n.Children[path] == nil {
proofs = zeroHashRange(0, 60-n.Depth-8)
fillZeroHashRange(proofs, 0, 60-n.Depth-8)
} else {
proofs = (*n.Children[path]).GenerateProof(addr)
(*n.Children[path]).GenerateProof(addr, proofs)
}

proofIndex := 60 - n.Depth - 8
for idx := path + 1<<8; idx > 1; idx >>= 1 {
sibling := idx ^ 1
proofs = append(proofs, n.MerkleizeNode(addr>>(64-n.Depth), sibling))
proofs[proofIndex] = n.MerkleizeNode(addr>>(64-n.Depth), sibling)
proofIndex += 1
}
return proofs
}

func (m *Memory) GenerateProof(addr uint64) [][32]byte {
func (m *Memory) GenerateProof(addr uint64, proofs [][32]byte) {
pageIndex := addr >> PageAddrSize

// number of proof for a page is 8
// 0: leaf page data, 7: page's root
if p, ok := m.pages[pageIndex]; ok {
return p.GenerateProof(addr) // Generate proof from the page.
pageProofs := p.GenerateProof(addr) // Generate proof from the page.
copy(proofs[:8], pageProofs)
} else {
return zeroHashRange(0, 8) // Return zero hashes if the page does not exist.
fillZeroHashRange(proofs, 0, 8) // Return zero hashes if the page does not exist.
}
}

Expand All @@ -138,82 +139,86 @@ func (m *Memory) GenerateProof(addr uint64) [][32]byte {
func (n *SmallRadixNode[C]) MerkleizeNode(addr, gindex uint64) [32]byte {
depth := uint64(bits.Len64(gindex)) // Get the depth of the current gindex.

if depth <= 4 {
hashBit := gindex & 15

if (n.ChildExists & (1 << hashBit)) != 0 {
if (n.HashValid & (1 << hashBit)) != 0 {
// Return the cached hash if valid.
return n.Hashes[gindex]
} else {
left := n.MerkleizeNode(addr, gindex<<1)
right := n.MerkleizeNode(addr, (gindex<<1)|1)

// Hash the pair and cache the result.
r := HashPair(left, right)
n.Hashes[gindex] = r
n.HashValid |= 1 << hashBit
return r
}
} else {
// Return zero hash for non-existent child.
if depth > 5 {
panic("gindex too deep")
}

// Leaf node of the radix trie (17~32)
if depth > 4 {
childIndex := gindex - 1<<4

if n.Children[childIndex] == nil {
// Return zero hash if child does not exist.
return zeroHashes[64-5+1-(depth+n.Depth)]
}
}

if depth > 5 {
panic("gindex too deep")
// Update the partial address by appending the child index bits.
// This accumulates the address as we traverse deeper into the trie.
addr <<= 4
addr |= childIndex
return (*n.Children[childIndex]).MerkleizeNode(addr, 1)
}

childIndex := gindex - 1<<4
// Intermediate node of the radix trie (0~16)
hashBit := gindex & 15

if n.Children[childIndex] == nil {
// Return zero hash if child does not exist.
if (n.ChildExists & (1 << hashBit)) != 0 {
if (n.HashValid & (1 << hashBit)) != 0 {
// Return the cached hash if valid.
return n.Hashes[gindex]
} else {
left := n.MerkleizeNode(addr, gindex<<1)
right := n.MerkleizeNode(addr, (gindex<<1)|1)

// Hash the pair and cache the result.
r := HashPair(left, right)
n.Hashes[gindex] = r
n.HashValid |= 1 << hashBit
return r
}
} else {
// Return zero hash for non-existent child.
return zeroHashes[64-5+1-(depth+n.Depth)]
}

// Update the partial address by appending the child index bits.
// This accumulates the address as we traverse deeper into the trie.
addr <<= 4
addr |= childIndex
return (*n.Children[childIndex]).MerkleizeNode(addr, 1)
}

func (n *LargeRadixNode[C]) MerkleizeNode(addr, gindex uint64) [32]byte {
depth := uint64(bits.Len64(gindex))

if depth <= 8 {
hashIndex := gindex >> 6
hashBit := gindex & 63
if (n.ChildExists[hashIndex] & (1 << hashBit)) != 0 {
if (n.HashValid[hashIndex] & (1 << hashBit)) != 0 {
return n.Hashes[gindex]
} else {
left := n.MerkleizeNode(addr, gindex<<1)
right := n.MerkleizeNode(addr, (gindex<<1)|1)

r := HashPair(left, right)
n.Hashes[gindex] = r
n.HashValid[hashIndex] |= 1 << hashBit
return r
}
} else {
if depth > 9 {
panic("gindex too deep")
}

// Leaf node of the radix trie (2^8~2^16)
if depth > 8 {
childIndex := gindex - 1<<8
if n.Children[int(childIndex)] == nil {
return zeroHashes[64-5+1-(depth+n.Depth)]
}
}

if depth > 16 {
panic("gindex too deep")
addr <<= 8
addr |= childIndex
return (*n.Children[childIndex]).MerkleizeNode(addr, 1)
}

childIndex := gindex - 1<<8
if n.Children[int(childIndex)] == nil {
// Intermediate node of the radix trie (0~2^7)
hashIndex := gindex >> 6
hashBit := gindex & 63
if (n.ChildExists[hashIndex] & (1 << hashBit)) != 0 {
if (n.HashValid[hashIndex] & (1 << hashBit)) != 0 {
return n.Hashes[gindex]
} else {
left := n.MerkleizeNode(addr, gindex<<1)
right := n.MerkleizeNode(addr, (gindex<<1)|1)

r := HashPair(left, right)
n.Hashes[gindex] = r
n.HashValid[hashIndex] |= 1 << hashBit
return r
}
} else {
return zeroHashes[64-5+1-(depth+n.Depth)]
}

addr <<= 8
addr |= childIndex
return (*n.Children[childIndex]).MerkleizeNode(addr, 1)
}

func (m *Memory) MerkleizeNode(addr, gindex uint64) [32]byte {
Expand All @@ -234,21 +239,20 @@ func (m *Memory) MerkleRoot() [32]byte {

// MerkleProof generates the Merkle proof for the specified address in memory.
func (m *Memory) MerkleProof(addr uint64) [ProofLen * 32]byte {
proofs := m.radix.GenerateProof(addr)
proofs := make([][32]byte, 60)
m.radix.GenerateProof(addr, proofs)
return encodeProofs(proofs)
}

// zeroHashRange returns a slice of zero hashes from start to end.
func zeroHashRange(start, end uint64) [][32]byte {
proofs := make([][32]byte, end-start)
func fillZeroHashRange(slice [][32]byte, start, end uint64) {
if start == 0 {
proofs[0] = zeroHashes[0]
slice[0] = zeroHashes[0]
start++
}
for i := start; i < end; i++ {
proofs[i] = zeroHashes[i-1]
slice[i] = zeroHashes[i-1]
}
return proofs
}

// encodeProofs encodes the list of proof hashes into a byte array.
Expand Down Expand Up @@ -293,67 +297,78 @@ func (m *Memory) AllocPage(pageIndex uint64) *CachedPage {

addr := pageIndex << PageAddrSize
branchPaths := m.addressToRadixPaths(addr)
depth := uint64(0)

// Build the radix trie path to the new page, creating nodes as necessary.
// This code is a bit repetitive, but better for the compiler to optimize.
radixLevel1 := m.radix
depth += m.branchFactors[0]
if (*radixLevel1).Children[branchPaths[0]] == nil {
node := &SmallRadixNode[L3]{Depth: 4}
node := &SmallRadixNode[L3]{Depth: depth}
(*radixLevel1).Children[branchPaths[0]] = &node
}
radixLevel1.InvalidateNode(addr)

radixLevel2 := (*radixLevel1).Children[branchPaths[0]]
depth += m.branchFactors[1]
if (*radixLevel2).Children[branchPaths[1]] == nil {
node := &SmallRadixNode[L4]{Depth: 8}
node := &SmallRadixNode[L4]{Depth: depth}
(*radixLevel2).Children[branchPaths[1]] = &node
}
(*radixLevel2).InvalidateNode(addr)

radixLevel3 := (*radixLevel2).Children[branchPaths[1]]
depth += m.branchFactors[2]
if (*radixLevel3).Children[branchPaths[2]] == nil {
node := &SmallRadixNode[L5]{Depth: 12}
node := &SmallRadixNode[L5]{Depth: depth}
(*radixLevel3).Children[branchPaths[2]] = &node
}
(*radixLevel3).InvalidateNode(addr)

radixLevel4 := (*radixLevel3).Children[branchPaths[2]]
depth += m.branchFactors[3]
if (*radixLevel4).Children[branchPaths[3]] == nil {
node := &SmallRadixNode[L6]{Depth: 16}
node := &SmallRadixNode[L6]{Depth: depth}
(*radixLevel4).Children[branchPaths[3]] = &node
}
(*radixLevel4).InvalidateNode(addr)

radixLevel5 := (*radixLevel4).Children[branchPaths[3]]
depth += m.branchFactors[4]
if (*radixLevel5).Children[branchPaths[4]] == nil {
node := &SmallRadixNode[L7]{Depth: 20}
node := &SmallRadixNode[L7]{Depth: depth}
(*radixLevel5).Children[branchPaths[4]] = &node
}
(*radixLevel5).InvalidateNode(addr)

radixLevel6 := (*radixLevel5).Children[branchPaths[4]]
depth += m.branchFactors[5]
if (*radixLevel6).Children[branchPaths[5]] == nil {
node := &SmallRadixNode[L8]{Depth: 24}
node := &SmallRadixNode[L8]{Depth: depth}
(*radixLevel6).Children[branchPaths[5]] = &node
}
(*radixLevel6).InvalidateNode(addr)

radixLevel7 := (*radixLevel6).Children[branchPaths[5]]
depth += m.branchFactors[6]
if (*radixLevel7).Children[branchPaths[6]] == nil {
node := &LargeRadixNode[L9]{Depth: 28}
node := &LargeRadixNode[L9]{Depth: depth}
(*radixLevel7).Children[branchPaths[6]] = &node
}
(*radixLevel7).InvalidateNode(addr)

radixLevel8 := (*radixLevel7).Children[branchPaths[6]]
depth += m.branchFactors[7]
if (*radixLevel8).Children[branchPaths[7]] == nil {
node := &LargeRadixNode[L10]{Depth: 36}
node := &LargeRadixNode[L10]{Depth: depth}
(*radixLevel8).Children[branchPaths[7]] = &node
}
(*radixLevel8).InvalidateNode(addr)

radixLevel9 := (*radixLevel8).Children[branchPaths[7]]
depth += m.branchFactors[8]
if (*radixLevel9).Children[branchPaths[8]] == nil {
node := &LargeRadixNode[L11]{Depth: 44}
node := &LargeRadixNode[L11]{Depth: depth}
(*radixLevel9).Children[branchPaths[8]] = &node
}
(*radixLevel9).InvalidateNode(addr)
Expand Down

0 comments on commit c5a040b

Please sign in to comment.