Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(pkg/trie/triedb): implement generics, refactor nibbles, revise iterator #4221

Merged
merged 9 commits into from
Oct 15, 2024
5 changes: 5 additions & 0 deletions internal/primitives/core/hash/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ func (h256 H256) String() string {
return fmt.Sprintf("%v", h256.Bytes())
}

// Length returns the byte length of H256
func (h256 H256) Length() int {
return 32
}

// MarshalSCALE fulfils the SCALE interface for encoding
func (h256 H256) MarshalSCALE() ([]byte, error) {
var arr [32]byte
Expand Down
2 changes: 1 addition & 1 deletion pkg/trie/triedb/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ It offers functionalities for writing and reading operations and uses lazy loadi
- **Reads**: Basic functions to get data from the trie.
- **Lazy Loading**: Load data on demand.
- **Caching**: Enhances search performance.
- **Compatibility**: Works with any database implementing the `db.RWDatabase` interface and any cache implementing the `cache.TrieCache` interface.
- **Compatibility**: Works with any database implementing the `db.RWDatabase` interface and any cache implementing the `Cache` interface (wip).
- **Merkle proofs**: Create and verify merkle proofs.
- **Iterator**: Traverse the trie keys in order.

Expand Down
21 changes: 0 additions & 21 deletions pkg/trie/triedb/child_tries.go

This file was deleted.

46 changes: 26 additions & 20 deletions pkg/trie/triedb/codec/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import (
"fmt"
"io"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/pkg/scale"
"github.com/ChainSafe/gossamer/pkg/trie/triedb/hash"
"github.com/ChainSafe/gossamer/pkg/trie/triedb/nibbles"
)

var (
Expand All @@ -25,13 +26,11 @@ var (
ErrDecodeStorageValue = errors.New("cannot decode storage value")
)

const hashLength = common.HashLength

// Decode decodes a node from a reader.
// The encoding format is documented in the README.md
// of this package, and specified in the Polkadot spec at
// https://spec.polkadot.network/chap-state#defn-node-header
func Decode(reader io.Reader) (n EncodedNode, err error) {
func Decode[H hash.Hash](reader io.Reader) (n EncodedNode, err error) {
variant, partialKeyLength, err := decodeHeader(reader)
if err != nil {
return nil, fmt.Errorf("decoding header: %w", err)
Expand All @@ -48,13 +47,13 @@ func Decode(reader io.Reader) (n EncodedNode, err error) {

switch variant {
case leafVariant, leafWithHashedValueVariant:
n, err = decodeLeaf(reader, variant, partialKey)
n, err = decodeLeaf[H](reader, variant, partialKey)
if err != nil {
return nil, fmt.Errorf("cannot decode leaf: %w", err)
}
return n, nil
case branchVariant, branchWithValueVariant, branchWithHashedValueVariant:
n, err = decodeBranch(reader, variant, partialKey)
n, err = decodeBranch[H](reader, variant, partialKey)
if err != nil {
return nil, fmt.Errorf("cannot decode branch: %w", err)
}
Expand All @@ -67,7 +66,7 @@ func Decode(reader io.Reader) (n EncodedNode, err error) {

// decodeBranch reads from a reader and decodes to a node branch.
// Note that we are not decoding the children nodes.
func decodeBranch(reader io.Reader, variant variant, partialKey []byte) (
func decodeBranch[H hash.Hash](reader io.Reader, variant variant, partialKey nibbles.Nibbles) (
node Branch, err error) {
node = Branch{
PartialKey: partialKey,
Expand All @@ -91,11 +90,11 @@ func decodeBranch(reader io.Reader, variant variant, partialKey []byte) (

node.Value = InlineValue(valueBytes)
case branchWithHashedValueVariant:
hashedValue, err := decodeHashedValue(reader)
hashedValue, err := decodeHashedValue[H](reader)
if err != nil {
return Branch{}, err
}
node.Value = HashedValue(hashedValue)
node.Value = HashedValue[H]{hashedValue}
default:
// Do nothing, branch without value
}
Expand All @@ -113,31 +112,36 @@ func decodeBranch(reader io.Reader, variant variant, partialKey []byte) (
ErrDecodeChildHash, i, err)
}

if len(hash) < hashLength {
if len(hash) < (*new(H)).Length() {
node.Children[i] = InlineNode(hash)
} else {
node.Children[i] = HashedNode(hash)
var h H
err := scale.Unmarshal(hash, &h)
timwu20 marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
panic(err)
}
node.Children[i] = HashedNode[H]{h}
}
}

return node, nil
}

// decodeLeaf reads from a reader and decodes to a leaf node.
func decodeLeaf(reader io.Reader, variant variant, partialKey []byte) (node Leaf, err error) {
func decodeLeaf[H hash.Hash](reader io.Reader, variant variant, partialKey nibbles.Nibbles) (node Leaf, err error) {
node = Leaf{
PartialKey: partialKey,
}

sd := scale.NewDecoder(reader)

if variant == leafWithHashedValueVariant {
hashedValue, err := decodeHashedValue(reader)
hashedValue, err := decodeHashedValue[H](sd)
if err != nil {
return Leaf{}, err
}

node.Value = HashedValue(hashedValue)
node.Value = HashedValue[H]{hashedValue}
return node, nil
}

Expand All @@ -152,15 +156,17 @@ func decodeLeaf(reader io.Reader, variant variant, partialKey []byte) (node Leaf
return node, nil
}

func decodeHashedValue(reader io.Reader) ([]byte, error) {
buffer := make([]byte, hashLength)
func decodeHashedValue[H hash.Hash](reader io.Reader) (hash H, err error) {
buffer := make([]byte, (*new(H)).Length())
n, err := reader.Read(buffer)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrDecodeStorageValue, err)
return hash, fmt.Errorf("%w: %s", ErrDecodeStorageValue, err)
}
if n < hashLength {
return nil, fmt.Errorf("%w: expected %d, got: %d", ErrDecodeHashedValueTooShort, hashLength, n)
if n < (*new(H)).Length() {
return hash, fmt.Errorf("%w: expected %d, got: %d", ErrDecodeHashedValueTooShort, (*new(H)).Length(), n)
}

return buffer, nil
h := new(H)
err = scale.Unmarshal(buffer, h)
return *h, err
}
Loading
Loading