Skip to content

Commit

Permalink
add safety checks
Browse files Browse the repository at this point in the history
  • Loading branch information
rauljordan committed Sep 18, 2024
1 parent 5915a85 commit 220ff27
Showing 1 changed file with 34 additions and 22 deletions.
56 changes: 34 additions & 22 deletions state-commitments/optimized/inclusion_proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"math"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
)

// Computes the Merkle proof for a leaf at a given index.
Expand All @@ -14,12 +13,13 @@ func (h *HistoryCommitter) computeMerkleProof(leafIndex uint64, leaves []common.
if len(leaves) == 0 {
return nil, nil
}
// TODO: Add all other safety conditions.
if virtual == 0 {
return nil, errors.New("virtual size must be greater than 0")
if leafIndex >= uint64(len(leaves)) {
return nil, errors.New("leaf index out of bounds")
}
if virtual < uint64(len(leaves)) {
return nil, errors.New("virtual size must be greater than or equal to the number of leaves")
}
numRealLeaves := uint64(len(leaves))
// Last leaf used for padding.
lastLeaf := leaves[numRealLeaves-1]
depth := int(math.Ceil(math.Log2(float64(virtual))))

Expand All @@ -28,11 +28,13 @@ func (h *HistoryCommitter) computeMerkleProof(leafIndex uint64, leaves []common.
if err != nil {
return nil, err
}

var proof []common.Hash
for level := 0; level < depth; level++ {
nodeIndex := leafIndex >> level
siblingHash, exists := computeSiblingHash(nodeIndex, uint64(level), numRealLeaves, virtual, leaves, virtualHashes)
siblingHash, exists, err := h.computeSiblingHash(nodeIndex, uint64(level), numRealLeaves, virtual, leaves, virtualHashes)
if err != nil {
return nil, err
}
if exists {
proof = append(proof, siblingHash)
}
Expand All @@ -41,46 +43,56 @@ func (h *HistoryCommitter) computeMerkleProof(leafIndex uint64, leaves []common.
}

// Computes the hash of a node's sibling at a given index and level.
func computeSiblingHash(
func (h *HistoryCommitter) computeSiblingHash(
nodeIndex uint64,
level uint64,
N uint64,
virtual uint64,
hLeaves []common.Hash,
hNHashes []common.Hash,
) (common.Hash, bool) {
) (common.Hash, bool, error) {
siblingIndex := nodeIndex ^ 1
numNodes := (virtual + (1 << level) - 1) / (1 << level) // Equivalent to ceil(virtual / (2 ** level))
// Essentially ceil(virtual / (2 ** level))
numNodes := (virtual + (1 << level) - 1) / (1 << level)
if siblingIndex >= numNodes {
// No sibling exists, so use a zero hash.
return common.Hash{}, false
return common.Hash{}, false, nil
} else if siblingIndex >= paddingStartIndexAtLevel(N, level) {
return hNHashes[level], true
return hNHashes[level], true, nil
} else {
siblingHash := computeNodeHash(siblingIndex, level, N, hLeaves, hNHashes)
return siblingHash, true
siblingHash, err := h.computeNodeHash(siblingIndex, level, N, hLeaves, hNHashes)
if err != nil {
return emptyHash, false, err
}
return siblingHash, true, nil
}
}

// Recursively computes the hash of a node at a given index and level.
func computeNodeHash(
func (h *HistoryCommitter) computeNodeHash(
nodeIndex uint64, level uint64, numRealLeaves uint64, leaves []common.Hash, virtualHashes []common.Hash,
) common.Hash {
) (common.Hash, error) {
if level == 0 {
if nodeIndex >= numRealLeaves {
// Node is in padding (the virtual segment of the tree).
return virtualHashes[0]
return virtualHashes[0], nil
} else {
return leaves[nodeIndex]
return leaves[nodeIndex], nil
}
} else {
if nodeIndex >= paddingStartIndexAtLevel(numRealLeaves, level) {
return virtualHashes[level]
return virtualHashes[level], nil
} else {
leftChild := computeNodeHash(2*nodeIndex, level-1, numRealLeaves, leaves, virtualHashes)
rightChild := computeNodeHash(2*nodeIndex+1, level-1, numRealLeaves, leaves, virtualHashes)
leftChild, err := h.computeNodeHash(2*nodeIndex, level-1, numRealLeaves, leaves, virtualHashes)
if err != nil {
return emptyHash, err
}
rightChild, err := h.computeNodeHash(2*nodeIndex+1, level-1, numRealLeaves, leaves, virtualHashes)
if err != nil {
return emptyHash, err
}
data := append(leftChild.Bytes(), rightChild.Bytes()...)
return crypto.Keccak256Hash(data)
return h.hash(data)
}
}
}
Expand Down

0 comments on commit 220ff27

Please sign in to comment.