diff --git a/internal/primitives/core/hash/hash.go b/internal/primitives/core/hash/hash.go index e9235736f0..a7dc1e7f0d 100644 --- a/internal/primitives/core/hash/hash.go +++ b/internal/primitives/core/hash/hash.go @@ -25,6 +25,11 @@ func (h256 H256) String() string { return fmt.Sprintf("%v", h256.Bytes()) } +// Length returns the byte length of H256 +func (h256 H256) Length() int { + return 32 +} + // MarshalSCALE fulfils the SCALE interface for encoding func (h256 H256) MarshalSCALE() ([]byte, error) { var arr [32]byte diff --git a/pkg/trie/triedb/README.md b/pkg/trie/triedb/README.md index 682c810d3d..391fbddda7 100644 --- a/pkg/trie/triedb/README.md +++ b/pkg/trie/triedb/README.md @@ -10,7 +10,7 @@ It offers functionalities for writing and reading operations and uses lazy loadi - **Reads**: Basic functions to get data from the trie. - **Lazy Loading**: Load data on demand. - **Caching**: Enhances search performance. -- **Compatibility**: Works with any database implementing the `db.RWDatabase` interface and any cache implementing the `cache.TrieCache` interface. +- **Compatibility**: Works with any database implementing the `db.RWDatabase` interface and any cache implementing the `Cache` interface (wip). - **Merkle proofs**: Create and verify merkle proofs. - **Iterator**: Traverse the trie keys in order. diff --git a/pkg/trie/triedb/child_tries.go b/pkg/trie/triedb/child_tries.go deleted file mode 100644 index aafe7bc72d..0000000000 --- a/pkg/trie/triedb/child_tries.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2024 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package triedb - -import ( - "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/pkg/trie" -) - -func (t *TrieDB) GetChild(keyToChild []byte) (trie.Trie, error) { - panic("not implemented yet") -} - -func (t *TrieDB) GetFromChild(keyToChild, key []byte) ([]byte, error) { - panic("not implemented yet") -} - -func (t *TrieDB) GetChildTries() map[common.Hash]trie.Trie { - panic("not implemented yet") -} diff --git a/pkg/trie/triedb/codec/decode.go b/pkg/trie/triedb/codec/decode.go index af9d477802..92ed738161 100644 --- a/pkg/trie/triedb/codec/decode.go +++ b/pkg/trie/triedb/codec/decode.go @@ -9,8 +9,9 @@ import ( "fmt" "io" - "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/pkg/scale" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/hash" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibbles" ) var ( @@ -25,13 +26,11 @@ var ( ErrDecodeStorageValue = errors.New("cannot decode storage value") ) -const hashLength = common.HashLength - // Decode decodes a node from a reader. // The encoding format is documented in the README.md // of this package, and specified in the Polkadot spec at // https://spec.polkadot.network/chap-state#defn-node-header -func Decode(reader io.Reader) (n EncodedNode, err error) { +func Decode[H hash.Hash](reader io.Reader) (n EncodedNode, err error) { variant, partialKeyLength, err := decodeHeader(reader) if err != nil { return nil, fmt.Errorf("decoding header: %w", err) @@ -48,13 +47,13 @@ func Decode(reader io.Reader) (n EncodedNode, err error) { switch variant { case leafVariant, leafWithHashedValueVariant: - n, err = decodeLeaf(reader, variant, partialKey) + n, err = decodeLeaf[H](reader, variant, partialKey) if err != nil { return nil, fmt.Errorf("cannot decode leaf: %w", err) } return n, nil case branchVariant, branchWithValueVariant, branchWithHashedValueVariant: - n, err = decodeBranch(reader, variant, partialKey) + n, err = decodeBranch[H](reader, variant, partialKey) if err != nil { return nil, fmt.Errorf("cannot decode branch: %w", err) } @@ -67,7 +66,7 @@ func Decode(reader io.Reader) (n EncodedNode, err error) { // decodeBranch reads from a reader and decodes to a node branch. // Note that we are not decoding the children nodes. -func decodeBranch(reader io.Reader, variant variant, partialKey []byte) ( +func decodeBranch[H hash.Hash](reader io.Reader, variant variant, partialKey nibbles.Nibbles) ( node Branch, err error) { node = Branch{ PartialKey: partialKey, @@ -91,11 +90,11 @@ func decodeBranch(reader io.Reader, variant variant, partialKey []byte) ( node.Value = InlineValue(valueBytes) case branchWithHashedValueVariant: - hashedValue, err := decodeHashedValue(reader) + hashedValue, err := decodeHashedValue[H](reader) if err != nil { return Branch{}, err } - node.Value = HashedValue(hashedValue) + node.Value = HashedValue[H]{hashedValue} default: // Do nothing, branch without value } @@ -113,10 +112,15 @@ func decodeBranch(reader io.Reader, variant variant, partialKey []byte) ( ErrDecodeChildHash, i, err) } - if len(hash) < hashLength { + if len(hash) < (*new(H)).Length() { node.Children[i] = InlineNode(hash) } else { - node.Children[i] = HashedNode(hash) + var h H + err := scale.Unmarshal(hash, &h) + if err != nil { + panic(err) + } + node.Children[i] = HashedNode[H]{h} } } @@ -124,7 +128,7 @@ func decodeBranch(reader io.Reader, variant variant, partialKey []byte) ( } // decodeLeaf reads from a reader and decodes to a leaf node. -func decodeLeaf(reader io.Reader, variant variant, partialKey []byte) (node Leaf, err error) { +func decodeLeaf[H hash.Hash](reader io.Reader, variant variant, partialKey nibbles.Nibbles) (node Leaf, err error) { node = Leaf{ PartialKey: partialKey, } @@ -132,12 +136,12 @@ func decodeLeaf(reader io.Reader, variant variant, partialKey []byte) (node Leaf sd := scale.NewDecoder(reader) if variant == leafWithHashedValueVariant { - hashedValue, err := decodeHashedValue(reader) + hashedValue, err := decodeHashedValue[H](sd) if err != nil { return Leaf{}, err } - node.Value = HashedValue(hashedValue) + node.Value = HashedValue[H]{hashedValue} return node, nil } @@ -152,15 +156,17 @@ func decodeLeaf(reader io.Reader, variant variant, partialKey []byte) (node Leaf return node, nil } -func decodeHashedValue(reader io.Reader) ([]byte, error) { - buffer := make([]byte, hashLength) +func decodeHashedValue[H hash.Hash](reader io.Reader) (hash H, err error) { + buffer := make([]byte, (*new(H)).Length()) n, err := reader.Read(buffer) if err != nil { - return nil, fmt.Errorf("%w: %s", ErrDecodeStorageValue, err) + return hash, fmt.Errorf("%w: %s", ErrDecodeStorageValue, err) } - if n < hashLength { - return nil, fmt.Errorf("%w: expected %d, got: %d", ErrDecodeHashedValueTooShort, hashLength, n) + if n < (*new(H)).Length() { + return hash, fmt.Errorf("%w: expected %d, got: %d", ErrDecodeHashedValueTooShort, (*new(H)).Length(), n) } - return buffer, nil + h := new(H) + err = scale.Unmarshal(buffer, h) + return *h, err } diff --git a/pkg/trie/triedb/codec/decode_test.go b/pkg/trie/triedb/codec/decode_test.go index 21c025a219..eacb09ed9c 100644 --- a/pkg/trie/triedb/codec/decode_test.go +++ b/pkg/trie/triedb/codec/decode_test.go @@ -8,8 +8,10 @@ import ( "io" "testing" - "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/internal/primitives/core/hash" + "github.com/ChainSafe/gossamer/internal/primitives/runtime" "github.com/ChainSafe/gossamer/pkg/scale" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibbles" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -27,8 +29,7 @@ func scaleEncodeByteSlice(t *testing.T, b []byte) (encoded []byte) { func Test_Decode(t *testing.T) { t.Parallel() - hashedValue, err := common.Blake2bHash([]byte("test")) - assert.NoError(t, err) + hashedValue := runtime.BlakeTwo256{}.Hash([]byte("test")) testCases := map[string]struct { reader io.Reader @@ -66,7 +67,7 @@ func Test_Decode(t *testing.T) { scaleEncodeBytes(t, 1, 2, 3), }, nil)), n: Leaf{ - PartialKey: []byte{9}, + PartialKey: nibbles.NewNibbles([]byte{9}, 1), Value: InlineValue([]byte{1, 2, 3}), }, }, @@ -86,18 +87,18 @@ func Test_Decode(t *testing.T) { {0b0000_0000, 0b0000_0000}, // no children bitmap }, nil)), n: Branch{ - PartialKey: []byte{9}, + PartialKey: nibbles.NewNibbles([]byte{0x09}, 1), }, }, "leaf_with_hashed_value_success": { reader: bytes.NewReader(bytes.Join([][]byte{ {leafWithHashedValueVariant.bits | 1}, // partial key length 1 {9}, // key data - hashedValue.ToBytes(), + hashedValue.Bytes(), }, nil)), n: Leaf{ - PartialKey: []byte{9}, - Value: HashedValue(hashedValue), + PartialKey: nibbles.NewNibbles([]byte{9}, 1), + Value: HashedValue[hash.H256]{hashedValue}, }, }, "leaf_with_hashed_value_fail_too_short": { @@ -114,11 +115,11 @@ func Test_Decode(t *testing.T) { {branchWithHashedValueVariant.bits | 1}, // partial key length 1 {9}, // key data {0b0000_0000, 0b0000_0000}, // no children bitmap - hashedValue.ToBytes(), + hashedValue.Bytes(), }, nil)), n: Branch{ - PartialKey: []byte{9}, - Value: HashedValue(hashedValue), + PartialKey: nibbles.NewNibbles([]byte{9}, 1), + Value: HashedValue[hash.H256]{hashedValue}, }, }, "branch_with_hashed_value_fail_too_short": { @@ -138,7 +139,7 @@ func Test_Decode(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - n, err := Decode(testCase.reader) + n, err := Decode[hash.H256](testCase.reader) assert.ErrorIs(t, err, testCase.errWrapped) if err != nil { @@ -152,13 +153,13 @@ func Test_Decode(t *testing.T) { func Test_decodeBranch(t *testing.T) { t.Parallel() - childHash := common.EmptyHash - scaleEncodedChildHash := scaleEncodeByteSlice(t, childHash.ToBytes()) + var childHash hash.H256 = runtime.BlakeTwo256{}.Hash([]byte{0}) + scaleEncodedChildHash := scaleEncodeByteSlice(t, childHash.Bytes()) testCases := map[string]struct { reader io.Reader nodeVariant variant - partialKey []byte + partialKey nibbles.Nibbles branch Branch errWrapped error errMessage string @@ -177,7 +178,7 @@ func Test_decodeBranch(t *testing.T) { // missing children scale encoded data }), nodeVariant: branchVariant, - partialKey: []byte{1}, + partialKey: nibbles.NewNibbles([]byte{1}), errWrapped: ErrDecodeChildHash, errMessage: "cannot decode child hash: at index 10: decoding uint: reading byte: EOF", }, @@ -189,13 +190,13 @@ func Test_decodeBranch(t *testing.T) { }, nil), ), nodeVariant: branchVariant, - partialKey: []byte{1}, + partialKey: nibbles.NewNibbles([]byte{1}), branch: Branch{ - PartialKey: []byte{1}, + PartialKey: nibbles.NewNibbles([]byte{1}), Children: [ChildrenCapacity]MerkleValue{ nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, - HashedNode(childHash), + HashedNode[hash.H256]{childHash}, }, }, }, @@ -207,7 +208,7 @@ func Test_decodeBranch(t *testing.T) { }, nil), ), nodeVariant: branchWithValueVariant, - partialKey: []byte{1}, + partialKey: nibbles.NewNibbles([]byte{1}), errWrapped: ErrDecodeStorageValue, errMessage: "cannot decode storage value: decoding uint: reading byte: EOF", }, @@ -218,14 +219,14 @@ func Test_decodeBranch(t *testing.T) { scaleEncodedChildHash, }, nil)), nodeVariant: branchWithValueVariant, - partialKey: []byte{1}, + partialKey: nibbles.NewNibbles([]byte{1}), branch: Branch{ - PartialKey: []byte{1}, + PartialKey: nibbles.NewNibbles([]byte{1}), Value: InlineValue([]byte{7, 8, 9}), Children: [ChildrenCapacity]MerkleValue{ nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, - HashedNode(childHash), + HashedNode[hash.H256]{childHash}, }, }, }, @@ -236,9 +237,9 @@ func Test_decodeBranch(t *testing.T) { {0}, // garbage inlined node }, nil)), nodeVariant: branchWithValueVariant, - partialKey: []byte{1}, + partialKey: nibbles.NewNibbles([]byte{1}), branch: Branch{ - PartialKey: []byte{1}, + PartialKey: nibbles.NewNibbles([]byte{1}), Value: InlineValue([]byte{1}), Children: [ChildrenCapacity]MerkleValue{ InlineNode{}, @@ -269,9 +270,9 @@ func Test_decodeBranch(t *testing.T) { }, nil)), }, nil)), nodeVariant: branchVariant, - partialKey: []byte{1}, + partialKey: nibbles.NewNibbles([]byte{1}), branch: Branch{ - PartialKey: []byte{1}, + PartialKey: nibbles.NewNibbles([]byte{1}), Children: [ChildrenCapacity]MerkleValue{ InlineNode( bytes.Join([][]byte{ @@ -304,7 +305,7 @@ func Test_decodeBranch(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - branch, err := decodeBranch(testCase.reader, + branch, err := decodeBranch[hash.H256](testCase.reader, testCase.nodeVariant, testCase.partialKey) assert.ErrorIs(t, err, testCase.errWrapped) @@ -322,7 +323,7 @@ func Test_decodeLeaf(t *testing.T) { testCases := map[string]struct { reader io.Reader variant variant - partialKey []byte + partialKey nibbles.Nibbles leaf Leaf errWrapped error errMessage string @@ -332,7 +333,7 @@ func Test_decodeLeaf(t *testing.T) { {255, 255}, // bad storage value data }, nil)), variant: leafVariant, - partialKey: []byte{9}, + partialKey: nibbles.NewNibbles([]byte{9}), errWrapped: ErrDecodeStorageValue, errMessage: "cannot decode storage value: decoding uint: unknown prefix for compact uint: 255", }, @@ -341,7 +342,7 @@ func Test_decodeLeaf(t *testing.T) { // missing storage value data }), variant: leafVariant, - partialKey: []byte{9}, + partialKey: nibbles.NewNibbles([]byte{9}), errWrapped: ErrDecodeStorageValue, errMessage: "cannot decode storage value: decoding uint: reading byte: EOF", }, @@ -350,9 +351,9 @@ func Test_decodeLeaf(t *testing.T) { scaleEncodeByteSlice(t, []byte{}), // results to []byte{0} }, nil)), variant: leafVariant, - partialKey: []byte{9}, + partialKey: nibbles.NewNibbles([]byte{9}), leaf: Leaf{ - PartialKey: []byte{9}, + PartialKey: nibbles.NewNibbles([]byte{9}), Value: InlineValue([]byte{}), }, }, @@ -361,9 +362,9 @@ func Test_decodeLeaf(t *testing.T) { scaleEncodeBytes(t, 1, 2, 3, 4, 5), // storage value data }, nil)), variant: leafVariant, - partialKey: []byte{9}, + partialKey: nibbles.NewNibbles([]byte{9}), leaf: Leaf{ - PartialKey: []byte{9}, + PartialKey: nibbles.NewNibbles([]byte{9}), Value: InlineValue([]byte{1, 2, 3, 4, 5}), }, }, @@ -374,7 +375,7 @@ func Test_decodeLeaf(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - leaf, err := decodeLeaf(testCase.reader, testCase.variant, testCase.partialKey) + leaf, err := decodeLeaf[hash.H256](testCase.reader, testCase.variant, testCase.partialKey) assert.ErrorIs(t, err, testCase.errWrapped) if err != nil { diff --git a/pkg/trie/triedb/codec/key.go b/pkg/trie/triedb/codec/key.go index 5b0eccb7af..977ab02049 100644 --- a/pkg/trie/triedb/codec/key.go +++ b/pkg/trie/triedb/codec/key.go @@ -8,7 +8,7 @@ import ( "fmt" "io" - "github.com/ChainSafe/gossamer/pkg/trie/codec" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibbles" ) const maxPartialKeyLength = ^uint16(0) @@ -16,22 +16,23 @@ const maxPartialKeyLength = ^uint16(0) var ErrReaderMismatchCount = errors.New("read unexpected number of bytes from reader") // decodeKey decodes a key from a reader. -func decodeKey(reader io.Reader, partialKeyLength uint16) (b []byte, err error) { +func decodeKey(reader io.Reader, partialKeyLength uint16) (b nibbles.Nibbles, err error) { if partialKeyLength == 0 { - return []byte{}, nil + return b, nil } key := make([]byte, partialKeyLength/2+partialKeyLength%2) n, err := reader.Read(key) if err != nil { - return nil, fmt.Errorf("reading from reader: %w", err) + return b, fmt.Errorf("reading from reader: %w", err) } else if n != len(key) { - return nil, fmt.Errorf("%w: read %d bytes instead of expected %d bytes", + return b, fmt.Errorf("%w: read %d bytes instead of expected %d bytes", ErrReaderMismatchCount, n, len(key)) } // if the partialKeyLength is an odd number means that when parsing the key // to nibbles it will contains a useless 0 in the first index, otherwise // we can use the entire nibbles - return codec.KeyLEToNibbles(key)[partialKeyLength%2:], nil + offset := uint(partialKeyLength) % 2 + return nibbles.NewNibbles(key, offset), nil } diff --git a/pkg/trie/triedb/codec/key_test.go b/pkg/trie/triedb/codec/key_test.go index e24502f3c3..27d4ecc145 100644 --- a/pkg/trie/triedb/codec/key_test.go +++ b/pkg/trie/triedb/codec/key_test.go @@ -8,6 +8,7 @@ import ( "fmt" "testing" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibbles" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" ) @@ -57,20 +58,20 @@ func Test_decodeKey(t *testing.T) { testCases := map[string]struct { reads []readCall partialKeyLength uint16 - b []byte + b nibbles.Nibbles errWrapped error errMessage string }{ "zero_key_length": { partialKeyLength: 0, - b: []byte{}, + b: nibbles.NewNibbles(nil), }, "short_key_length": { reads: []readCall{ {buffArgCap: 3, read: []byte{1, 2, 3}, n: 3}, }, partialKeyLength: 5, - b: []byte{0x1, 0x0, 0x2, 0x0, 0x3}, + b: nibbles.NewNibbles([]byte{1, 2, 3}, 1), }, "key_read_error": { reads: []readCall{ @@ -94,14 +95,7 @@ func Test_decodeKey(t *testing.T) { {buffArgCap: 35, read: bytes.Repeat([]byte{7}, 35), n: 35}, // key data }, partialKeyLength: 70, - b: []byte{ - 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, - 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, - 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, - 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, - 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, - 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, - 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7}, + b: nibbles.NewNibbles(bytes.Repeat([]byte{7}, 35)), }, } diff --git a/pkg/trie/triedb/codec/node.go b/pkg/trie/triedb/codec/node.go index 9efdd51452..5ae58f3ff7 100644 --- a/pkg/trie/triedb/codec/node.go +++ b/pkg/trie/triedb/codec/node.go @@ -7,8 +7,9 @@ import ( "fmt" "io" - "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/pkg/scale" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/hash" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibbles" ) const ChildrenCapacity = 16 @@ -23,11 +24,11 @@ type ( // InlineNode contains bytes of the encoded node data InlineNode []byte // HashedNode contains a hash used to lookup in db for encoded node data - HashedNode common.Hash + HashedNode[H any] struct{ Hash H } ) -func (InlineNode) IsHashed() bool { return false } -func (HashedNode) IsHashed() bool { return true } +func (InlineNode) IsHashed() bool { return false } +func (HashedNode[H]) IsHashed() bool { return true } // EncodedValue is a helper enum to differentiate between inline and hashed values type EncodedValue interface { @@ -39,7 +40,9 @@ type ( // InlineValue contains bytes for the value in this node InlineValue []byte // HashedValue contains a hash used to lookup in db for real value - HashedValue common.Hash + HashedValue[H hash.Hash] struct { + Hash H + } ) func (InlineValue) IsHashed() bool { return false } @@ -52,13 +55,9 @@ func (v InlineValue) Write(writer io.Writer) error { return nil } -func (HashedValue) IsHashed() bool { return true } -func (v HashedValue) Write(writer io.Writer) error { - if len(v) != common.HashLength { - panic("invalid hash length") - } - - _, err := writer.Write(v[:]) +func (HashedValue[H]) IsHashed() bool { return true } +func (v HashedValue[H]) Write(writer io.Writer) error { + _, err := writer.Write(v.Hash.Bytes()) if err != nil { return fmt.Errorf("writing hashed storage value: %w", err) } @@ -67,7 +66,7 @@ func (v HashedValue) Write(writer io.Writer) error { // EncodedNode is the object representation of a encoded node type EncodedNode interface { - GetPartialKey() []byte + GetPartialKey() *nibbles.Nibbles GetValue() EncodedValue } @@ -76,23 +75,23 @@ type ( Empty struct{} // Leaf always contains values Leaf struct { - PartialKey []byte + PartialKey nibbles.Nibbles Value EncodedValue } // Branch could has or not has values Branch struct { - PartialKey []byte + PartialKey nibbles.Nibbles Children [ChildrenCapacity]MerkleValue Value EncodedValue } ) -func (Empty) GetPartialKey() []byte { return nil } -func (Empty) GetValue() EncodedValue { return nil } -func (l Leaf) GetPartialKey() []byte { return l.PartialKey } -func (l Leaf) GetValue() EncodedValue { return l.Value } -func (b Branch) GetPartialKey() []byte { return b.PartialKey } -func (b Branch) GetValue() EncodedValue { return b.Value } +func (Empty) GetPartialKey() *nibbles.Nibbles { return nil } +func (Empty) GetValue() EncodedValue { return nil } +func (l Leaf) GetPartialKey() *nibbles.Nibbles { return &l.PartialKey } +func (l Leaf) GetValue() EncodedValue { return l.Value } +func (b Branch) GetPartialKey() *nibbles.Nibbles { return &b.PartialKey } +func (b Branch) GetValue() EncodedValue { return b.Value } // NodeKind is an enum to represent the different types of nodes (Leaf, Branch, etc.) type NodeKind int @@ -105,9 +104,8 @@ const ( BranchWithHashedValue ) -func EncodeHeader(partialKey []byte, kind NodeKind, writer io.Writer) (err error) { - partialKeyLength := len(partialKey) - if partialKeyLength > int(maxPartialKeyLength) { +func EncodeHeader(partialKey []byte, partialKeyLength uint, kind NodeKind, writer io.Writer) (err error) { + if partialKeyLength > uint(maxPartialKeyLength) { panic(fmt.Sprintf("partial key length is too big: %d", partialKeyLength)) } @@ -131,38 +129,45 @@ func EncodeHeader(partialKey []byte, kind NodeKind, writer io.Writer) (err error buffer[0] = nodeVariant.bits partialKeyLengthMask := nodeVariant.partialKeyLengthHeaderMask() - if partialKeyLength < int(partialKeyLengthMask) { + if partialKeyLength < uint(partialKeyLengthMask) { // Partial key length fits in header byte buffer[0] |= byte(partialKeyLength) _, err = writer.Write(buffer) - return err - } - - // Partial key length does not fit in header byte only - buffer[0] |= partialKeyLengthMask - partialKeyLength -= int(partialKeyLengthMask) - _, err = writer.Write(buffer) - if err != nil { - return err - } - - for { - buffer[0] = 255 - if partialKeyLength < 255 { - buffer[0] = byte(partialKeyLength) + if err != nil { + return err } - + } else { + // Partial key length does not fit in header byte only + buffer[0] |= partialKeyLengthMask + partialKeyLength -= uint(partialKeyLengthMask) _, err = writer.Write(buffer) if err != nil { return err } - partialKeyLength -= int(buffer[0]) + for { + buffer[0] = 255 + if partialKeyLength < 255 { + buffer[0] = byte(partialKeyLength) + } + + _, err = writer.Write(buffer) + if err != nil { + return err + } - if buffer[0] < 255 { - break + partialKeyLength -= uint(buffer[0]) + + if buffer[0] < 255 { + break + } } } + _, err = writer.Write(partialKey) + if err != nil { + return fmt.Errorf("cannot write LE key to buffer: %w", err) + } + return nil } diff --git a/pkg/trie/triedb/hash/hash.go b/pkg/trie/triedb/hash/hash.go new file mode 100644 index 0000000000..072cf7713b --- /dev/null +++ b/pkg/trie/triedb/hash/hash.go @@ -0,0 +1,19 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package hash + +// Hash type +type Hash interface { + comparable + // Bytes returns a byte slice representation of Hash + Bytes() []byte + // Length return the byte length of the hash + Length() int +} + +// Hasher is an interface around hashing +type Hasher[H Hash] interface { + // Produce the hash of some byte slice. + Hash(s []byte) H +} diff --git a/pkg/trie/triedb/in_memory_to_triedb_migration_test.go b/pkg/trie/triedb/in_memory_to_triedb_migration_test.go deleted file mode 100644 index 14159f7ca9..0000000000 --- a/pkg/trie/triedb/in_memory_to_triedb_migration_test.go +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright 2024 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package triedb - -import ( - "testing" - - "github.com/ChainSafe/gossamer/internal/database" - "github.com/ChainSafe/gossamer/pkg/trie" - "github.com/ChainSafe/gossamer/pkg/trie/inmemory" - "github.com/stretchr/testify/assert" -) - -func newTestDB(t assert.TestingT) database.Table { - db, err := database.NewPebble("", true) - assert.NoError(t, err) - return database.NewTable(db, "trie") -} - -func TestWriteTrieDB_Migration(t *testing.T) { - inmemoryTrieDB := newTestDB(t) - inMemoryTrie := inmemory.NewEmptyTrie() - inMemoryTrie.SetVersion(trie.V1) - - inmemoryDB := NewMemoryDB(make([]byte, 1)) - trieDB := NewTrieDB(trie.EmptyHash, inmemoryDB) - - entries := map[string][]byte{ - "no": []byte("noValue"), - "noot": []byte("nootValue"), - "not": []byte("notValue"), - "a": []byte("aValue"), - "b": []byte("bValue"), - "test": []byte("testValue"), - "dimartiro": []byte("dimartiroValue"), - } - - for k, v := range entries { - inMemoryTrie.Put([]byte(k), v) - trieDB.Put([]byte(k), v) - } - - err := inMemoryTrie.WriteDirty(inmemoryTrieDB) - assert.NoError(t, err) - - t.Run("read_same_from_both", func(t *testing.T) { - for k := range entries { - valueFromInMemoryTrie := inMemoryTrie.Get([]byte(k)) - assert.NotNil(t, valueFromInMemoryTrie) - - valueFromTrieDB := trieDB.Get([]byte(k)) - assert.NotNil(t, valueFromTrieDB) - assert.Equal(t, valueFromInMemoryTrie, valueFromTrieDB) - } - }) -} - -func TestReadTrieDB_Migration(t *testing.T) { - db := newTestDB(t) - inMemoryTrie := inmemory.NewEmptyTrie() - inMemoryTrie.SetVersion(trie.V1) - - // Use at least 1 value with more than 32 bytes to test trie V1 - entries := map[string][]byte{ - "no": make([]byte, 10), - "noot": make([]byte, 20), - "not": make([]byte, 30), - "notable": make([]byte, 40), - "notification": make([]byte, 50), - "test": make([]byte, 60), - "dimartiro": make([]byte, 70), - } - - for k, v := range entries { - inMemoryTrie.Put([]byte(k), v) - } - - err := inMemoryTrie.WriteDirty(db) - assert.NoError(t, err) - - root, err := inMemoryTrie.Hash() - assert.NoError(t, err) - trieDB := NewTrieDB(root, db) - - t.Run("read_successful_from_db_created_using_v1_trie", func(t *testing.T) { - for k, v := range entries { - value := trieDB.Get([]byte(k)) - assert.NotNil(t, value) - assert.Equal(t, v, value) - } - - assert.Equal(t, root, trieDB.MustHash()) - }) - t.Run("next_key_are_the_same", func(t *testing.T) { - key := []byte("no") - - for key != nil { - expected := inMemoryTrie.NextKey(key) - actual := trieDB.NextKey(key) - assert.Equal(t, expected, actual) - - key = actual - } - }) - - t.Run("get_keys_with_prefix_are_the_same", func(t *testing.T) { - key := []byte("no") - - expected := inMemoryTrie.GetKeysWithPrefix(key) - actual := trieDB.GetKeysWithPrefix(key) - - assert.Equal(t, expected, actual) - }) - - t.Run("entries_are_the_same", func(t *testing.T) { - expected := inMemoryTrie.Entries() - actual := trieDB.Entries() - - assert.Equal(t, expected, actual) - }) -} diff --git a/pkg/trie/triedb/iterator.go b/pkg/trie/triedb/iterator.go index 193bbb3e3a..95d3ff44ce 100644 --- a/pkg/trie/triedb/iterator.go +++ b/pkg/trie/triedb/iterator.go @@ -3,41 +3,471 @@ package triedb -// Entries returns all the key-value pairs in the trie as a map of keys to values -// where the keys are encoded in Little Endian. -func (t *TrieDB) Entries() (keyValueMap map[string][]byte) { - entries := make(map[string][]byte) +import ( + "bytes" + "fmt" - iter := NewTrieDBIterator(t) - for entry := iter.NextEntry(); entry != nil; entry = iter.NextEntry() { - entries[string(entry.Key)] = entry.Value + "github.com/ChainSafe/gossamer/pkg/trie/triedb/codec" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/hash" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibbles" +) + +type ( + status interface { + isStatus() + } + statusEntering struct{} + statusAt struct{} + statusAtChild uint + statusExiting struct{} + statusAftExiting struct{} +) + +func (statusEntering) isStatus() {} +func (statusAt) isStatus() {} +func (statusAtChild) isStatus() {} +func (statusExiting) isStatus() {} +func (statusAftExiting) isStatus() {} + +type crumb[H hash.Hash] struct { + hash *H + node codec.EncodedNode + status +} + +func (c *crumb[H]) step(fwd bool) { + switch status := c.status.(type) { + case statusEntering: + switch c.node.(type) { + case codec.Branch: + c.status = statusAt{} + default: + c.status = statusExiting{} + } + case statusAt: + switch c.node.(type) { + case codec.Branch: + if fwd { + c.status = statusAtChild(0) + } else { + c.status = statusAtChild(15) + } + default: + c.status = statusExiting{} + } + case statusAtChild: + switch c.node.(type) { + case codec.Branch: + if fwd && status < 15 { + c.status = status + 1 + } else if !fwd && status > 15 { + c.status = status - 1 + } else { + c.status = statusExiting{} + } + } + case statusExiting: + c.status = statusAftExiting{} + default: + c.status = statusExiting{} + } +} + +type extractedKey struct { + Key []byte + Padding *byte + Value codec.EncodedValue +} + +type rawItem[H any] struct { + nibbles.NibbleSlice + hash *H + codec.EncodedNode +} + +// Extracts the key from the result of a raw item retrieval. +// +// Given a raw item, it extracts the key information, including the key bytes, an optional +// extra nibble (prefix padding), and the node value. +func (ri rawItem[H]) extractKey() *extractedKey { + prefix := ri.NibbleSlice + node := ri.EncodedNode + + var value codec.EncodedValue + switch node := node.(type) { + case codec.Leaf: + prefix.AppendPartial(node.PartialKey.RightPartial()) + value = node.Value + case codec.Branch: + prefix.AppendPartial(node.PartialKey.RightPartial()) + if node.Value == nil { + return nil + } + value = node.Value + default: + return nil + } + + p := prefix.Prefix() + return &extractedKey{ + Key: p.Key, + Padding: p.Padded, + Value: value, + } +} + +type rawIterator[H hash.Hash, Hasher hash.Hasher[H]] struct { + // Forward trail of nodes to visit. + trail []crumb[H] + // Forward iteration key nibbles of the current node. + keyNibbles nibbles.NibbleSlice + db *TrieDB[H, Hasher] +} + +// Create a new iterator. +func newRawIterator[H hash.Hash, Hasher hash.Hasher[H]]( + db *TrieDB[H, Hasher], +) (*rawIterator[H, Hasher], error) { + rootNode, rootHash, err := db.getNodeOrLookup( + codec.HashedNode[H]{Hash: db.rootHash}, + nibbles.Prefix{}, + true, + ) + if err != nil { + return nil, err + } + + r := rawIterator[H, Hasher]{ + db: db, + } + r.descend(rootNode, rootHash) + return &r, nil +} + +// Create a new iterator, but limited to a given prefix. +func newPrefixedRawIterator[H hash.Hash, Hasher hash.Hasher[H]]( + db *TrieDB[H, Hasher], prefix []byte, +) (*rawIterator[H, Hasher], error) { + iter, err := newRawIterator(db) + if err != nil { + return nil, err + } + err = iter.prefix(prefix, true) + if err != nil { + return nil, err + } + return iter, nil +} + +// Create a new iterator, but limited to a given prefix. +// It then do a seek operation from prefixed context (using seek lose +// prefix context by default). +func newPrefixedRawIteratorThenSeek[H hash.Hash, Hasher hash.Hasher[H]]( + db *TrieDB[H, Hasher], prefix []byte, seek []byte, +) (*rawIterator[H, Hasher], error) { + iter, err := newRawIterator(db) + if err != nil { + return nil, err + } + err = iter.prefixThenSeek(prefix, seek) + if err != nil { + return nil, err + } + return iter, nil +} + +// Descend into a node. +func (ri *rawIterator[H, Hasher]) descend(node codec.EncodedNode, nodeHash *H) { + ri.trail = append(ri.trail, crumb[H]{ + hash: nodeHash, + status: statusEntering{}, + node: node, + }) +} + +// Seek a node position at key for iterator. +// Returns true if the cursor is at or after the key, but still shares +// a common prefix with the key, return false if the key do not +// share its prefix with the node. +// This indicates if there is still nodes to iterate over in the case +// where we limit iteration to key as a prefix. +func (ri *rawIterator[H, Hasher]) seek(keyBytes []byte, fwd bool) (bool, error) { + ri.trail = nil + ri.keyNibbles.Clear() + key := nibbles.NewNibbles(keyBytes) + + node, nodeHash, err := ri.db.getNodeOrLookup( + codec.HashedNode[H]{Hash: ri.db.rootHash}, nibbles.Prefix{}, true, + ) + if err != nil { + return false, err } + partial := key + var fullKeyNibbles uint + for { + var ( + nextNode codec.EncodedNode + nextNodeHash *H + ) + + ri.descend(node, nodeHash) + crumb := &ri.trail[len(ri.trail)-1] + + switch node := crumb.node.(type) { + case codec.Leaf: + if (fwd && node.PartialKey.Compare(partial) == -1) || + (!fwd && node.PartialKey.Compare(partial) == 1) { + crumb.status = statusAftExiting{} + return false, nil + } + return node.PartialKey.StartsWith(partial), nil + case codec.Branch: + pk := node.PartialKey + if !partial.StartsWith(pk) { + if (fwd && pk.Compare(partial) == -1) || + (!fwd && pk.Compare(partial) == 1) { + crumb.status = statusAftExiting{} + return false, nil + } + return pk.StartsWith(partial), nil + } + + fullKeyNibbles += pk.Len() + partial = partial.Mid(pk.Len()) + + if partial.Len() == 0 { + return true, nil + } + + i := partial.At(0) + crumb.status = statusAtChild(i) + ri.keyNibbles.AppendPartial(pk.RightPartial()) + ri.keyNibbles.Push(i) - return entries + if child := node.Children[i]; child != nil { + fullKeyNibbles += 1 + partial = partial.Mid(1) + + prefix := key.Back(fullKeyNibbles) + var err error + nextNode, nextNodeHash, err = ri.db.getNodeOrLookup(child, prefix.Left(), true) + if err != nil { + return false, err + } + } else { + return false, nil + } + case codec.Empty: + if !(partial.Len() == 0) { + crumb.status = statusExiting{} + return false, nil + } + return true, nil + } + + node = nextNode + nodeHash = nextNodeHash + } +} + +// Advance the iterator into a prefix, no value out of the prefix will be accessed +// or returned after this operation. +func (ri *rawIterator[H, Hasher]) prefix(prefix []byte, fwd bool) error { + found, err := ri.seek(prefix, fwd) + if err != nil { + return err + } + if found { + if len(ri.trail) > 0 { + popped := ri.trail[len(ri.trail)-1] + ri.trail = nil + ri.trail = append(ri.trail, popped) + } + } else { + ri.trail = nil + } + return nil } -// NextKey returns the next key in the trie in lexicographic order. -// It returns nil if no next key is found. -func (t *TrieDB) NextKey(key []byte) []byte { - iter := NewTrieDBIterator(t) +// Advance the iterator into a prefix, no value out of the prefix will be accessed +// or returned after this operation. +func (ri *rawIterator[H, Hasher]) prefixThenSeek(prefix []byte, seek []byte) error { + if len(prefix) == 0 { + // Theres no prefix, so just seek. + _, err := ri.seek(seek, true) + if err != nil { + return err + } + } - // TODO: Seek will potentially skip a lot of keys, we need to find a way to - // optimise it, maybe creating a lookupFor - iter.Seek(key) - return iter.NextKey() + if len(seek) == 0 || bytes.Compare(seek, prefix) <= 0 { + // Either were not supposed to seek anywhere, + // or were supposed to seek *before* the prefix, + // so just directly go to the prefix. + return ri.prefix(prefix, true) + } + + if !bytes.HasPrefix(seek, prefix) { + // Were supposed to seek *after* the prefix, + // so just return an empty iterator. + ri.trail = nil + return nil + } + + found, err := ri.seek(prefix, true) + if err != nil { + return err + } + if !found { + // The database doesnt have a key with such a prefix. + ri.trail = nil + return nil + } + + // Now seek forward again + _, err = ri.seek(seek, true) + if err != nil { + return err + } + + prefixLen := uint(len(prefix)) * nibbles.NibblesPerByte + var length uint + // look first prefix in trail + for i := 0; i < len(ri.trail); i++ { + switch node := ri.trail[i].node.(type) { + case codec.Empty: + case codec.Leaf: + length += node.PartialKey.Len() + case codec.Branch: + length++ + length += node.PartialKey.Len() + } + if length > prefixLen { + ri.trail = ri.trail[i:] + return nil + } + } + + ri.trail = nil + return nil +} + +// Fetches the next raw item. +// +// Must be called with the same db as when the iterator was created. +// +// Specify fwd to indicate the direction of the iteration (true for forward). +func (ri *rawIterator[H, Hasher]) nextRawItem(fwd bool) (*rawItem[H], error) { + for { + if len(ri.trail) == 0 { + return nil, nil + } + crumb := &ri.trail[len(ri.trail)-1] + switch status := crumb.status.(type) { + case statusEntering: + crumb.step(fwd) + if fwd { + return &rawItem[H]{ri.keyNibbles, crumb.hash, crumb.node}, nil + } + case statusAftExiting: + ri.trail = ri.trail[:len(ri.trail)-1] + if len(ri.trail) > 0 { + crumb := &ri.trail[len(ri.trail)-1] + crumb.step(fwd) + } + case statusExiting: + switch node := crumb.node.(type) { + case codec.Empty, codec.Leaf: + case codec.Branch: + ri.keyNibbles.DropLasts(node.PartialKey.Len() + 1) + default: + panic("unreachable") + } + crumb := &ri.trail[len(ri.trail)-1] + crumb.step(fwd) + if !fwd { + return &rawItem[H]{ri.keyNibbles, crumb.hash, crumb.node}, nil + } + case statusAt: + branch, ok := crumb.node.(codec.Branch) + if !ok { + panic("unsupported") + } + partial := branch.PartialKey + ri.keyNibbles.AppendPartial(partial.RightPartial()) + if fwd { + ri.keyNibbles.Push(0) + } else { + ri.keyNibbles.Push(15) + } + crumb.step(fwd) + case statusAtChild: + i := status + branch, ok := crumb.node.(codec.Branch) + if !ok { + panic("unsupported") + } + children := branch.Children + child := children[i] + if child != nil { + ri.keyNibbles.Pop() + ri.keyNibbles.Push(uint8(i)) //nolint:gosec + + node, nodeHash, err := ri.db.getNodeOrLookup(children[i], ri.keyNibbles.Prefix(), true) + if err != nil { + crumb.step(fwd) + return nil, err + } + ri.descend(node, nodeHash) + } else { + crumb.step(fwd) + } + default: + panic(fmt.Errorf("unreachable: %T", status)) + } + } } -// GetKeysWithPrefix returns all keys in little Endian -// format from nodes in the trie that have the given little -// Endian formatted prefix in their key. -func (t *TrieDB) GetKeysWithPrefix(prefix []byte) (keysLE [][]byte) { - iter := NewPrefixedTrieDBIterator(t, prefix) +// Fetches the next trie item. +// +// Must be called with the same db as when the iterator was created. +func (ri *rawIterator[H, Hasher]) NextItem() (*TrieItem, error) { + for { + rawItem, err := ri.nextRawItem(true) + if err != nil { + return nil, err + } + if rawItem == nil { + return nil, nil + } + extracted := rawItem.extractKey() + if extracted == nil { + continue + } + key := extracted.Key + maybeExtraNibble := extracted.Padding + value := extracted.Value - keys := make([][]byte, 0) + if maybeExtraNibble != nil { + return nil, fmt.Errorf("ValueAtIncompleteKey: %v %v", key, *maybeExtraNibble) + } - for key := iter.NextKey(); key != nil; key = iter.NextKey() { - keys = append(keys, key) + switch value := value.(type) { + case codec.HashedValue[H]: + val, err := ri.db.fetchValue(value.Hash, nibbles.Prefix{Key: key}) + if err != nil { + return nil, err + } + return &TrieItem{key, val}, nil + case codec.InlineValue: + return &TrieItem{key, value}, nil + default: + panic("unreachable") + } } +} - return keys +type TrieItem struct { + Key []byte + Value []byte } diff --git a/pkg/trie/triedb/iterator_test.go b/pkg/trie/triedb/iterator_test.go new file mode 100644 index 0000000000..c16c437d2e --- /dev/null +++ b/pkg/trie/triedb/iterator_test.go @@ -0,0 +1,150 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package triedb + +import ( + "testing" + + "github.com/ChainSafe/gossamer/internal/primitives/core/hash" + "github.com/ChainSafe/gossamer/internal/primitives/runtime" + "github.com/ChainSafe/gossamer/pkg/trie" + "github.com/stretchr/testify/assert" +) + +func Test_rawIterator(t *testing.T) { + entries := map[string][]byte{ + "no": make([]byte, 1), + "noot": make([]byte, 2), + "not": make([]byte, 3), + "notable": make([]byte, 4), + "notification": make([]byte, 5), + "test": make([]byte, 6), + "dimartiro": make([]byte, 7), + "bigvalue": make([]byte, 33), + "bigbigvalue": make([]byte, 66), + } + + db := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + trieDB := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](db) + trieDB.SetVersion(trie.V1) + + for k, v := range entries { + err := trieDB.Put([]byte(k), v) + assert.NoError(t, err) + } + assert.NoError(t, trieDB.commit()) + + t.Run("iterate_over_all_raw_items", func(t *testing.T) { + iter, err := newRawIterator(trieDB) + assert.NoError(t, err) + + i := 0 + for { + item, err := iter.nextRawItem(true) + assert.NoError(t, err) + if item == nil { + break + } + i++ + } + assert.Equal(t, 13, i) + }) + + t.Run("iterate_over_all_entries", func(t *testing.T) { + iter, err := newRawIterator(trieDB) + assert.NoError(t, err) + + i := 0 + for { + item, err := iter.NextItem() + assert.NoError(t, err) + if item == nil { + break + } + assert.Contains(t, entries, string(item.Key)) + assert.Equal(t, item.Value, entries[string(item.Key)]) + i++ + } + assert.Equal(t, len(entries), i) + }) + + t.Run("seek", func(t *testing.T) { + iter, err := newRawIterator(trieDB) + assert.NoError(t, err) + + found, err := iter.seek([]byte("no"), true) + assert.NoError(t, err) + assert.True(t, found) + + item, err := iter.NextItem() + assert.NoError(t, err) + assert.NotNil(t, item) + assert.Equal(t, "no", string(item.Key)) + + item, err = iter.NextItem() + assert.NoError(t, err) + assert.NotNil(t, item) + assert.Equal(t, "noot", string(item.Key)) + }) + + t.Run("seek_leaf", func(t *testing.T) { + iter, err := newRawIterator(trieDB) + assert.NoError(t, err) + + found, err := iter.seek([]byte("dimartiro"), true) + assert.NoError(t, err) + assert.True(t, found) + + item, err := iter.NextItem() + assert.NoError(t, err) + assert.NotNil(t, item) + assert.Equal(t, "dimartiro", string(item.Key)) + }) + + t.Run("iterate_over_all_prefixed_entries", func(t *testing.T) { + iter, err := newPrefixedRawIterator(trieDB, []byte("no")) + assert.NoError(t, err) + + i := 0 + for { + item, err := iter.NextItem() + assert.NoError(t, err) + if item == nil { + break + } + assert.Contains(t, entries, string(item.Key)) + assert.Equal(t, item.Value, entries[string(item.Key)]) + i++ + } + assert.Equal(t, 5, i) + }) + + t.Run("prefixed_raw_iterator", func(t *testing.T) { + iter, err := newPrefixedRawIterator(trieDB, []byte("noot")) + assert.NoError(t, err) + + item, err := iter.NextItem() + assert.NoError(t, err) + assert.NotNil(t, item) + assert.Equal(t, "noot", string(item.Key)) + }) + + t.Run("iterate_over_all_prefixed_entries_then_seek", func(t *testing.T) { + iter, err := newPrefixedRawIteratorThenSeek(trieDB, []byte("no"), []byte("noot")) + assert.NoError(t, err) + + i := 0 + for { + item, err := iter.NextItem() + assert.NoError(t, err) + if item == nil { + break + } + assert.Contains(t, entries, string(item.Key)) + assert.Equal(t, item.Value, entries[string(item.Key)]) + i++ + } + assert.Equal(t, 4, i) + }) +} diff --git a/pkg/trie/triedb/lookup.go b/pkg/trie/triedb/lookup.go index 84cab5d467..833efe43a3 100644 --- a/pkg/trie/triedb/lookup.go +++ b/pkg/trie/triedb/lookup.go @@ -6,25 +6,30 @@ package triedb import ( "bytes" - "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/pkg/trie/cache" "github.com/ChainSafe/gossamer/pkg/trie/db" "github.com/ChainSafe/gossamer/pkg/trie/triedb/codec" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/hash" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibbles" ) -type TrieLookup struct { +type TrieLookup[H hash.Hash, Hasher hash.Hasher[H]] struct { // db to query from db db.DBGetter // hash to start at - hash common.Hash - // cache to speed up the db lookups - cache cache.TrieCache - // Optional recorder for recording trie accesses - recorder *Recorder + hash H + // optional cache to speed up the db lookups + cache Cache + // optional recorder for recording trie accesses + recorder TrieRecorder } -func NewTrieLookup(db db.DBGetter, hash common.Hash, cache cache.TrieCache, recorder *Recorder) TrieLookup { - return TrieLookup{ +func NewTrieLookup[H hash.Hash, Hasher hash.Hasher[H]]( + db db.DBGetter, + hash H, + cache Cache, + recorder TrieRecorder, +) TrieLookup[H, Hasher] { + return TrieLookup[H, Hasher]{ db: db, hash: hash, cache: cache, @@ -32,38 +37,30 @@ func NewTrieLookup(db db.DBGetter, hash common.Hash, cache cache.TrieCache, reco } } -func (l *TrieLookup) lookupNode(keyNibbles []byte) (codec.EncodedNode, error) { +func (l *TrieLookup[H, Hasher]) lookupNode( + nibbleKey nibbles.Nibbles, fullKey []byte, +) (codec.EncodedNode, error) { // Start from root node and going downwards - partialKey := keyNibbles - hash := l.hash[:] + partialKey := nibbleKey.Clone() + hash := l.hash + var keyNibbles uint // Iterates through non inlined nodes for { // Get node from DB - var nodeData []byte - if l.cache != nil { - nodeData = l.cache.GetNode(hash) + prefixedKey := append(nibbleKey.Mid(keyNibbles).Left().JoinedBytes(), hash.Bytes()...) + nodeData, err := l.db.Get(prefixedKey) + if err != nil { + return nil, ErrIncompleteDB } - if nodeData == nil { - var err error - nodeData, err = l.db.Get(hash) - if err != nil { - return nil, ErrIncompleteDB - } - - if l.cache != nil { - l.cache.SetNode(hash, nodeData) - } - - l.recordAccess(encodedNodeAccess{hash: common.BytesToHash(hash), encodedNode: nodeData}) - } + l.recordAccess(EncodedNodeAccess[H]{Hash: hash, EncodedNode: nodeData}) InlinedChildrenIterator: for { // Decode node reader := bytes.NewReader(nodeData) - decodedNode, err := codec.Decode(reader) + decodedNode, err := codec.Decode[H](reader) if err != nil { return nil, err } @@ -75,11 +72,11 @@ func (l *TrieLookup) lookupNode(keyNibbles []byte) (codec.EncodedNode, error) { return nil, nil //nolint:nilnil case codec.Leaf: // We are in the node we were looking for - if bytes.Equal(partialKey, n.PartialKey) { + if partialKey.Equal(n.PartialKey) { return n, nil } - l.recordAccess(nonExistingNodeAccess{fullKey: keyNibbles}) + l.recordAccess(NonExistingNodeAccess{FullKey: fullKey}) return nil, nil //nolint:nilnil case codec.Branch: @@ -88,40 +85,41 @@ func (l *TrieLookup) lookupNode(keyNibbles []byte) (codec.EncodedNode, error) { // This is unusual but could happen if for some reason one // branch has a hashed child node that points to a node that // doesn't share the prefix we are expecting - if !bytes.HasPrefix(partialKey, nodePartialKey) { - l.recordAccess(nonExistingNodeAccess{fullKey: keyNibbles}) + if !partialKey.StartsWith(nodePartialKey) { + l.recordAccess(NonExistingNodeAccess{FullKey: fullKey}) return nil, nil //nolint:nilnil } // We are in the node we were looking for - if bytes.Equal(partialKey, nodePartialKey) { + if partialKey.Equal(n.PartialKey) { if n.Value != nil { return n, nil } - l.recordAccess(nonExistingNodeAccess{fullKey: keyNibbles}) + l.recordAccess(NonExistingNodeAccess{FullKey: fullKey}) return nil, nil //nolint:nilnil } // This is not the node we were looking for but it might be in // one of its children - childIdx := int(partialKey[len(nodePartialKey)]) + childIdx := int(partialKey.At(nodePartialKey.Len())) nextNode = n.Children[childIdx] if nextNode == nil { - l.recordAccess(nonExistingNodeAccess{fullKey: keyNibbles}) + l.recordAccess(NonExistingNodeAccess{FullKey: fullKey}) return nil, nil //nolint:nilnil } // Advance the partial key consuming the part we already checked - partialKey = partialKey[len(nodePartialKey)+1:] + partialKey = partialKey.Mid(nodePartialKey.Len() + 1) + keyNibbles += nodePartialKey.Len() + 1 } // Next node could be inlined or hashed (pointer to a node) // https://spec.polkadot.network/chap-state#defn-merkle-value switch merkleValue := nextNode.(type) { - case codec.HashedNode: + case codec.HashedNode[H]: // If it's hashed we set the hash to look for it in next loop - hash = merkleValue[:] + hash = merkleValue.Hash break InlinedChildrenIterator case codec.InlineNode: // If it is inlined we just need to decode it in the next loop @@ -131,14 +129,10 @@ func (l *TrieLookup) lookupNode(keyNibbles []byte) (codec.EncodedNode, error) { } } -func (l *TrieLookup) lookupValue(keyNibbles []byte) (value []byte, err error) { - if l.cache != nil { - if value = l.cache.GetValue(keyNibbles); value != nil { - return value, nil - } - } - - node, err := l.lookupNode(keyNibbles) +func (l *TrieLookup[H, Hasher]) lookupValue( + fullKey []byte, keyNibbles nibbles.Nibbles, +) (value []byte, err error) { + node, err := l.lookupNode(keyNibbles, fullKey) if err != nil { return nil, err } @@ -149,15 +143,10 @@ func (l *TrieLookup) lookupValue(keyNibbles []byte) (value []byte, err error) { } if nodeValue := node.GetValue(); nodeValue != nil { - value, err = l.fetchValue(node.GetPartialKey(), keyNibbles, nodeValue) + value, err = l.fetchValue(keyNibbles.OriginalDataPrefix(), fullKey, nodeValue) if err != nil { return nil, err } - - if l.cache != nil { - l.cache.SetValue(keyNibbles, value) - } - return value, nil } @@ -166,29 +155,22 @@ func (l *TrieLookup) lookupValue(keyNibbles []byte) (value []byte, err error) { // fetchValue gets the value from the node, if it is inlined we can return it // directly. But if it is hashed (V1) we have to look up for its value in the DB -func (l *TrieLookup) fetchValue(prefix []byte, fullKey []byte, value codec.EncodedValue) ([]byte, error) { +func (l *TrieLookup[H, Hasher]) fetchValue( + prefix nibbles.Prefix, fullKey []byte, value codec.EncodedValue, +) ([]byte, error) { switch v := value.(type) { case codec.InlineValue: - l.recordAccess(inlineValueAccess{fullKey: fullKey}) + l.recordAccess(InlineValueAccess{FullKey: fullKey}) return v, nil - case codec.HashedValue: - prefixedKey := bytes.Join([][]byte{prefix, v[:]}, nil) - if l.cache != nil { - if value := l.cache.GetValue(prefixedKey); value != nil { - return value, nil - } - } + case codec.HashedValue[H]: + prefixedKey := bytes.Join([][]byte{prefix.JoinedBytes(), v.Hash.Bytes()}, nil) nodeData, err := l.db.Get(prefixedKey) if err != nil { return nil, ErrIncompleteDB } - if l.cache != nil { - l.cache.SetValue(prefixedKey, nodeData) - } - - l.recordAccess(valueAccess{hash: common.Hash(v), fullKey: fullKey, value: nodeData}) + l.recordAccess(ValueAccess[H]{Hash: v.Hash, FullKey: fullKey, Value: nodeData}) return nodeData, nil default: @@ -196,8 +178,8 @@ func (l *TrieLookup) fetchValue(prefix []byte, fullKey []byte, value codec.Encod } } -func (l *TrieLookup) recordAccess(access trieAccess) { +func (l *TrieLookup[H, Hasher]) recordAccess(access TrieAccess) { if l.recorder != nil { - l.recorder.record(access) + l.recorder.Record(access) } } diff --git a/pkg/trie/triedb/lookup_test.go b/pkg/trie/triedb/lookup_test.go index c56a1e8feb..bf57721148 100644 --- a/pkg/trie/triedb/lookup_test.go +++ b/pkg/trie/triedb/lookup_test.go @@ -6,17 +6,32 @@ package triedb import ( "testing" - "github.com/ChainSafe/gossamer/pkg/trie" + "github.com/ChainSafe/gossamer/internal/primitives/core/hash" + "github.com/ChainSafe/gossamer/internal/primitives/runtime" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibbles" "github.com/stretchr/testify/assert" ) func TestTrieDB_Lookup(t *testing.T) { t.Run("root_not_exists_in_db", func(t *testing.T) { db := newTestDB(t) - lookup := NewTrieLookup(db, trie.EmptyHash, nil, nil) + empty := runtime.BlakeTwo256{}.Hash([]byte{0}) + lookup := NewTrieLookup[hash.H256, runtime.BlakeTwo256](db, empty, nil, nil) - value, err := lookup.lookupValue([]byte("test")) + value, err := lookup.lookupValue([]byte("test"), nibbles.NewNibbles([]byte("test"))) assert.Nil(t, value) assert.ErrorIs(t, err, ErrIncompleteDB) }) } + +// TODO: restore after implementing node level caching +// func Test_valueHash_CachedValue(t *testing.T) { +// var vh *valueHash[hash.H256] +// assert.Equal(t, NonExistingCachedValue{}, vh.CachedValue()) + +// vh = &valueHash[hash.H256]{ +// Value: []byte("someValue"), +// Hash: hash.NewRandomH256(), +// } +// assert.Equal(t, ExistingCachedValue[hash.H256]{vh.Hash, vh.Value}, vh.CachedValue()) +// } diff --git a/pkg/trie/triedb/mem_test.go b/pkg/trie/triedb/mem_test.go index bec511298c..54ff2c4475 100644 --- a/pkg/trie/triedb/mem_test.go +++ b/pkg/trie/triedb/mem_test.go @@ -6,10 +6,13 @@ package triedb import ( "testing" + "github.com/ChainSafe/gossamer/internal/primitives/core/hash" + "github.com/ChainSafe/gossamer/internal/primitives/runtime" "github.com/ChainSafe/gossamer/pkg/trie" inmemory_cache "github.com/ChainSafe/gossamer/pkg/trie/cache/inmemory" inmemory_trie "github.com/ChainSafe/gossamer/pkg/trie/inmemory" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Benchmark_ValueCache(b *testing.B) { @@ -38,7 +41,7 @@ func Benchmark_ValueCache(b *testing.B) { assert.NoError(b, err) b.Run("get_value_without_cache", func(b *testing.B) { - trieDB := NewTrieDB(root, db) + trieDB := NewTrieDB[hash.H256, runtime.BlakeTwo256](hash.H256(root.ToBytes()), db) b.ResetTimer() for i := 0; i < b.N; i++ { // Use the deepest key to ensure the trie is traversed fully @@ -48,7 +51,8 @@ func Benchmark_ValueCache(b *testing.B) { b.Run("get_value_with_cache", func(b *testing.B) { cache := inmemory_cache.NewTrieInMemoryCache() - trieDB := NewTrieDB(root, db, WithCache(cache)) + trieDB := NewTrieDB[hash.H256, runtime.BlakeTwo256]( + hash.H256(root.ToBytes()), db, WithCache[hash.H256, runtime.BlakeTwo256](cache)) b.ResetTimer() for i := 0; i < b.N; i++ { // Use the deepest key to ensure the trie is traversed fully @@ -83,12 +87,13 @@ func Benchmark_NodesCache(b *testing.B) { assert.NoError(b, err) b.Run("iterate_all_entries_without_cache", func(b *testing.B) { - trieDB := NewTrieDB(root, db) + trieDB := NewTrieDB[hash.H256, runtime.BlakeTwo256](hash.H256(root.ToBytes()), db) b.ResetTimer() for i := 0; i < b.N; i++ { // Iterate through all keys - iter := NewTrieDBIterator(trieDB) - for entry := iter.NextEntry(); entry != nil; entry = iter.NextEntry() { + iter, err := newRawIterator(trieDB) + require.NoError(b, err) + for entry, err := iter.NextItem(); entry != nil && err == nil; entry, err = iter.NextItem() { } } }) @@ -98,12 +103,14 @@ func Benchmark_NodesCache(b *testing.B) { // cache the decoded node instead and avoid decoding it every time. b.Run("iterate_all_entries_with_cache", func(b *testing.B) { cache := inmemory_cache.NewTrieInMemoryCache() - trieDB := NewTrieDB(root, db, WithCache(cache)) + trieDB := NewTrieDB[hash.H256, runtime.BlakeTwo256]( + hash.H256(root.ToBytes()), db, WithCache[hash.H256, runtime.BlakeTwo256](cache)) b.ResetTimer() for i := 0; i < b.N; i++ { // Iterate through all keys - iter := NewTrieDBIterator(trieDB) - for entry := iter.NextEntry(); entry != nil; entry = iter.NextEntry() { + iter, err := newRawIterator(trieDB) + require.NoError(b, err) + for entry, err := iter.NextItem(); entry != nil && err == nil; entry, err = iter.NextItem() { } } }) diff --git a/pkg/trie/triedb/nibbles/leftnibbles.go b/pkg/trie/triedb/nibbles/leftnibbles.go new file mode 100644 index 0000000000..8e0bb5ef44 --- /dev/null +++ b/pkg/trie/triedb/nibbles/leftnibbles.go @@ -0,0 +1,99 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package nibbles + +import ( + "bytes" + "cmp" +) + +// A representation of a nibble slice which is left-aligned. The regular [Nibbles] is +// right-aligned, meaning it does not support efficient truncation from the right side. +// +// This is meant to be an immutable struct. No operations actually change it. +type LeftNibbles struct { + bytes []byte + len uint +} + +// Constructs a byte-aligned nibble slice from a byte slice. +func NewLeftNibbles(bytes []byte) LeftNibbles { + return LeftNibbles{ + bytes: bytes, + len: uint(len(bytes)) * NibblesPerByte, + } +} + +// Returns the length of the slice in nibbles. +func (ln LeftNibbles) Len() uint { + return ln.len +} + +func leftNibbleAt(v1 []byte, ix uint) uint8 { + return atLeft(uint8(ix%NibblesPerByte), v1[ix/NibblesPerByte]) //nolint:gosec +} + +// Get the nibble at a nibble index padding with a 0 nibble. Returns nil if the index is +// out of bounds. +func (ln LeftNibbles) At(index uint) *uint8 { + if index < ln.len { + at := leftNibbleAt(ln.bytes, index) + return &at + } + return nil +} + +// Returns a new slice truncated from the right side to the given length. If the given length +// is greater than that of this slice, the function just returns a copy. +func (ln LeftNibbles) Truncate(len uint) LeftNibbles { + if ln.len < len { + len = ln.len + } + return LeftNibbles{bytes: ln.bytes, len: len} +} + +// Returns whether the given slice is a prefix of this one. +func (ln LeftNibbles) StartsWith(prefix LeftNibbles) bool { + return ln.Truncate(prefix.Len()).compare(prefix) == 0 +} + +// Returns whether another regular (right-aligned) nibble slice is contained in this one at +// the given offset. +func (ln LeftNibbles) Contains(partial Nibbles, offset uint) bool { + for i := uint(0); i < partial.Len(); i++ { + lnAt := ln.At(offset + i) + partialAt := partial.At(i) + if *lnAt == partialAt { + continue + } + return false + } + return true +} + +func (ln LeftNibbles) compare(other LeftNibbles) int { + commonLen := ln.Len() + if other.Len() < commonLen { + commonLen = other.Len() + } + commonByteLen := commonLen / NibblesPerByte + + // Quickly compare the common prefix of the byte slices. + c := bytes.Compare(ln.bytes[:commonByteLen], other.bytes[:commonByteLen]) + if c != 0 { + return c + } + + // Compare nibble-by-nibble (either 0 or 1 nibbles) any after the common byte prefix. + for i := commonByteLen * NibblesPerByte; i < commonLen; i++ { + a := *ln.At(i) + b := *other.At(i) + if c := cmp.Compare(a, b); c != 0 { + return c + } + } + + // If common nibble prefix is the same, finally compare lengths. + return cmp.Compare(ln.Len(), other.Len()) +} diff --git a/pkg/trie/triedb/nibbles/leftnibbles_test.go b/pkg/trie/triedb/nibbles/leftnibbles_test.go new file mode 100644 index 0000000000..abc681e06a --- /dev/null +++ b/pkg/trie/triedb/nibbles/leftnibbles_test.go @@ -0,0 +1,31 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package nibbles + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_LeftNibbles_compare(t *testing.T) { + n := NewLeftNibbles([]byte("gossamer")) + m := NewLeftNibbles([]byte("gossamerGossamer")) + + assert.Equal(t, -1, n.compare(m)) + assert.Equal(t, 1, m.compare(n)) + assert.Equal(t, 0, n.compare(m.Truncate(16))) + + truncated := n.Truncate(1) + assert.Equal(t, -1, truncated.compare(n)) + assert.Equal(t, 1, n.compare(truncated)) + assert.Equal(t, -1, truncated.compare(m)) +} + +func Test_LeftNibbles_StartsWith(t *testing.T) { + a := NewLeftNibbles([]byte("polkadot")) + b := NewLeftNibbles([]byte("go")) + b.len = 1 + assert.Equal(t, false, a.StartsWith(b)) +} diff --git a/pkg/trie/triedb/nibbles/nibbles.go b/pkg/trie/triedb/nibbles/nibbles.go new file mode 100644 index 0000000000..e603a041e2 --- /dev/null +++ b/pkg/trie/triedb/nibbles/nibbles.go @@ -0,0 +1,362 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package nibbles + +import ( + "slices" +) + +// Single nibble length in bit. +const BitsPerNibble uint8 = 4 + +// Number of nibble per byte. +const NibblesPerByte uint = 2 + +// Nibble (half a byte). +const PaddingBitmask uint8 = 0x0F + +// Nibble-orientated view onto byte-slice, allowing nibble-precision offsets. +// +// This is meant to be an immutable struct. No operations actually change it. +type Nibbles struct { + data []uint8 + offset uint +} + +// Construct new [Nibbles] from data and offset +func NewNibbles(data []byte, offset ...uint) Nibbles { + var off uint + if len(offset) > 0 { + off = offset[0] + } + return Nibbles{ + data: data, + offset: off, + } +} + +// Construct new [Nibbles] from [NodeKey] +func NewNibblesFromNodeKey(from NodeKey) Nibbles { + return NewNibbles(from.Data, from.Offset) +} + +// Get the nibble at position i. +func (n Nibbles) At(i uint) uint8 { + ix := (n.offset + i) / NibblesPerByte + pad := uint8((n.offset + i) % NibblesPerByte) //nolint:gosec + return atLeft(pad, n.data[ix]) +} + +func atLeft(ix, b uint8) uint8 { + if ix == 1 { + return b & 0x0F + } else { + return b >> BitsPerNibble + } +} + +// Return object which represents a view on to this slice (further) offset by i nibbles. +func (n Nibbles) Mid(i uint) Nibbles { + return Nibbles{ + data: slices.Clone(n.data), + offset: n.offset + i, + } +} + +// Mask a byte, keeping left nibble. +func PadLeft(b uint8) uint8 { + return b &^ PaddingBitmask +} + +// Mask a byte, keeping right byte. +func PadRight(b uint8) uint8 { + return b & PaddingBitmask +} + +// A trie node prefix, it is the nibble path from the trie root +// to the trie node. +// For a node containing no partial key value it is the full key. +// For a value node or node containing a partial key, it is the full key minus its node partial +// nibbles (the node key can be split into prefix and node partial). +// Therefore it is always the leftmost portion of the node key, so its internal representation +// is a non expanded byte slice followed by a last padded byte representation. +// The padded byte is an optional padded value. +type Prefix struct { + Key []byte + Padded *byte +} + +func (p Prefix) JoinedBytes() []byte { + if p.Padded != nil { + return append(p.Key, *p.Padded) + } + return p.Key +} + +// Return left portion of [Nibbles], if the slice +// originates from a full key it will be the Prefix of +// the node. +func (n Nibbles) Left() Prefix { + split := n.offset / NibblesPerByte + ix := uint8(n.offset % NibblesPerByte) //nolint:gosec + if ix == 0 { + return Prefix{Key: n.data[:split]} + } + padded := PadLeft(n.data[split]) + return Prefix{Key: slices.Clone(n.data[:split]), Padded: &padded} +} + +func (n Nibbles) Len() uint { + return uint(len(n.data))*NibblesPerByte - n.offset +} + +// Advance the view on the slice by i nibbles. +func (n *Nibbles) Advance(i uint) { + if n.Len() < i { + panic("not enough nibbles to advance") + } + n.offset += i +} + +// Move back to a previously valid fix offset position. +func (n Nibbles) Back(i uint) Nibbles { + return Nibbles{data: n.data, offset: i} +} + +func (n Nibbles) Equal(them Nibbles) bool { + return n.Len() == them.Len() && n.StartsWith(them) +} + +func (n Nibbles) StartsWith(them Nibbles) bool { + return n.CommonPrefix(them) == them.Len() +} + +// Calculate the number of common nibble between two left aligned bytes. +func leftCommon(a uint8, b uint8) uint { + if a == b { + return 2 + } else if PadLeft(a) == PadLeft(b) { + return 1 + } else { + return 0 + } +} + +// Count the biggest common depth between two left aligned packed nibble slice. +func biggestDepth(v1 []uint8, v2 []uint8) uint { + upperBound := len(v1) + if len(v2) < upperBound { + upperBound = len(v2) + } + for a := 0; a < upperBound; a++ { + if v1[a] != v2[a] { + return uint(a)*NibblesPerByte + leftCommon(v1[a], v2[a]) //nolint:gosec + } + } + return uint(upperBound) * NibblesPerByte //nolint:gosec +} + +// How many of the same nibbles at the beginning do we match with them? +func (n Nibbles) CommonPrefix(them Nibbles) uint { + selfAlign := n.offset % NibblesPerByte + themAlign := them.offset % NibblesPerByte + if selfAlign == themAlign { + selfStart := n.offset / NibblesPerByte + themStart := them.offset / NibblesPerByte + var first uint = 0 + if selfAlign != 0 { + if PadRight(n.data[selfStart]) != PadRight(them.data[themStart]) { + // warning only for radix 16 + return 0 + } + selfStart += 1 + themStart += 1 + first += 1 + } + return biggestDepth(n.data[selfStart:], them.data[themStart:]) + first + } else { + s := n.Len() + if them.Len() < s { + s = them.Len() + } + var i uint + for i < s { + if n.At(i) != them.At(i) { + break + } + i++ + } + return i + } +} + +// Helper function to create a [NodeKey]. +func (n Nibbles) NodeKey() NodeKey { + split := n.offset / NibblesPerByte + offset := n.offset % NibblesPerByte + return NodeKey{offset, n.data[split:]} +} + +// Helper function to create a [NodeKey] for a given number of nibbles. +// Warning this method can be slow (number of nibble does not align the +// original padding). +func (n Nibbles) NodeKeyRange(nb uint) NodeKey { + if nb >= n.Len() { + return n.NodeKey() + } + if (n.offset+nb)%NibblesPerByte == 0 { + // aligned + start := n.offset / NibblesPerByte + end := (n.offset + nb) / NibblesPerByte + return NodeKey{ + Offset: n.offset % NibblesPerByte, + Data: n.data[start:end], + } + } + // unaligned + start := n.offset / NibblesPerByte + end := (n.offset + nb) / NibblesPerByte + ea := n.data[start : end+1] + eaOffset := n.offset % NibblesPerByte + nOffset := NumberPadding(nb) + result := NodeKey{ + Offset: eaOffset, + Data: ea, + } + result.ShiftKey(nOffset) + result.Data = result.Data[:len(result.Data)-1] + return result +} + +// Calculate the number of needed padding a array of nibble length i. +func NumberPadding(i uint) uint { + return i % NibblesPerByte +} + +// Representation of a nible slice (right aligned). +// It contains a right aligned padded first byte (first pair element is the number of nibbles +// (0 to max nb nibble - 1), second pair element is the padded nibble), and a slice over +// the remaining bytes. +type Partial struct { + First uint8 + PaddedNibble uint8 + Data []byte +} + +// Return [Partial] representation of this slice: +// first encoded byte and following slice. +func (n Nibbles) RightPartial() Partial { + split := n.offset / NibblesPerByte + nb := uint8(n.Len() % NibblesPerByte) //nolint:gosec + if nb > 0 { + return Partial{ + First: nb, + PaddedNibble: (n.data[split]), + Data: n.data[split+1:], + } + } + return Partial{ + First: 0, + PaddedNibble: 0, + Data: n.data[split:], + } +} + +// Return an iterator over [Partial] bytes representation. +func (n Nibbles) Right() []uint8 { + p := n.RightPartial() + var ret []uint8 + if p.First > 0 { + ret = append(ret, PadRight(p.PaddedNibble)) + } + for ix := 0; ix < len(p.Data); ix++ { + ret = append(ret, p.Data[ix]) + } + return ret +} + +// Push uint8 nibble value at a given index into an existing byte. +func PushAtLeft(ix uint8, v uint8, into uint8) uint8 { + var right uint8 + if ix == 1 { + right = v + } else { + right = v << BitsPerNibble + } + return into | right +} + +func (nb Nibbles) Clone() Nibbles { + return Nibbles{ + data: slices.Clone(nb.data), + offset: nb.offset, + } +} + +// Get [Prefix] representation of the inner data. +// +// This means the entire inner data will be returned as [Prefix], ignoring any offset. +func (nb Nibbles) OriginalDataPrefix() Prefix { + return Prefix{ + Key: nb.data, + } +} + +func (nb Nibbles) Compare(other Nibbles) int { + s := nb.Len() + if other.Len() < s { + s = other.Len() + } + + for i := uint(0); i < s; i++ { + nbAt := nb.At(i) + otherAt := other.At(i) + if nbAt < otherAt { + return -1 + } else if nbAt > otherAt { + return 1 + } + } + + if nb.Len() < other.Len() { + return -1 + } else if nb.Len() > other.Len() { + return 1 + } else { + return 0 + } +} + +// Partial node key type: offset and value. +// Offset is applied on first byte of array (bytes are right aligned). +type NodeKey struct { + Offset uint + Data []byte +} + +// Shifts right aligned key to add a given left offset. +// Resulting in possibly padding at both left and right +// (example usage when combining two keys). +func (nk *NodeKey) ShiftKey(offset uint) bool { + oldOffset := nk.Offset + nk.Offset = offset + if oldOffset > offset { + // shift left + kl := len(nk.Data) + for i := 0; i < kl-1; i++ { + nk.Data[i] = nk.Data[i]<<4 | nk.Data[i+1]>>4 + } + nk.Data[kl-1] = nk.Data[kl-1] << 4 + return true + } else if oldOffset < offset { + // shift right + nk.Data = append(nk.Data, 0) + for i := len(nk.Data) - 1; i >= 1; i-- { + nk.Data[i] = nk.Data[i-1]<<4 | nk.Data[i]>>4 + } + nk.Data[0] = nk.Data[0] >> 4 + return true + } + return false +} diff --git a/pkg/trie/triedb/nibbles/nibbles_test.go b/pkg/trie/triedb/nibbles/nibbles_test.go new file mode 100644 index 0000000000..9bbd4434b2 --- /dev/null +++ b/pkg/trie/triedb/nibbles/nibbles_test.go @@ -0,0 +1,92 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package nibbles + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNibbles(t *testing.T) { + data := []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef} + nibbles := NewNibbles(data) + for i := 0; i < (2 * len(data)); i++ { + assert.Equal(t, uint8(i), nibbles.At(uint(i))) + } +} + +func TestNibbles_MidAndLeft(t *testing.T) { + n := NewNibbles([]byte{0x01, 0x23, 0x45}) + m := n.Mid(2) + for i := 0; i < 4; i++ { + assert.Equal(t, m.At(uint(i)), uint8(i+2)) + } + assert.Equal(t, Prefix{Key: []byte{0x01}}, m.Left()) + m = n.Mid(3) + for i := 0; i < 3; i++ { + assert.Equal(t, m.At(uint(i)), uint8(i+3)) + } + padded := uint8(0x23 &^ 0x0F) + assert.Equal(t, Prefix{Key: []byte{0x01}, Padded: &padded}, m.Left()) +} + +func TestNibbles_Right(t *testing.T) { + data := []uint8{1, 2, 3, 4, 5, 234, 78, 99} + nibbles := NewNibbles(data) + assert.Equal(t, data, nibbles.Right()) + + nibbles = NewNibbles(data, 3) + assert.Equal(t, data[1:], nibbles.Right()) +} + +func TestNibbles_CommonPrefix(t *testing.T) { + n := NewNibbles([]byte{0x01, 0x23, 0x45}) + + other := []byte{0x01, 0x23, 0x01, 0x23, 0x45, 0x67} + m := NewNibbles(other) + + assert.Equal(t, uint(4), n.CommonPrefix(m)) + assert.Equal(t, uint(4), m.CommonPrefix(n)) + assert.Equal(t, uint(3), n.Mid(1).CommonPrefix(m.Mid(1))) + assert.Equal(t, uint(0), n.Mid(1).CommonPrefix(m.Mid(2))) + assert.Equal(t, uint(6), n.CommonPrefix(m.Mid(4))) + assert.False(t, n.StartsWith(m.Mid(4))) + assert.True(t, m.Mid(4).StartsWith(n)) +} + +func TestNibbles_Compare(t *testing.T) { + n := NewNibbles([]byte{1, 35}) + m := NewNibbles([]byte{1}) + + assert.Equal(t, -1, m.Compare(n)) + assert.Equal(t, 1, n.Compare(m)) + + n = NewNibbles([]byte{1, 35}) + m = NewNibbles([]byte{1, 35}) + + assert.Equal(t, 0, m.Compare(n)) + + n = NewNibbles([]byte{1, 35}) + m = NewNibbles([]byte{3, 35}) + assert.Equal(t, -1, n.Compare(m)) + assert.Equal(t, 1, m.Compare(n)) +} + +func TestNibbles_Advance(t *testing.T) { + n := NewNibbles([]byte{1, 35}) + n.Advance(1) + n.Advance(1) + n.Advance(1) + n.Advance(1) + require.Panics(t, func() { n.Advance(1) }) + + n = NewNibbles([]byte{1, 35}) + require.Panics(t, func() { n.Advance(5) }) + + n = NewNibbles([]byte{1, 35}) + n.Advance(4) + require.Panics(t, func() { n.Advance(1) }) +} diff --git a/pkg/trie/triedb/nibbles/nibbleslice.go b/pkg/trie/triedb/nibbles/nibbleslice.go new file mode 100644 index 0000000000..b9c35f2425 --- /dev/null +++ b/pkg/trie/triedb/nibbles/nibbleslice.go @@ -0,0 +1,161 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package nibbles + +import "slices" + +// NOTE: This is to facilitate easier truncation, and to reference odd number +// lengths omitting the last nibble of the last byte. +type NibbleSlice struct { + inner []byte + len uint +} + +// Construct a new [NibbleSlice]. +func NewNibbleSlice() NibbleSlice { + return NibbleSlice{ + inner: make([]byte, 0), + } +} + +// Construct a new [NibbleSlice] from [Nibbles]. +func NewNibbleSliceFromNibbles(s Nibbles) NibbleSlice { + v := NewNibbleSlice() + for i := uint(0); i < s.Len(); i++ { + v.Push(s.At(i)) + } + return v +} + +// Returns true if [NibbleSlice] has zero length. +func (n NibbleSlice) IsEmpty() bool { + return n.len == 0 +} + +// Try to get the nibble at the given offset. +func (n NibbleSlice) At(idx uint) uint8 { + ix := idx / NibblesPerByte + pad := idx % NibblesPerByte + return atLeft(uint8(pad), n.inner[ix]) //nolint:gosec +} + +// Push a nibble onto the [NibbleSlice]. Ignores the high 4 bits. +func (n *NibbleSlice) Push(nibble uint8) { + i := n.len % NibblesPerByte + if i == 0 { + n.inner = append(n.inner, PushAtLeft(0, nibble, 0)) + } else { + output := n.inner[len(n.inner)-1] + n.inner[len(n.inner)-1] = PushAtLeft(uint8(i), nibble, output) //nolint:gosec + } + n.len++ +} + +// Try to pop a nibble off the NibbleSlice. Fails if len == 0. +func (n *NibbleSlice) Pop() *uint8 { + if n.IsEmpty() { + return nil + } + b := n.inner[len(n.inner)-1] + n.inner = n.inner[:len(n.inner)-1] + n.len -= 1 + iNew := n.len % NibblesPerByte + if iNew != 0 { + n.inner = append(n.inner, PadLeft(b)) + } + popped := atLeft(uint8(iNew), b) + return &popped +} + +// Append a [Partial]. Can be slow (alignement of partial). +func (n *NibbleSlice) AppendPartial(p Partial) { + if p.First == 1 { + n.Push(atLeft(1, p.PaddedNibble)) + } + pad := uint(len(n.inner))*NibblesPerByte - n.len + if pad == 0 { + n.inner = append(n.inner, p.Data...) + } else { + kend := uint(len(n.inner)) - 1 + if len(p.Data) > 0 { + n.inner[kend] = PadLeft(n.inner[kend]) + n.inner[kend] |= p.Data[0] >> 4 + for i := 0; i < len(p.Data)-1; i++ { + n.inner = append(n.inner, p.Data[i]<<4|p.Data[i+1]>>4) + } + n.inner = append(n.inner, p.Data[len(p.Data)-1]<<4) + } + } + n.len += uint(len(p.Data)) * NibblesPerByte +} + +// Utility function for chaining two optional appending +// of [NibbleSlice] and/or a byte. +// Can be slow. +func (n *NibbleSlice) AppendOptionalSliceAndNibble(oSlice *Nibbles, oIndex *uint8) uint { + var res uint + if oSlice != nil { + n.AppendPartial(oSlice.RightPartial()) + res += oSlice.Len() + } + if oIndex != nil { + n.Push(*oIndex) + res += 1 + } + return res +} + +// Get Prefix representation of this [NibbleSlice]. +func (n NibbleSlice) Prefix() Prefix { + split := n.len / NibblesPerByte + pos := uint8(n.len % NibblesPerByte) //nolint:gosec + if pos == 0 { + return Prefix{ + Key: n.inner[:split], + } + } else { + padded := PadLeft(n.inner[split]) + return Prefix{ + Key: n.inner[:split], + Padded: &padded, + } + } +} + +func (n *NibbleSlice) Clear() { + n.inner = make([]byte, 0) + n.len = 0 +} + +// Remove the last num nibbles in a faster way than popping num times. +func (n *NibbleSlice) DropLasts(num uint) { + if num == 0 { + return + } + if num >= n.len { + n.Clear() + return + } + end := n.len - num + endIndex := end / NibblesPerByte + if end%NibblesPerByte != 0 { + endIndex++ + } + for i := endIndex; i < uint(len(n.inner)); endIndex++ { + n.inner = n.inner[:len(n.inner)-1] + } + n.len = end + pos := n.len % NibblesPerByte + if pos != 0 { + kl := len(n.inner) - 1 + n.inner[kl] = PadLeft(n.inner[kl]) + } +} + +func (n NibbleSlice) Clone() NibbleSlice { + return NibbleSlice{ + inner: slices.Clone(n.inner), + len: n.len, + } +} diff --git a/pkg/trie/triedb/nibbles/nibbleslice_test.go b/pkg/trie/triedb/nibbles/nibbleslice_test.go new file mode 100644 index 0000000000..fb5c7c0421 --- /dev/null +++ b/pkg/trie/triedb/nibbles/nibbleslice_test.go @@ -0,0 +1,54 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package nibbles + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNibbleSlice_Push(t *testing.T) { + v := NewNibbleSlice() + for i := uint(0); i < NibblesPerByte*3; i++ { + iu8 := uint8(i % NibblesPerByte) + v.Push(iu8) + assert.Equal(t, i, v.len-1) + assert.Equal(t, v.At(i), iu8) + } + + for i := int(NibblesPerByte*3) - 1; i >= 0; i-- { + iu8 := uint8(uint(i) % NibblesPerByte) + a := v.Pop() + assert.NotNil(t, a) + assert.Equal(t, iu8, *a) + assert.Equal(t, v.len, uint(i)) + } +} + +func TestNibbleSlice_AppendPartial(t *testing.T) { + t.Run("", func(t *testing.T) { + appendPartial(t, []byte{1, 2, 3}, []byte{}, Partial{First: 1, PaddedNibble: 1, Data: []byte{0x23}}) + }) + t.Run("", func(t *testing.T) { + appendPartial(t, []byte{1, 2, 3}, []byte{1}, Partial{First: 0, PaddedNibble: 0, Data: []byte{0x23}}) + }) + t.Run("", func(t *testing.T) { + appendPartial(t, []byte{0, 1, 2, 3}, []byte{0}, Partial{First: 1, PaddedNibble: 1, Data: []byte{0x23}}) + }) +} + +func appendPartial(t *testing.T, res []uint8, init []uint8, partial Partial) { + t.Helper() + resv := NewNibbleSlice() + for _, r := range res { + resv.Push(r) + } + initv := NewNibbleSlice() + for _, r := range init { + initv.Push(r) + } + initv.AppendPartial(partial) + assert.Equal(t, resv, initv) +} diff --git a/pkg/trie/triedb/node.go b/pkg/trie/triedb/node.go index ccf9556eab..e9ad1808b4 100644 --- a/pkg/trie/triedb/node.go +++ b/pkg/trie/triedb/node.go @@ -8,128 +8,135 @@ import ( "fmt" "io" - "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/pkg/scale" - nibbles "github.com/ChainSafe/gossamer/pkg/trie/codec" "github.com/ChainSafe/gossamer/pkg/trie/db" "github.com/ChainSafe/gossamer/pkg/trie/triedb/codec" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/hash" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibbles" ) -type nodeValue interface { - getHash() common.Hash - equal(other nodeValue) bool +type nodeValue[H hash.Hash] interface { + equal(other nodeValue[H]) bool } type ( // inline is an inlined value representation - inline []byte + inline[H hash.Hash] []byte // valueRef is a reference to a value stored in the db - valueRef common.Hash + valueRef[H hash.Hash] struct { + hash H + } // newValueRef is a value that will be stored in the db - newValueRef struct { - hash common.Hash + newValueRef[H hash.Hash] struct { + hash H data []byte } ) // newEncodedValue creates an EncodedValue from a nodeValue -func newEncodedValue(value nodeValue, partial []byte, childF onChildStoreFn) (codec.EncodedValue, error) { +func newEncodedValue[H hash.Hash]( + value nodeValue[H], partial *nibbles.Nibbles, childF onChildStoreFn, +) (codec.EncodedValue, error) { switch v := value.(type) { - case inline: + case inline[H]: return codec.InlineValue(v), nil - case valueRef: - return codec.HashedValue(v), nil - case newValueRef: + case valueRef[H]: + return codec.HashedValue[H]{Hash: v.hash}, nil + case newValueRef[H]: // Store value in db - childRef, err := childF(newNodeToEncode{partialKey: partial, value: v.data}, partial, nil) + childRef, err := childF(newNodeToEncode{value: v.data}, partial, nil) if err != nil { return nil, err } // Check and get new new value hash switch cr := childRef.(type) { - case HashChildReference: - if common.Hash(cr) == common.EmptyHash { + case HashChildReference[H]: + empty := *new(H) + if cr.Hash == empty { panic("new external value are always added before encoding a node") } - if v.hash != common.EmptyHash { - if v.hash != common.Hash(cr) { + if v.hash != empty { + if v.hash != cr.Hash { panic("hash mismatch") } } else { - v.hash = common.Hash(cr) + v.hash = cr.Hash } default: panic("value node can never be inlined") } - return codec.HashedValue(v.hash), nil + return codec.HashedValue[H]{Hash: v.hash}, nil default: panic("unreachable") } } -func (inline) getHash() common.Hash { return common.EmptyHash } -func (n inline) equal(other nodeValue) bool { +func (n inline[H]) equal(other nodeValue[H]) bool { switch otherValue := other.(type) { - case inline: + case inline[H]: return bytes.Equal(n, otherValue) default: return false } } -func (vr valueRef) getHash() common.Hash { return common.Hash(vr) } -func (vr valueRef) equal(other nodeValue) bool { + +func (vr valueRef[H]) getHash() H { return vr.hash } +func (vr valueRef[H]) equal(other nodeValue[H]) bool { switch otherValue := other.(type) { - case valueRef: - return vr == otherValue + case valueRef[H]: + return vr.hash == otherValue.hash default: return false } } -func (vr newValueRef) getHash() common.Hash { +func (vr newValueRef[H]) getHash() H { return vr.hash } -func (vr newValueRef) equal(other nodeValue) bool { +func (vr newValueRef[H]) equal(other nodeValue[H]) bool { switch otherValue := other.(type) { - case newValueRef: + case newValueRef[H]: return vr.hash == otherValue.hash default: return false } } -func NewValue(data []byte, threshold int) nodeValue { +func NewValue[H hash.Hash](data []byte, threshold int) nodeValue[H] { if len(data) >= threshold { - return newValueRef{data: data} + return newValueRef[H]{ + hash: *new(H), + data: data, + } } - return inline(data) + return inline[H](data) } -func NewValueFromEncoded(encodedValue codec.EncodedValue) nodeValue { +func NewValueFromEncoded[H hash.Hash](encodedValue codec.EncodedValue) nodeValue[H] { switch v := encodedValue.(type) { case codec.InlineValue: - return inline(v) - case codec.HashedValue: - return valueRef(v) + return inline[H](v) + case codec.HashedValue[H]: + return valueRef[H]{v.Hash} } return nil } -func inMemoryFetchedValue(value nodeValue, prefix []byte, db db.DBGetter) ([]byte, error) { +func inMemoryFetchedValue[H hash.Hash](value nodeValue[H], prefix []byte, db db.DBGetter) ([]byte, error) { switch v := value.(type) { - case inline: + case inline[H]: return v, nil - case newValueRef: + case newValueRef[H]: return v.data, nil - case valueRef: - prefixedKey := bytes.Join([][]byte{prefix, v[:]}, nil) + case valueRef[H]: + prefixedKey := bytes.Join([][]byte{prefix, v.hash.Bytes()}, nil) value, err := db.Get(prefixedKey) if err != nil { return nil, err @@ -143,32 +150,38 @@ func inMemoryFetchedValue(value nodeValue, prefix []byte, db db.DBGetter) ([]byt } } +type NodeTypes[H hash.Hash] interface { + Empty | Leaf[H] | Branch[H] + Node +} type Node interface { - getPartialKey() []byte + getPartialKey() *nodeKey } +type nodeKey = nibbles.NodeKey + type ( - Empty struct{} - Leaf struct { - partialKey []byte - value nodeValue + Empty struct{} + Leaf[H hash.Hash] struct { + partialKey nodeKey + value nodeValue[H] } - Branch struct { - partialKey []byte + Branch[H hash.Hash] struct { + partialKey nodeKey children [codec.ChildrenCapacity]NodeHandle - value nodeValue + value nodeValue[H] } ) -func (Empty) getPartialKey() []byte { return nil } -func (n Leaf) getPartialKey() []byte { return n.partialKey } -func (n Branch) getPartialKey() []byte { return n.partialKey } +func (Empty) getPartialKey() *nodeKey { return nil } +func (n Leaf[H]) getPartialKey() *nodeKey { return &n.partialKey } +func (n Branch[H]) getPartialKey() *nodeKey { return &n.partialKey } // Create a new node from the encoded data, decoding this data into a codec.Node // and mapping that with this node type -func newNodeFromEncoded(nodeHash common.Hash, data []byte, storage nodeStorage) (Node, error) { +func newNodeFromEncoded[H hash.Hash](nodeHash H, data []byte, storage nodeStorage[H]) (Node, error) { reader := bytes.NewReader(data) - encodedNode, err := codec.Decode(reader) + encodedNode, err := codec.Decode[H](reader) if err != nil { return nil, err } @@ -177,7 +190,10 @@ func newNodeFromEncoded(nodeHash common.Hash, data []byte, storage nodeStorage) case codec.Empty: return Empty{}, nil case codec.Leaf: - return Leaf{partialKey: encoded.PartialKey, value: NewValueFromEncoded(encoded.Value)}, nil + return Leaf[H]{ + partialKey: encoded.PartialKey.NodeKey(), + value: NewValueFromEncoded[H](encoded.Value), + }, nil case codec.Branch: key := encoded.PartialKey encodedChildren := encoded.Children @@ -185,7 +201,7 @@ func newNodeFromEncoded(nodeHash common.Hash, data []byte, storage nodeStorage) child := func(i int) (NodeHandle, error) { if encodedChildren[i] != nil { - newChild, err := newFromEncodedMerkleValue(nodeHash, encodedChildren[i], storage) + newChild, err := newFromEncodedMerkleValue[H](nodeHash, encodedChildren[i], storage) if err != nil { return nil, err } @@ -203,7 +219,7 @@ func newNodeFromEncoded(nodeHash common.Hash, data []byte, storage nodeStorage) children[i] = child } - return Branch{partialKey: key, children: children, value: NewValueFromEncoded(value)}, nil + return Branch[H]{partialKey: key.NodeKey(), children: children, value: NewValueFromEncoded[H](value)}, nil default: panic("unreachable") } @@ -215,8 +231,7 @@ type nodeToEncode interface { type ( newNodeToEncode struct { - partialKey []byte - value []byte + value []byte } trieNodeToEncode struct { child NodeHandle @@ -236,44 +251,46 @@ type ChildReference interface { type ( // HashChildReference is a reference to a child node that is not inlined - HashChildReference common.Hash + HashChildReference[H hash.Hash] struct{ Hash H } // InlineChildReference is a reference to an inlined child node InlineChildReference []byte ) -func (h HashChildReference) getNodeData() []byte { - return h[:] +func (h HashChildReference[H]) getNodeData() []byte { + return h.Hash.Bytes() } func (i InlineChildReference) getNodeData() []byte { return i } -type onChildStoreFn = func(node nodeToEncode, partialKey []byte, childIndex *byte) (ChildReference, error) +type onChildStoreFn = func(node nodeToEncode, partialKey *nibbles.Nibbles, childIndex *byte) (ChildReference, error) const EmptyTrieBytes = byte(0) // newEncodedNode creates a new encoded node from a node and a child store function and return its bytes -func newEncodedNode(node Node, childF onChildStoreFn) (encodedNode []byte, err error) { +func newEncodedNode[H hash.Hash](node Node, childF onChildStoreFn) (encodedNode []byte, err error) { encodingBuffer := bytes.NewBuffer(nil) switch n := node.(type) { case Empty: return []byte{EmptyTrieBytes}, nil - case Leaf: - pr := n.partialKey - value, err := newEncodedValue(n.value, pr, childF) + case Leaf[H]: + partialKey := nibbles.NewNibbles(n.partialKey.Data, n.partialKey.Offset) + value, err := newEncodedValue[H](n.value, &partialKey, childF) if err != nil { return nil, err } - - err = NewEncodedLeaf(pr, value, encodingBuffer) + right := partialKey.Right() + len := partialKey.Len() + err = NewEncodedLeaf(right, len, value, encodingBuffer) if err != nil { return nil, err } - case Branch: + case Branch[H]: + partialKey := nibbles.NewNibbles(n.partialKey.Data, n.partialKey.Offset) var value codec.EncodedValue if n.value != nil { - value, err = newEncodedValue(n.value, n.partialKey, childF) + value, err = newEncodedValue[H](n.value, &partialKey, childF) if err != nil { return nil, err } @@ -286,13 +303,14 @@ func newEncodedNode(node Node, childF onChildStoreFn) (encodedNode []byte, err e } childIndex := byte(i) - children[i], err = childF(trieNodeToEncode{child}, n.partialKey, &childIndex) + + children[i], err = childF(trieNodeToEncode{child}, &partialKey, &childIndex) if err != nil { return nil, err } } - err := NewEncodedBranch(n.partialKey, children, value, encodingBuffer) + err := NewEncodedBranch(partialKey.Right(), partialKey.Len(), children, value, encodingBuffer) if err != nil { return nil, err } @@ -304,29 +322,22 @@ func newEncodedNode(node Node, childF onChildStoreFn) (encodedNode []byte, err e } // NewEncodedLeaf creates a new encoded leaf node and writes it to the writer -func NewEncodedLeaf(partialKey []byte, value codec.EncodedValue, writer io.Writer) error { +func NewEncodedLeaf(partialKey []byte, numberNibble uint, value codec.EncodedValue, writer io.Writer) error { // Write encoded header if value.IsHashed() { - err := codec.EncodeHeader(partialKey, codec.LeafWithHashedValue, writer) + err := codec.EncodeHeader(partialKey, numberNibble, codec.LeafWithHashedValue, writer) if err != nil { return fmt.Errorf("encoding header for leaf with hashed value: %w", err) } } else { - err := codec.EncodeHeader(partialKey, codec.LeafNode, writer) + err := codec.EncodeHeader(partialKey, numberNibble, codec.LeafNode, writer) if err != nil { return fmt.Errorf("encoding header for leaf node value: %w", err) } } - // Write partial key - keyLE := nibbles.NibblesToKeyLE(partialKey) - _, err := writer.Write(keyLE) - if err != nil { - return fmt.Errorf("cannot write LE key to buffer: %w", err) - } - // Write encoded value - err = value.Write(writer) + err := value.Write(writer) if err != nil { return fmt.Errorf("writing leaf value: %w", err) } @@ -336,35 +347,29 @@ func NewEncodedLeaf(partialKey []byte, value codec.EncodedValue, writer io.Write // NewEncodedBranch creates a new encoded branch node and writes it to the writer func NewEncodedBranch( partialKey []byte, + numberNibbles uint, children [codec.ChildrenCapacity]ChildReference, value codec.EncodedValue, writer io.Writer, ) error { // Write encoded header if value == nil { - err := codec.EncodeHeader(partialKey, codec.BranchWithoutValue, writer) + err := codec.EncodeHeader(partialKey, numberNibbles, codec.BranchWithoutValue, writer) if err != nil { return fmt.Errorf("encoding header for branch without value: %w", err) } } else if value.IsHashed() { - err := codec.EncodeHeader(partialKey, codec.BranchWithHashedValue, writer) + err := codec.EncodeHeader(partialKey, numberNibbles, codec.BranchWithHashedValue, writer) if err != nil { return fmt.Errorf("encoding header for branch with hashed value: %w", err) } } else { - err := codec.EncodeHeader(partialKey, codec.BranchWithValue, writer) + err := codec.EncodeHeader(partialKey, numberNibbles, codec.BranchWithValue, writer) if err != nil { return fmt.Errorf("encoding header for branch with value: %w", err) } } - // Write partial key - keyLE := nibbles.NibblesToKeyLE(partialKey) - _, err := writer.Write(keyLE) - if err != nil { - return fmt.Errorf("cannot write LE key to buffer: %w", err) - } - // Write bitmap var bitmap uint16 for i := range children { @@ -373,8 +378,8 @@ func NewEncodedBranch( } bitmap |= 1 << uint(i) } - childrenBitmap := common.Uint16ToBytes(bitmap) - _, err = writer.Write(childrenBitmap) + encoder := scale.NewEncoder(writer) + err := encoder.Encode(bitmap) if err != nil { return fmt.Errorf("writing branch bitmap: %w", err) } diff --git a/pkg/trie/triedb/node_storage.go b/pkg/trie/triedb/node_storage.go index c6deb613bd..f3faced47d 100644 --- a/pkg/trie/triedb/node_storage.go +++ b/pkg/trie/triedb/node_storage.go @@ -4,13 +4,12 @@ package triedb import ( - "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/pkg/trie/triedb/codec" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/hash" "github.com/gammazero/deque" ) var EmptyNode = []byte{0} -var hashedNullNode = common.MustBlake2bHash(EmptyNode) // StorageHandle is a pointer to a node contained in `NodeStorage` type storageHandle int @@ -24,21 +23,21 @@ type NodeHandle interface { } type ( - inMemory storageHandle - persisted common.Hash + inMemory storageHandle + persisted[H any] struct{ hash H } ) -func (inMemory) isNodeHandle() {} -func (persisted) isNodeHandle() {} +func (inMemory) isNodeHandle() {} +func (persisted[H]) isNodeHandle() {} -func newFromEncodedMerkleValue( - parentHash common.Hash, +func newFromEncodedMerkleValue[H hash.Hash]( + parentHash H, encodedNodeHandle codec.MerkleValue, - storage nodeStorage, + storage nodeStorage[H], ) (NodeHandle, error) { switch encoded := encodedNodeHandle.(type) { - case codec.HashedNode: - return persisted(encoded), nil + case codec.HashedNode[H]: + return persisted[H]{encoded.Hash}, nil case codec.InlineNode: child, err := newNodeFromEncoded(parentHash, encoded, storage) if err != nil { @@ -63,34 +62,34 @@ type ( NewStoredNode struct { node Node } - CachedStoredNode struct { + CachedStoredNode[H any] struct { node Node - hash common.Hash + hash H } ) func (n NewStoredNode) getNode() Node { return n.node } -func (n CachedStoredNode) getNode() Node { +func (n CachedStoredNode[H]) getNode() Node { return n.node } // nodeStorage is a struct that contains all the temporal nodes that are stored // in the trieDB before being written to the backed db -type nodeStorage struct { +type nodeStorage[H any] struct { nodes []StoredNode freeIndices *deque.Deque[int] } -func newNodeStorage() nodeStorage { - return nodeStorage{ +func newNodeStorage[H any]() nodeStorage[H] { + return nodeStorage[H]{ nodes: make([]StoredNode, 0), freeIndices: deque.New[int](0), } } -func (ns *nodeStorage) alloc(stored StoredNode) storageHandle { +func (ns *nodeStorage[H]) alloc(stored StoredNode) storageHandle { if ns.freeIndices.Len() > 0 { idx := ns.freeIndices.PopFront() ns.nodes[idx] = stored @@ -101,7 +100,7 @@ func (ns *nodeStorage) alloc(stored StoredNode) storageHandle { return storageHandle(len(ns.nodes) - 1) } -func (ns *nodeStorage) destroy(handle storageHandle) StoredNode { +func (ns *nodeStorage[H]) destroy(handle storageHandle) StoredNode { idx := int(handle) ns.freeIndices.PushBack(idx) oldNode := ns.nodes[idx] @@ -110,11 +109,11 @@ func (ns *nodeStorage) destroy(handle storageHandle) StoredNode { return oldNode } -func (ns *nodeStorage) get(handle storageHandle) Node { +func (ns *nodeStorage[H]) get(handle storageHandle) Node { switch n := ns.nodes[handle].(type) { case NewStoredNode: return n.node - case CachedStoredNode: + case CachedStoredNode[H]: return n.node default: panic("unreachable") diff --git a/pkg/trie/triedb/print.go b/pkg/trie/triedb/print.go index cba5d31506..dfc5e45374 100644 --- a/pkg/trie/triedb/print.go +++ b/pkg/trie/triedb/print.go @@ -5,12 +5,10 @@ package triedb import ( "fmt" - - "github.com/ChainSafe/gossamer/lib/common" ) -func (t *TrieDB) String() string { - if t.rootHash == common.EmptyHash { +func (t *TrieDB[H, Hasher]) String() string { + if t.rootHash == (*new(H)) { return "empty" } diff --git a/pkg/trie/triedb/proof/generate.go b/pkg/trie/triedb/proof/generate.go index 39d414f9e4..946fba51fb 100644 --- a/pkg/trie/triedb/proof/generate.go +++ b/pkg/trie/triedb/proof/generate.go @@ -6,9 +6,10 @@ package proof import ( "bytes" - "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/pkg/trie/triedb" "github.com/ChainSafe/gossamer/pkg/trie/triedb/codec" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/hash" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibbles" "github.com/gammazero/deque" "golang.org/x/exp/slices" ) @@ -18,12 +19,14 @@ type nodeHandle interface { } type ( - nodeHandleHash common.Hash + nodeHandleHash[H any] struct { + hash H + } nodeHandleInline []byte ) -func (nodeHandleHash) isNodeHandle() {} -func (nodeHandleInline) isNodeHandle() {} +func (nodeHandleHash[H]) isNodeHandle() {} +func (nodeHandleInline) isNodeHandle() {} type genProofStep interface { isProofStep() @@ -31,7 +34,7 @@ type genProofStep interface { type ( genProofStepDescend struct { - childPrefixLen int + childPrefixLen uint child nodeHandle } genProofStepFoundValue struct { @@ -46,36 +49,36 @@ func (genProofStepDescend) isProofStep() {} func (genProofStepFoundValue) isProofStep() {} func (genProofStepFoundHashedValue) isProofStep() {} -type genProofStackEntry struct { +type genProofStackEntry[H hash.Hash] struct { // prefix is the nibble path to the node in the trie - prefix []byte + prefix nibbles.LeftNibbles // node is the stacked node node codec.EncodedNode // encodedNode is the encoded node data encodedNode []byte // nodeHash of the node or nil if the node is inlined - nodeHash *common.Hash + nodeHash []byte // omitValue is a flag to know if the value should be omitted in the generated proof omitValue bool // childIndex is used for branch nodes - childIndex int + childIndex uint // children contains the child references to use in constructing the proof nodes. children triedb.ChildReferences // outputIndex is the index into the proof vector that the encoding of this entry should be placed at. - outputIndex *int + outputIndex *uint } -func newGenProofStackEntry( - prefix []byte, +func newGenProofStackEntry[H hash.Hash]( + prefix nibbles.LeftNibbles, nodeData []byte, - nodeHash *common.Hash, - outputIndex *int) (*genProofStackEntry, error) { - node, err := codec.Decode(bytes.NewReader(nodeData)) + nodeHash []byte, + outputIndex *uint) (*genProofStackEntry[H], error) { + node, err := codec.Decode[H](bytes.NewReader(nodeData)) if err != nil { return nil, err } - return &genProofStackEntry{ + return &genProofStackEntry[H]{ prefix: prefix, node: node, encodedNode: nodeData, @@ -87,7 +90,7 @@ func newGenProofStackEntry( }, nil } -func (e *genProofStackEntry) encodeNode() ([]byte, error) { +func (e *genProofStackEntry[H]) encodeNode() ([]byte, error) { switch n := e.node.(type) { case codec.Empty: return e.encodedNode, nil @@ -97,7 +100,7 @@ func (e *genProofStackEntry) encodeNode() ([]byte, error) { } encodingBuffer := bytes.NewBuffer(nil) - err := triedb.NewEncodedLeaf(e.node.GetPartialKey(), codec.InlineValue{}, encodingBuffer) + err := triedb.NewEncodedLeaf(n.PartialKey.Right(), n.PartialKey.Len(), codec.InlineValue{}, encodingBuffer) if err != nil { return nil, err } @@ -110,7 +113,7 @@ func (e *genProofStackEntry) encodeNode() ([]byte, error) { } e.completBranchChildren(n.Children, e.childIndex) encodingBuffer := bytes.NewBuffer(nil) - err := triedb.NewEncodedBranch(e.node.GetPartialKey(), e.children, value, encodingBuffer) + err := triedb.NewEncodedBranch(n.PartialKey.Right(), n.PartialKey.Len(), e.children, value, encodingBuffer) if err != nil { return nil, err } @@ -120,7 +123,7 @@ func (e *genProofStackEntry) encodeNode() ([]byte, error) { } } -func (e *genProofStackEntry) setChild(encodedChild []byte) { +func (e *genProofStackEntry[H]) setChild(encodedChild []byte) { var childRef triedb.ChildReference switch n := e.node.(type) { case codec.Branch: @@ -137,23 +140,23 @@ func (e *genProofStackEntry) setChild(encodedChild []byte) { e.childIndex++ } -func (e *genProofStackEntry) completBranchChildren( +func (e *genProofStackEntry[H]) completBranchChildren( childHandles [codec.ChildrenCapacity]codec.MerkleValue, - childIndex int, + childIndex uint, ) { for i := childIndex; i < codec.ChildrenCapacity; i++ { switch n := childHandles[i].(type) { case codec.InlineNode: e.children[i] = triedb.InlineChildReference(n) - case codec.HashedNode: - e.children[i] = triedb.HashChildReference(common.Hash(n)) + case codec.HashedNode[H]: + e.children[i] = triedb.HashChildReference[H](n) } } } -func (e *genProofStackEntry) replaceChildRef(encodedChild []byte, child codec.MerkleValue) triedb.ChildReference { +func (e *genProofStackEntry[H]) replaceChildRef(encodedChild []byte, child codec.MerkleValue) triedb.ChildReference { switch child.(type) { - case codec.HashedNode: + case codec.HashedNode[H]: return triedb.InlineChildReference(nil) case codec.InlineNode: return triedb.InlineChildReference(encodedChild) @@ -162,17 +165,17 @@ func (e *genProofStackEntry) replaceChildRef(encodedChild []byte, child codec.Me } } -// / Unwind the stack until the given key is prefixed by the entry at the top of the stack. If the -// / key is nil, unwind the stack completely. As entries are popped from the stack, they are -// / encoded into proof nodes and added to the finalized proof. -func unwindStack( - stack *deque.Deque[*genProofStackEntry], +// Unwind the stack until the given key is prefixed by the entry at the top of the stack. If the +// key is nil, unwind the stack completely. As entries are popped from the stack, they are +// encoded into proof nodes and added to the finalized proof. +func unwindStack[H hash.Hash]( + stack *deque.Deque[*genProofStackEntry[H]], proofNodes [][]byte, - maybeKey *[]byte, + maybeKey *nibbles.LeftNibbles, ) error { for stack.Len() > 0 { entry := stack.PopBack() - if maybeKey != nil && bytes.HasPrefix(*maybeKey, entry.prefix) { + if maybeKey != nil && maybeKey.StartsWith(entry.prefix) { stack.PushBack(entry) break } @@ -206,25 +209,25 @@ func sortAndDeduplicateKeys(keys []string) []string { return deduplicatedkeys } -func genProofMatchKeyToNode( +func genProofMatchKeyToNode[H hash.Hash]( node codec.EncodedNode, omitValue *bool, - childIndex *int, - key []byte, - prefixlen int, - recordedNodes *Iterator[triedb.Record], + childIndex *uint, + key nibbles.LeftNibbles, + prefixlen uint, + recordedNodes *Iterator[triedb.Record[H]], ) (genProofStep, error) { switch n := node.(type) { case codec.Empty: return genProofStepFoundValue{nil}, nil case codec.Leaf: - if bytes.Contains(key, n.PartialKey) && len(key) == prefixlen+len(n.PartialKey) { + if key.Contains(n.PartialKey, prefixlen) && key.Len() == prefixlen+n.PartialKey.Len() { switch v := n.Value.(type) { case codec.InlineValue: *omitValue = true value := []byte(v) return genProofStepFoundValue{&value}, nil - case codec.HashedValue: + case codec.HashedValue[H]: *omitValue = true return resolveValue(recordedNodes) } @@ -246,21 +249,21 @@ func genProofMatchKeyToNode( } } -func genProofMatchKeyToBranchNode( +func genProofMatchKeyToBranchNode[H hash.Hash]( value codec.EncodedValue, childHandles [codec.ChildrenCapacity]codec.MerkleValue, - childIndex *int, + childIndex *uint, omitValue *bool, - key []byte, - prefixlen int, - nodePartialKey []byte, - recordedNodes *Iterator[triedb.Record], + key nibbles.LeftNibbles, + prefixlen uint, + nodePartialKey nibbles.Nibbles, + recordedNodes *Iterator[triedb.Record[H]], ) (genProofStep, error) { - if !bytes.Contains(key, nodePartialKey) { + if !key.Contains(nodePartialKey, prefixlen) { return genProofStepFoundValue{nil}, nil } - if len(key) == prefixlen+len(nodePartialKey) { + if key.Len() == prefixlen+nodePartialKey.Len() { if value == nil { return genProofStepFoundValue{nil}, nil } @@ -270,38 +273,38 @@ func genProofMatchKeyToBranchNode( *omitValue = true value := []byte(v) return genProofStepFoundValue{&value}, nil - case codec.HashedValue: + case codec.HashedValue[H]: *omitValue = true return resolveValue(recordedNodes) } } - newIndex := int(key[prefixlen+len(nodePartialKey)]) + newIndex := *key.At(prefixlen + nodePartialKey.Len()) - if newIndex < *childIndex { + if uint(newIndex) < *childIndex { panic("newIndex out of bounds") } - *childIndex = newIndex + *childIndex = uint(newIndex) if childHandles[newIndex] != nil { var child nodeHandle switch c := childHandles[newIndex].(type) { - case codec.HashedNode: - child = nodeHandleHash(c) + case codec.HashedNode[H]: + child = nodeHandleHash[H]{c.Hash} case codec.InlineNode: child = nodeHandleInline(c) } return genProofStepDescend{ - childPrefixLen: len(nodePartialKey) + prefixlen + 1, + childPrefixLen: nodePartialKey.Len() + prefixlen + 1, child: child, }, nil } return genProofStepFoundValue{nil}, nil } -func resolveValue(recordedNodes *Iterator[triedb.Record]) (genProofStep, error) { +func resolveValue[H hash.Hash](recordedNodes *Iterator[triedb.Record[H]]) (genProofStep, error) { value := recordedNodes.Next() if value != nil { return genProofStepFoundHashedValue{value.Data}, nil diff --git a/pkg/trie/triedb/proof/generate_test.go b/pkg/trie/triedb/proof/generate_test.go index 296301d8b0..f70a934730 100644 --- a/pkg/trie/triedb/proof/generate_test.go +++ b/pkg/trie/triedb/proof/generate_test.go @@ -6,8 +6,11 @@ package proof import ( "testing" + "github.com/ChainSafe/gossamer/internal/primitives/core/hash" + "github.com/ChainSafe/gossamer/internal/primitives/runtime" "github.com/ChainSafe/gossamer/pkg/trie" "github.com/ChainSafe/gossamer/pkg/trie/triedb" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -16,7 +19,7 @@ func Test_NewProof(t *testing.T) { entries []trie.Entry storageVersion trie.TrieLayout keys []string - expectedProof MerkleProof + expectedProof MerkleProof[hash.H256, runtime.BlakeTwo256] }{ "leaf": { entries: []trie.Entry{ @@ -26,7 +29,7 @@ func Test_NewProof(t *testing.T) { }, }, keys: []string{"a"}, - expectedProof: MerkleProof{ + expectedProof: MerkleProof[hash.H256, runtime.BlakeTwo256]{ {66, 97, 0}, // 'a' node without value }, }, @@ -42,7 +45,7 @@ func Test_NewProof(t *testing.T) { }, }, keys: []string{"ab"}, - expectedProof: MerkleProof{ + expectedProof: MerkleProof[hash.H256, runtime.BlakeTwo256]{ {194, 97, 64, 0, 4, 97, 12, 65, 2, 0}, }, }, @@ -74,7 +77,7 @@ func Test_NewProof(t *testing.T) { }, }, keys: []string{"go"}, - expectedProof: MerkleProof{ + expectedProof: MerkleProof[hash.H256, runtime.BlakeTwo256]{ { 128, 192, 0, 0, 128, 114, 166, 121, 79, 225, 146, 229, 34, 68, 211, 54, 148, 205, 192, 58, 131, 95, 46, 239, @@ -115,7 +118,7 @@ func Test_NewProof(t *testing.T) { }, }, keys: []string{"go", "polkadot"}, - expectedProof: MerkleProof{ + expectedProof: MerkleProof[hash.H256, runtime.BlakeTwo256]{ { 128, 192, 0, 0, 0, }, @@ -138,7 +141,7 @@ func Test_NewProof(t *testing.T) { t.Run(name, func(t *testing.T) { // Build trie inmemoryDB := NewMemoryDB(triedb.EmptyNode) - triedb := triedb.NewEmptyTrieDB(inmemoryDB) + triedb := triedb.NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) for _, entry := range testCase.entries { triedb.Put(entry.Key, entry.Value) @@ -147,9 +150,11 @@ func Test_NewProof(t *testing.T) { root := triedb.MustHash() // Generate proof - proof, err := NewMerkleProof(inmemoryDB, testCase.storageVersion, root, testCase.keys) + proof, err := NewMerkleProof[hash.H256, runtime.BlakeTwo256]( + inmemoryDB, testCase.storageVersion, root, testCase.keys) require.NoError(t, err) - require.Equal(t, testCase.expectedProof, proof) + assert.Equal(t, len(testCase.expectedProof), len(proof)) + assert.Equal(t, testCase.expectedProof, proof) }) } } diff --git a/pkg/trie/triedb/proof/proof.go b/pkg/trie/triedb/proof/proof.go index 39a9962cd2..941c4cb9f4 100644 --- a/pkg/trie/triedb/proof/proof.go +++ b/pkg/trie/triedb/proof/proof.go @@ -7,32 +7,33 @@ import ( "bytes" "errors" - "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/pkg/trie" - nibbles "github.com/ChainSafe/gossamer/pkg/trie/codec" "github.com/ChainSafe/gossamer/pkg/trie/db" "github.com/ChainSafe/gossamer/pkg/trie/triedb" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/hash" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibbles" "github.com/gammazero/deque" ) -type MerkleProof [][]byte +type MerkleProof[H hash.Hash, Hasher hash.Hasher[H]] [][]byte -func NewMerkleProof(db db.RWDatabase, trieVersion trie.TrieLayout, rootHash common.Hash, keys []string) ( - proof MerkleProof, err error) { +func NewMerkleProof[H hash.Hash, Hasher hash.Hasher[H]]( + db db.RWDatabase, trieVersion trie.TrieLayout, rootHash H, keys []string) ( + proof MerkleProof[H, Hasher], err error) { // Sort and deduplicate keys keys = sortAndDeduplicateKeys(keys) // The stack of nodes through a path in the trie. // Each entry is a child node of the preceding entry. - stack := deque.New[*genProofStackEntry]() + stack := deque.New[*genProofStackEntry[H]]() // final proof nodes - var proofNodes MerkleProof + var proofNodes MerkleProof[H, Hasher] // Iterate over the keys and build the proof nodes for i := 0; i < len(keys); i = i + 1 { var key = []byte(keys[i]) - var keyNibbles = nibbles.KeyLEToNibbles(key) + var keyNibbles = nibbles.NewLeftNibbles(key) err := unwindStack(stack, proofNodes, &keyNibbles) if err != nil { @@ -40,8 +41,8 @@ func NewMerkleProof(db db.RWDatabase, trieVersion trie.TrieLayout, rootHash comm } // Traverse the trie recording the visited nodes - recorder := triedb.NewRecorder() - trie := triedb.NewTrieDB(rootHash, db, triedb.WithRecorder(recorder)) + recorder := triedb.NewRecorder[H]() + trie := triedb.NewTrieDB[H, Hasher](rootHash, db, triedb.WithRecorder[H, Hasher](recorder)) trie.SetVersion(trieVersion) trie.Get(key) @@ -52,7 +53,7 @@ func NewMerkleProof(db db.RWDatabase, trieVersion trie.TrieLayout, rootHash comm nextEntry := stack.At(i) nextRecord := recordedNodes.Peek() - if nextRecord == nil || !bytes.Equal(nextEntry.nodeHash[:], nextRecord.Hash[:]) { + if nextRecord == nil || !bytes.Equal(nextEntry.nodeHash[:], nextRecord.Hash.Bytes()) { break } @@ -63,12 +64,12 @@ func NewMerkleProof(db db.RWDatabase, trieVersion trie.TrieLayout, rootHash comm loop: for { var nextStep genProofStep - var entry *genProofStackEntry + var entry *genProofStackEntry[H] if stack.Len() > 0 { entry = stack.Back() } if entry == nil { - nextStep = genProofStepDescend{childPrefixLen: 0, child: nodeHandleHash(rootHash)} + nextStep = genProofStepDescend{childPrefixLen: 0, child: nodeHandleHash[H]{rootHash}} } else { var err error nextStep, err = genProofMatchKeyToNode( @@ -76,7 +77,7 @@ func NewMerkleProof(db db.RWDatabase, trieVersion trie.TrieLayout, rootHash comm &entry.omitValue, &entry.childIndex, keyNibbles, - len(entry.prefix), + entry.prefix.Len(), recordedNodes, ) @@ -87,25 +88,26 @@ func NewMerkleProof(db db.RWDatabase, trieVersion trie.TrieLayout, rootHash comm switch s := nextStep.(type) { case genProofStepDescend: - childPrefix := keyNibbles[:s.childPrefixLen] - var childEntry *genProofStackEntry + childPrefix := keyNibbles.Truncate(s.childPrefixLen) + var childEntry *genProofStackEntry[H] switch child := s.child.(type) { - case nodeHandleHash: + case nodeHandleHash[H]: childRecord := recordedNodes.Next() - if !bytes.Equal(childRecord.Hash[:], child[:]) { + // if !bytes.Equal(childRecord.Hash[:], child[:]) { + if childRecord.Hash != child.hash { panic("hash mismatch") } - outputIndex := len(proofNodes) + outputIndex := uint(len(proofNodes)) // Insert a placeholder into output which will be replaced when this // new entry is popped from the stack. proofNodes = append(proofNodes, []byte{}) - childEntry, err = newGenProofStackEntry( + childEntry, err = newGenProofStackEntry[H]( childPrefix, childRecord.Data, - &childRecord.Hash, + childRecord.Hash.Bytes(), &outputIndex, ) @@ -113,10 +115,10 @@ func NewMerkleProof(db db.RWDatabase, trieVersion trie.TrieLayout, rootHash comm return nil, err } case nodeHandleInline: - if len(child) > common.HashLength { + if len(child) > (*new(H)).Length() { return nil, errors.New("invalid hash length") } - childEntry, err = newGenProofStackEntry( + childEntry, err = newGenProofStackEntry[H]( childPrefix, child, nil, diff --git a/pkg/trie/triedb/proof/proof_test.go b/pkg/trie/triedb/proof/proof_test.go index 5ae6cf60ca..eded85a722 100644 --- a/pkg/trie/triedb/proof/proof_test.go +++ b/pkg/trie/triedb/proof/proof_test.go @@ -7,6 +7,8 @@ import ( "fmt" "testing" + "github.com/ChainSafe/gossamer/internal/primitives/core/hash" + "github.com/ChainSafe/gossamer/internal/primitives/runtime" "github.com/ChainSafe/gossamer/pkg/trie" "github.com/ChainSafe/gossamer/pkg/trie/triedb" "github.com/stretchr/testify/require" @@ -16,7 +18,7 @@ func Test_GenerateAndVerify(t *testing.T) { testCases := map[string]struct { entries []trie.Entry keys []string - expectedProof MerkleProof + expectedProof MerkleProof[hash.H256, runtime.BlakeTwo256] }{ "leaf": { entries: []trie.Entry{ @@ -26,7 +28,7 @@ func Test_GenerateAndVerify(t *testing.T) { }, }, keys: []string{"a"}, - expectedProof: MerkleProof{ + expectedProof: MerkleProof[hash.H256, runtime.BlakeTwo256]{ {66, 97, 0}, // 'a' node without value }, }, @@ -42,7 +44,7 @@ func Test_GenerateAndVerify(t *testing.T) { }, }, keys: []string{"ab"}, - expectedProof: MerkleProof{ + expectedProof: MerkleProof[hash.H256, runtime.BlakeTwo256]{ {194, 97, 64, 0, 4, 97, 12, 65, 2, 0}, }, }, @@ -74,7 +76,7 @@ func Test_GenerateAndVerify(t *testing.T) { }, }, keys: []string{"go"}, - expectedProof: MerkleProof{ + expectedProof: MerkleProof[hash.H256, runtime.BlakeTwo256]{ { 128, 192, 0, 0, 128, 114, 166, 121, 79, 225, 146, 229, 34, 68, 211, 54, 148, 205, 192, 58, 131, 95, 46, 239, @@ -115,7 +117,7 @@ func Test_GenerateAndVerify(t *testing.T) { }, }, keys: []string{"go", "polkadot"}, - expectedProof: MerkleProof{ + expectedProof: MerkleProof[hash.H256, runtime.BlakeTwo256]{ { 128, 192, 0, 0, 0, }, @@ -141,7 +143,7 @@ func Test_GenerateAndVerify(t *testing.T) { t.Run(fmt.Sprintf("%s_%s", name, trieVersion.String()), func(t *testing.T) { // Build trie inmemoryDB := NewMemoryDB(triedb.EmptyNode) - triedb := triedb.NewEmptyTrieDB(inmemoryDB) + triedb := triedb.NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) triedb.SetVersion(trieVersion) for _, entry := range testCase.entries { @@ -151,7 +153,7 @@ func Test_GenerateAndVerify(t *testing.T) { root := triedb.MustHash() // Generate proof - proof, err := NewMerkleProof(inmemoryDB, trieVersion, root, testCase.keys) + proof, err := NewMerkleProof[hash.H256, runtime.BlakeTwo256](inmemoryDB, trieVersion, root, testCase.keys) require.NoError(t, err) require.Equal(t, testCase.expectedProof, proof) @@ -163,7 +165,7 @@ func Test_GenerateAndVerify(t *testing.T) { value: triedb.Get([]byte(key)), } } - err = proof.Verify(trieVersion, root, items) + err = proof.Verify(trieVersion, root.Bytes(), items) require.NoError(t, err) }) diff --git a/pkg/trie/triedb/proof/util_test.go b/pkg/trie/triedb/proof/util_test.go index 7ebb81a18c..2cb0c8e94b 100644 --- a/pkg/trie/triedb/proof/util_test.go +++ b/pkg/trie/triedb/proof/util_test.go @@ -7,22 +7,22 @@ import ( "bytes" "github.com/ChainSafe/gossamer/internal/database" - "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/internal/primitives/runtime" "github.com/ChainSafe/gossamer/pkg/trie/db" ) // MemoryDB is an in-memory implementation of the Database interface backed by a // map. It uses blake2b as hashing algorithm type MemoryDB struct { - data map[common.Hash][]byte - hashedNullNode common.Hash + data map[string][]byte + hashedNullNode []byte nullNodeData []byte } func memoryDBFromNullNode(nullKey, nullNodeData []byte) *MemoryDB { return &MemoryDB{ - data: make(map[common.Hash][]byte), - hashedNullNode: common.MustBlake2bHash(nullKey), + data: make(map[string][]byte), + hashedNullNode: runtime.BlakeTwo256{}.Hash(nullKey).Bytes(), nullNodeData: nullNodeData, } } @@ -31,20 +31,20 @@ func NewMemoryDB(data []byte) *MemoryDB { return memoryDBFromNullNode(data, data) } -func (db *MemoryDB) emplace(key common.Hash, value []byte) { +func (db *MemoryDB) emplace(key []byte, value []byte) { if bytes.Equal(value, db.nullNodeData) { return } - db.data[key] = value + db.data[string(key)] = value } func (db *MemoryDB) Get(key []byte) ([]byte, error) { - dbKey := common.NewHash(key) - if dbKey == db.hashedNullNode { + dbKey := key + if bytes.Equal(dbKey, db.hashedNullNode) { return db.nullNodeData, nil } - if value, has := db.data[dbKey]; has { + if value, has := db.data[string(dbKey)]; has { return value, nil } @@ -52,14 +52,14 @@ func (db *MemoryDB) Get(key []byte) ([]byte, error) { } func (db *MemoryDB) Put(key []byte, value []byte) error { - dbKey := common.NewHash(key) + dbKey := key db.emplace(dbKey, value) return nil } func (db *MemoryDB) Del(key []byte) error { - dbKey := common.NewHash(key) - delete(db.data, dbKey) + dbKey := key + delete(db.data, string(dbKey)) return nil } diff --git a/pkg/trie/triedb/proof/verify.go b/pkg/trie/triedb/proof/verify.go index b3f3e0fe53..7f2f5bd693 100644 --- a/pkg/trie/triedb/proof/verify.go +++ b/pkg/trie/triedb/proof/verify.go @@ -8,12 +8,11 @@ import ( "errors" "fmt" - nibbles "github.com/ChainSafe/gossamer/pkg/trie/codec" - - "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/pkg/trie" "github.com/ChainSafe/gossamer/pkg/trie/triedb" "github.com/ChainSafe/gossamer/pkg/trie/triedb/codec" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/hash" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibbles" "github.com/gammazero/deque" "golang.org/x/exp/slices" ) @@ -29,7 +28,7 @@ type verifyProofStep interface { type ( verifyProofStepDescend struct { - childPrefix []byte + childPrefix nibbles.LeftNibbles } verifyProofStepUnwindStackStep struct{} ) @@ -37,29 +36,30 @@ type ( func (verifyProofStepDescend) isProofStep() {} func (verifyProofStepUnwindStackStep) isProofStep() {} -type verifyProofStackEntry struct { +type verifyProofStackEntry[H hash.Hash, Hasher hash.Hasher[H]] struct { trieVersion trie.TrieLayout - prefix []byte + prefix nibbles.LeftNibbles node codec.EncodedNode value codec.EncodedValue isInline bool childIndex int children [codec.ChildrenCapacity]triedb.ChildReference - nextValueHash common.Hash + nextValueHash H + emptyHash H } -func newVerifyProofStackEntry( +func newVerifyProofStackEntry[H hash.Hash, Hasher hash.Hasher[H]]( trieVersion trie.TrieLayout, nodeData []byte, - prefix []byte, + prefix nibbles.LeftNibbles, isInline bool, -) (*verifyProofStackEntry, error) { - node, err := codec.Decode(bytes.NewReader(nodeData)) +) (*verifyProofStackEntry[H, Hasher], error) { + node, err := codec.Decode[H](bytes.NewReader(nodeData)) if err != nil { return nil, err } - return &verifyProofStackEntry{ + return &verifyProofStackEntry[H, Hasher]{ trieVersion: trieVersion, node: node, value: node.GetValue(), @@ -67,24 +67,25 @@ func newVerifyProofStackEntry( isInline: isInline, childIndex: 0, children: [codec.ChildrenCapacity]triedb.ChildReference{}, - nextValueHash: common.EmptyHash, + nextValueHash: (*new(Hasher)).Hash([]byte{0}), + emptyHash: (*new(Hasher)).Hash([]byte{0}), }, nil } -func (e *verifyProofStackEntry) getValue() codec.EncodedValue { - if e.nextValueHash != common.EmptyHash { - return codec.HashedValue(e.nextValueHash) +func (e *verifyProofStackEntry[H, Hasher]) getValue() codec.EncodedValue { + if e.nextValueHash != e.emptyHash { + return codec.HashedValue[H]{Hash: e.nextValueHash} } return e.value } -func (e *verifyProofStackEntry) encodeNode() ([]byte, error) { +func (e *verifyProofStackEntry[H, Hasher]) encodeNode() ([]byte, error) { switch n := e.node.(type) { case codec.Empty: return []byte{triedb.EmptyTrieBytes}, nil case codec.Leaf: encodingBuffer := bytes.NewBuffer(nil) - err := triedb.NewEncodedLeaf(e.node.GetPartialKey(), e.getValue(), encodingBuffer) + err := triedb.NewEncodedLeaf(n.PartialKey.Right(), n.PartialKey.Len(), e.getValue(), encodingBuffer) if err != nil { return nil, err } @@ -100,15 +101,15 @@ func (e *verifyProofStackEntry) encodeNode() ([]byte, error) { switch c := child.(type) { case codec.InlineNode: children[childIndex] = triedb.InlineChildReference(c) - case codec.HashedNode: - children[childIndex] = triedb.HashChildReference(common.Hash(c)) + case codec.HashedNode[H]: + children[childIndex] = triedb.HashChildReference[H](c) } } childIndex++ } encodingBuffer := bytes.NewBuffer(nil) - err := triedb.NewEncodedBranch(e.node.GetPartialKey(), children, e.getValue(), encodingBuffer) + err := triedb.NewEncodedBranch(n.PartialKey.Right(), n.PartialKey.Len(), children, e.getValue(), encodingBuffer) if err != nil { return nil, err } @@ -118,17 +119,18 @@ func (e *verifyProofStackEntry) encodeNode() ([]byte, error) { } } -func (e *verifyProofStackEntry) advanceItem(itemsIter *Iterator[proofItem]) (verifyProofStep, error) { +func (e *verifyProofStackEntry[H, Hasher]) advanceItem(itemsIter *Iterator[proofItem]) (verifyProofStep, error) { for { item := itemsIter.Peek() if item == nil { return verifyProofStepUnwindStackStep{}, nil } - keyNibbles := nibbles.KeyLEToNibbles(item.key) + // keyNibbles := nibbles.KeyLEToNibbles(item.key) + keyNibbles := nibbles.NewLeftNibbles(item.key) value := item.value - if bytes.HasPrefix(keyNibbles, e.prefix) { - valueMatch := matchKeyToNode(keyNibbles, len(e.prefix), e.node) + if keyNibbles.StartsWith(e.prefix) { + valueMatch := matchKeyToNode[H](keyNibbles, e.prefix.Len(), e.node) switch m := valueMatch.(type) { case matchesLeaf: if value != nil { @@ -159,24 +161,24 @@ func (e *verifyProofStackEntry) advanceItem(itemsIter *Iterator[proofItem]) (ver } } -func (e *verifyProofStackEntry) advanceChildIndex( - childPrefix []byte, +func (e *verifyProofStackEntry[H, Hasher]) advanceChildIndex( + childPrefix nibbles.LeftNibbles, proofIter *Iterator[[]byte], -) (*verifyProofStackEntry, error) { +) (*verifyProofStackEntry[H, Hasher], error) { switch n := e.node.(type) { case codec.Branch: - if len(childPrefix) <= 0 { + if childPrefix.Len() <= 0 { panic("child prefix should be greater than 0") } - childIndex := childPrefix[len(childPrefix)-1] + childIndex := *childPrefix.At(childPrefix.Len() - 1) for e.childIndex < int(childIndex) { child := n.Children[e.childIndex] if child != nil { switch c := child.(type) { case codec.InlineNode: e.children[e.childIndex] = triedb.InlineChildReference(c) - case codec.HashedNode: - e.children[e.childIndex] = triedb.HashChildReference(common.Hash(c)) + case codec.HashedNode[H]: + e.children[e.childIndex] = triedb.HashChildReference[H](c) } } e.childIndex++ @@ -188,11 +190,11 @@ func (e *verifyProofStackEntry) advanceChildIndex( } } -func (e *verifyProofStackEntry) makeChildEntry( +func (e *verifyProofStackEntry[H, Hasher]) makeChildEntry( proofIter *Iterator[[]byte], child codec.MerkleValue, - childPrefix []byte, -) (*verifyProofStackEntry, error) { + childPrefix nibbles.LeftNibbles, +) (*verifyProofStackEntry[H, Hasher], error) { switch c := child.(type) { case codec.InlineNode: if len(c) == 0 { @@ -200,24 +202,25 @@ func (e *verifyProofStackEntry) makeChildEntry( if nodeData == nil { return nil, ErrIncompleteProof } - return newVerifyProofStackEntry(e.trieVersion, *nodeData, childPrefix, false) + return newVerifyProofStackEntry[H, Hasher](e.trieVersion, *nodeData, childPrefix, false) } - return newVerifyProofStackEntry(e.trieVersion, c, childPrefix, true) - case codec.HashedNode: - if len(c) != common.HashLength { - return nil, fmt.Errorf("invalid hash length: %x", c) + return newVerifyProofStackEntry[H, Hasher](e.trieVersion, c, childPrefix, true) + case codec.HashedNode[H]: + if len(c.Hash.Bytes()) != (*new(H)).Length() { + return nil, fmt.Errorf("invalid hash length: %x", c.Hash.Bytes()) } - return nil, fmt.Errorf("extraneous hash reference: %x", c) + return nil, fmt.Errorf("extraneous hash reference: %x", c.Hash.Bytes()) default: panic("unreachable") } } -func (e *verifyProofStackEntry) setValue(value []byte) { +func (e *verifyProofStackEntry[H, Hasher]) setValue(value []byte) { if len(value) <= e.trieVersion.MaxInlineValue() { e.value = codec.InlineValue(value) } else { - hashedValue := common.MustBlake2bHash(value) + // hashedValue := common.MustBlake2bHash(value).ToBytes() + hashedValue := (*(new(Hasher))).Hash(value) e.nextValueHash = hashedValue e.value = nil } @@ -238,7 +241,7 @@ type ( notOmitted struct{} // The key may match below a child of this node. Parameter is the prefix of the child node. isChild struct { - childPrefix []byte + childPrefix nibbles.LeftNibbles } ) @@ -248,14 +251,15 @@ func (notFound) isValueMatch() {} func (notOmitted) isValueMatch() {} func (isChild) isValueMatch() {} -func matchKeyToNode(keyNibbles []byte, prefixLen int, node codec.EncodedNode) valueMatch { +func matchKeyToNode[H hash.Hash](keyNibbles nibbles.LeftNibbles, prefixLen uint, node codec.EncodedNode) valueMatch { switch n := node.(type) { case codec.Empty: return notFound{} case codec.Leaf: - if bytes.Contains(keyNibbles, n.PartialKey) && len(keyNibbles) == prefixLen+len(n.PartialKey) { + if keyNibbles.Contains(n.PartialKey, prefixLen) && keyNibbles.Len() == prefixLen+n.PartialKey.Len() { + // if bytes.Contains(keyNibbles.Data(), n.PartialKey.Data()) && len(keyNibbles) == prefixLen+len(n.PartialKey) { switch v := n.Value.(type) { - case codec.HashedValue: + case codec.HashedValue[H]: return notOmitted{} case codec.InlineValue: if len(v) == 0 { @@ -266,8 +270,8 @@ func matchKeyToNode(keyNibbles []byte, prefixLen int, node codec.EncodedNode) va } return notFound{} case codec.Branch: - if bytes.Contains(keyNibbles, n.PartialKey) { - return matchKeyToBranchNode(keyNibbles, prefixLen+len(n.PartialKey), n.Children, n.Value) + if keyNibbles.Contains(n.PartialKey, prefixLen) { + return matchKeyToBranchNode(keyNibbles, prefixLen+n.PartialKey.Len(), n.Children, n.Value) } else { return notFound{} } @@ -277,20 +281,20 @@ func matchKeyToNode(keyNibbles []byte, prefixLen int, node codec.EncodedNode) va } func matchKeyToBranchNode( - key []byte, - prefixPlusPartialLen int, + key nibbles.LeftNibbles, + prefixPlusPartialLen uint, children [codec.ChildrenCapacity]codec.MerkleValue, value codec.EncodedValue, ) valueMatch { - if len(key) == prefixPlusPartialLen { + if key.Len() == prefixPlusPartialLen { if value == nil { return matchesBranch{} } return notOmitted{} } - index := key[prefixPlusPartialLen] + index := *key.At(prefixPlusPartialLen) if children[index] != nil { - return isChild{childPrefix: key[:prefixPlusPartialLen+1]} + return isChild{childPrefix: key.Truncate(prefixPlusPartialLen + 1)} } return notFound{} @@ -301,9 +305,9 @@ type proofItem struct { value []byte } -func (proof MerkleProof) Verify( +func (proof MerkleProof[H, Hasher]) Verify( trieVersion trie.TrieLayout, - root common.Hash, + root []byte, items []proofItem, ) error { // sort items @@ -331,14 +335,14 @@ func (proof MerkleProof) Verify( // A stack of child references to fill in omitted branch children for later trie nodes in the // proof. - stack := deque.New[verifyProofStackEntry]() + stack := deque.New[verifyProofStackEntry[H, Hasher]]() rootNode := proofIter.Next() if rootNode == nil { return ErrIncompleteProof } - lastEntry, err := newVerifyProofStackEntry(trieVersion, *rootNode, []byte{}, false) + lastEntry, err := newVerifyProofStackEntry[H, Hasher](trieVersion, *rootNode, nibbles.NewLeftNibbles(nil), false) if err != nil { return err } @@ -367,13 +371,13 @@ loop: var childRef triedb.ChildReference if isInline { - if len(nodeData) > common.HashLength { + if len(nodeData) > (*new(H)).Length() { return fmt.Errorf("invalid child reference: %x", nodeData) } childRef = triedb.InlineChildReference(nodeData) } else { - hash := common.MustBlake2bHash(nodeData) - childRef = triedb.HashChildReference(hash) + hash := (*new(Hasher)).Hash(nodeData) + childRef = triedb.HashChildReference[H]{Hash: hash} } if stack.Len() > 0 { @@ -386,10 +390,10 @@ loop: if nextProof != nil { return ErrExtraneusNode } - var computedRoot common.Hash + var computedRoot []byte switch c := childRef.(type) { - case triedb.HashChildReference: - computedRoot = common.Hash(c) + case triedb.HashChildReference[H]: + computedRoot = c.Hash.Bytes() case triedb.InlineChildReference: panic("unreachable") } diff --git a/pkg/trie/triedb/recorder.go b/pkg/trie/triedb/recorder.go index 9b15eee0a0..8041bb9c89 100644 --- a/pkg/trie/triedb/recorder.go +++ b/pkg/trie/triedb/recorder.go @@ -4,58 +4,96 @@ package triedb import ( - "github.com/ChainSafe/gossamer/lib/common" "github.com/tidwall/btree" ) -type trieAccess interface { +type TrieAccess interface { isTrieAccess() } type ( - encodedNodeAccess struct { - hash common.Hash - encodedNode []byte + EncodedNodeAccess[H any] struct { + Hash H + EncodedNode []byte } - valueAccess struct { - hash common.Hash - value []byte - fullKey []byte + ValueAccess[H any] struct { + Hash H + Value []byte + FullKey []byte } - inlineValueAccess struct { - fullKey []byte + InlineValueAccess struct { + FullKey []byte } - hashAccess struct { - fullKey []byte + HashAccess struct { + FullKey []byte } - nonExistingNodeAccess struct { - fullKey []byte + NonExistingNodeAccess struct { + FullKey []byte } ) -func (encodedNodeAccess) isTrieAccess() {} -func (valueAccess) isTrieAccess() {} -func (inlineValueAccess) isTrieAccess() {} -func (hashAccess) isTrieAccess() {} -func (nonExistingNodeAccess) isTrieAccess() {} +func (EncodedNodeAccess[H]) isTrieAccess() {} +func (ValueAccess[H]) isTrieAccess() {} +func (InlineValueAccess) isTrieAccess() {} +func (HashAccess) isTrieAccess() {} +func (NonExistingNodeAccess) isTrieAccess() {} + +// A trie recorder that can be used to record all kind of [TrieAccess]. +// +// To build a trie proof a recorder is required that records all trie accesses. These recorded trie +// accesses can then be used to create the proof. +type TrieRecorder interface { + // Record the given [TrieAccess]. + // + // Depending on the [TrieAccess] a call to [TrieRecorder.TrieNodesRecordedForKey] afterwards + // must return the correct recorded state. + Record(access TrieAccess) + + // Check if we have recorded any trie nodes for the given key. + // + // Returns [RecordedForKey] to express the state of the recorded trie nodes. + TrieNodesRecordedForKey(key []byte) RecordedForKey +} type RecordedForKey int const ( + // We recorded all trie nodes up to the value for a storage key. + // + // This should be returned when the recorder has seen the following [TrieAccess]: + // + // - [ValueAccess]: If we see this [TrieAccess], it means we have recorded all the + // trie nodes up to the value. + // - [NonExistingNodeAccess]: If we see this [TrieAccess], it means we have recorded all + // the necessary trie nodes to prove that the value doesn't exist in the trie. RecordedValue RecordedForKey = iota + // We recorded all trie nodes up to the value hash for a storage key. + // + // If we have a [RecordedValue], it means that we also have the hash of this value. + // This also means that if we first have recorded the hash of a value and then also record the + // value, the access should be upgraded to [RecordedValue]. + // + // This should be returned when the recorder has seen the following [TrieAccess]: + // + // - [HashAccess]: If we see this [TrieAccess], it means we have recorded all trie + // nodes to have the hash of the value. RecordedHash + // We haven't recorded any trie nodes yet for a storage key. + // + // This means we have not seen any [TrieAccess] referencing the searched key. + RecordedNone ) -type RecordedNodesIterator struct { - nodes []Record +type RecordedNodesIterator[H any] struct { + nodes []Record[H] index int } -func NewRecordedNodesIterator(nodes []Record) *RecordedNodesIterator { - return &RecordedNodesIterator{nodes: nodes, index: -1} +func NewRecordedNodesIterator[H any](nodes []Record[H]) *RecordedNodesIterator[H] { + return &RecordedNodesIterator[H]{nodes: nodes, index: -1} } -func (r *RecordedNodesIterator) Next() *Record { +func (r *RecordedNodesIterator[H]) Next() *Record[H] { if r.index < len(r.nodes)-1 { r.index++ return &r.nodes[r.index] @@ -63,52 +101,58 @@ func (r *RecordedNodesIterator) Next() *Record { return nil } -func (r *RecordedNodesIterator) Peek() *Record { +func (r *RecordedNodesIterator[H]) Peek() *Record[H] { if r.index+1 < len(r.nodes)-1 { return &r.nodes[r.index+1] } return nil } -type Record struct { - Hash common.Hash +type Record[H any] struct { + Hash H Data []byte } -type Recorder struct { - nodes []Record +type Recorder[H any] struct { + nodes []Record[H] recordedKeys btree.Map[string, RecordedForKey] } -func NewRecorder() *Recorder { - return &Recorder{ - nodes: []Record{}, +func NewRecorder[H any]() *Recorder[H] { + return &Recorder[H]{ + nodes: []Record[H]{}, recordedKeys: *btree.NewMap[string, RecordedForKey](0), } } -func (r *Recorder) record(access trieAccess) { +func (r *Recorder[H]) Record(access TrieAccess) { switch a := access.(type) { - case encodedNodeAccess: - r.nodes = append(r.nodes, Record{Hash: a.hash, Data: a.encodedNode}) - case valueAccess: - r.nodes = append(r.nodes, Record{Hash: a.hash, Data: a.value}) - r.recordedKeys.Set(string(a.fullKey), RecordedValue) - case inlineValueAccess: - r.recordedKeys.Set(string(a.fullKey), RecordedValue) - case hashAccess: - if _, ok := r.recordedKeys.Get(string(a.fullKey)); !ok { - r.recordedKeys.Set(string(a.fullKey), RecordedHash) + case EncodedNodeAccess[H]: + r.nodes = append(r.nodes, Record[H]{Hash: a.Hash, Data: a.EncodedNode}) + case ValueAccess[H]: + r.nodes = append(r.nodes, Record[H]{Hash: a.Hash, Data: a.Value}) + r.recordedKeys.Set(string(a.FullKey), RecordedValue) + case InlineValueAccess: + r.recordedKeys.Set(string(a.FullKey), RecordedValue) + case HashAccess: + if _, ok := r.recordedKeys.Get(string(a.FullKey)); !ok { + r.recordedKeys.Set(string(a.FullKey), RecordedHash) } - case nonExistingNodeAccess: + case NonExistingNodeAccess: // We handle the non existing value/hash like having recorded the value - r.recordedKeys.Set(string(a.fullKey), RecordedValue) + r.recordedKeys.Set(string(a.FullKey), RecordedValue) } } -func (r *Recorder) Drain() []Record { +func (r *Recorder[H]) Drain() []Record[H] { r.recordedKeys.Clear() nodes := r.nodes - r.nodes = []Record{} + r.nodes = []Record[H]{} return nodes } + +func (r *Recorder[H]) TrieNodesRecordedForKey(key []byte) RecordedForKey { + panic("unimpl") +} + +var _ TrieRecorder = &Recorder[string]{} diff --git a/pkg/trie/triedb/recorder_test.go b/pkg/trie/triedb/recorder_test.go index 151e050691..76974f8632 100644 --- a/pkg/trie/triedb/recorder_test.go +++ b/pkg/trie/triedb/recorder_test.go @@ -6,15 +6,17 @@ package triedb import ( "testing" + "github.com/ChainSafe/gossamer/internal/primitives/core/hash" + "github.com/ChainSafe/gossamer/internal/primitives/runtime" "github.com/stretchr/testify/require" ) // Tests results are based on // https://github.com/dimartiro/substrate-trie-test/blob/master/src/substrate_trie_test.rs func TestRecorder(t *testing.T) { - inmemoryDB := NewMemoryDB(EmptyNode) + inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) - triedb := NewEmptyTrieDB(inmemoryDB) + triedb := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) triedb.Put([]byte("pol"), []byte("polvalue")) triedb.Put([]byte("polka"), []byte("polkavalue")) @@ -27,8 +29,9 @@ func TestRecorder(t *testing.T) { require.NotNil(t, root) t.Run("Record_pol_access_should_record_2_node", func(t *testing.T) { - recorder := NewRecorder() - trie := NewTrieDB(root, inmemoryDB, WithRecorder(recorder)) + recorder := NewRecorder[hash.H256]() + trie := NewTrieDB[hash.H256, runtime.BlakeTwo256]( + root, inmemoryDB, WithRecorder[hash.H256, runtime.BlakeTwo256](recorder)) trie.Get([]byte("pol")) @@ -63,8 +66,9 @@ func TestRecorder(t *testing.T) { }) t.Run("Record_go_access_should_record_2_nodes", func(t *testing.T) { - recorder := NewRecorder() - trie := NewTrieDB(root, inmemoryDB, WithRecorder(recorder)) + recorder := NewRecorder[hash.H256]() + trie := NewTrieDB[hash.H256, runtime.BlakeTwo256]( + root, inmemoryDB, WithRecorder[hash.H256, runtime.BlakeTwo256](recorder)) trie.Get([]byte("go")) diff --git a/pkg/trie/triedb/triedb.go b/pkg/trie/triedb/triedb.go index b7d01de767..e7bad49275 100644 --- a/pkg/trie/triedb/triedb.go +++ b/pkg/trie/triedb/triedb.go @@ -7,16 +7,16 @@ import ( "bytes" "errors" "fmt" + "slices" - "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/pkg/trie" - nibbles "github.com/ChainSafe/gossamer/pkg/trie/codec" "github.com/ChainSafe/gossamer/pkg/trie/db" "github.com/ChainSafe/gossamer/internal/database" "github.com/ChainSafe/gossamer/internal/log" - "github.com/ChainSafe/gossamer/pkg/trie/cache" "github.com/ChainSafe/gossamer/pkg/trie/triedb/codec" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/hash" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibbles" ) var ErrIncompleteDB = errors.New("incomplete database") @@ -25,24 +25,26 @@ var ( logger = log.NewFromGlobal(log.AddContext("pkg", "triedb")) ) -type TrieDBOpts func(*TrieDB) +type TrieDBOpts[H hash.Hash, Hasher hash.Hasher[H]] func(*TrieDB[H, Hasher]) -var WithCache = func(c cache.TrieCache) TrieDBOpts { - return func(t *TrieDB) { +// Define cache interface for now to reduce size of changes +type Cache interface{} + +func WithCache[H hash.Hash, Hasher hash.Hasher[H]](c Cache) TrieDBOpts[H, Hasher] { + return func(t *TrieDB[H, Hasher]) { t.cache = c } } - -var WithRecorder = func(r *Recorder) TrieDBOpts { - return func(t *TrieDB) { +func WithRecorder[H hash.Hash, Hasher hash.Hasher[H]](r TrieRecorder) TrieDBOpts[H, Hasher] { + return func(t *TrieDB[H, Hasher]) { t.recorder = r } } // TrieDB is a DB-backed patricia merkle trie implementation // using lazy loading to fetch nodes -type TrieDB struct { - rootHash common.Hash +type TrieDB[H hash.Hash, Hasher hash.Hasher[H]] struct { + rootHash H db db.RWDatabase version trie.TrieLayout // rootHandle is an in-memory-trie-like representation of the node @@ -50,30 +52,33 @@ type TrieDB struct { rootHandle NodeHandle // Storage is an in memory storage for nodes that we need to use during this // trieDB session (before nodes are committed to db) - storage nodeStorage + storage nodeStorage[H] // deathRow is a set of nodes that we want to delete from db // uses string since it's comparable []byte deathRow map[string]interface{} // Optional cache to speed up the db lookups - cache cache.TrieCache + cache Cache // Optional recorder for recording trie accesses - recorder *Recorder + recorder TrieRecorder } -func NewEmptyTrieDB(db db.RWDatabase, opts ...TrieDBOpts) *TrieDB { - root := hashedNullNode - return NewTrieDB(root, db) +func NewEmptyTrieDB[H hash.Hash, Hasher hash.Hasher[H]]( + db db.RWDatabase, opts ...TrieDBOpts[H, Hasher]) *TrieDB[H, Hasher] { + hasher := *new(Hasher) + root := hasher.Hash([]byte{0}) + return NewTrieDB[H, Hasher](root, db) } // NewTrieDB creates a new TrieDB using the given root and db -func NewTrieDB(rootHash common.Hash, db db.RWDatabase, opts ...TrieDBOpts) *TrieDB { - rootHandle := persisted(rootHash) +func NewTrieDB[H hash.Hash, Hasher hash.Hasher[H]]( + rootHash H, db db.RWDatabase, opts ...TrieDBOpts[H, Hasher]) *TrieDB[H, Hasher] { + rootHandle := persisted[H]{rootHash} - trieDB := &TrieDB{ + trieDB := &TrieDB[H, Hasher]{ rootHash: rootHash, version: trie.V0, db: db, - storage: newNodeStorage(), + storage: newNodeStorage[H](), rootHandle: rootHandle, deathRow: make(map[string]interface{}), } @@ -85,7 +90,7 @@ func NewTrieDB(rootHash common.Hash, db db.RWDatabase, opts ...TrieDBOpts) *Trie return trieDB } -func (t *TrieDB) SetVersion(v trie.TrieLayout) { +func (t *TrieDB[H, Hasher]) SetVersion(v trie.TrieLayout) { if v < t.version { panic("cannot regress trie version") } @@ -94,10 +99,11 @@ func (t *TrieDB) SetVersion(v trie.TrieLayout) { } // Hash returns the hashed root of the trie. -func (t *TrieDB) Hash() (common.Hash, error) { +func (t *TrieDB[H, Hasher]) Hash() (H, error) { err := t.commit() if err != nil { - return common.EmptyHash, err + root := (*new(Hasher)).Hash([]byte{0}) + return root, err } // This is trivial since it is a read only trie, but will change when we // support writes @@ -106,7 +112,7 @@ func (t *TrieDB) Hash() (common.Hash, error) { // MustHash returns the hashed root of the trie. // It panics if it fails to hash the root node. -func (t *TrieDB) MustHash() common.Hash { +func (t *TrieDB[H, Hasher]) MustHash() H { h, err := t.Hash() if err != nil { panic(err) @@ -118,10 +124,8 @@ func (t *TrieDB) MustHash() common.Hash { // Get returns the value in the node of the trie // which matches its key with the key given. // Note the key argument is given in little Endian format. -func (t *TrieDB) Get(key []byte) []byte { - keyNibbles := nibbles.KeyLEToNibbles(key) - - val, err := t.lookup(keyNibbles, keyNibbles, t.rootHandle) +func (t *TrieDB[H, Hasher]) Get(key []byte) []byte { + val, err := t.lookup(key, nibbles.NewNibbles(slices.Clone(key)), t.rootHandle) if err != nil { return nil } @@ -129,15 +133,15 @@ func (t *TrieDB) Get(key []byte) []byte { return val } -func (t *TrieDB) lookup(fullKey []byte, partialKey []byte, handle NodeHandle) ([]byte, error) { +func (t *TrieDB[H, Hasher]) lookup(fullKey []byte, partialKey nibbles.Nibbles, handle NodeHandle) ([]byte, error) { prefix := fullKey for { - var partialIdx int + var partialIdx uint switch node := handle.(type) { - case persisted: - lookup := NewTrieLookup(t.db, common.Hash(node), t.cache, t.recorder) - val, err := lookup.lookupValue(fullKey) + case persisted[H]: + lookup := NewTrieLookup[H, Hasher](t.db, node.hash, t.cache, t.recorder) + val, err := lookup.lookupValue(fullKey, partialKey) if err != nil { return nil, err } @@ -146,90 +150,98 @@ func (t *TrieDB) lookup(fullKey []byte, partialKey []byte, handle NodeHandle) ([ switch n := t.storage.get(storageHandle(node)).(type) { case Empty: return nil, nil - case Leaf: - if bytes.Equal(n.partialKey, partialKey) { + case Leaf[H]: + if nibbles.NewNibblesFromNodeKey(n.partialKey).Equal(partialKey) { return inMemoryFetchedValue(n.value, prefix, t.db) } else { return nil, nil } - case Branch: - if bytes.Equal(n.partialKey, partialKey) { + case Branch[H]: + slice := nibbles.NewNibblesFromNodeKey(n.partialKey) + if slice.Equal(partialKey) { return inMemoryFetchedValue(n.value, prefix, t.db) - } else if bytes.HasPrefix(partialKey, n.partialKey) { - idx := partialKey[len(n.partialKey)] + } else if partialKey.StartsWith(slice) { + idx := partialKey.At(slice.Len()) child := n.children[idx] if child != nil { - partialIdx = 1 + len(n.partialKey) + partialIdx = slice.Len() + 1 handle = child + } else { + return nil, nil } } else { return nil, nil } } } - partialKey = partialKey[partialIdx:] + partialKey = partialKey.Mid(partialIdx) } } -// Internal methods -func (t *TrieDB) getRootNode() (codec.EncodedNode, error) { - encodedNode, err := t.db.Get(t.rootHash[:]) - if err != nil { - return nil, err +func (t *TrieDB[H, Hasher]) getNodeOrLookup( + nodeHandle codec.MerkleValue, partialKey nibbles.Prefix, recordAccess bool, +) (codec.EncodedNode, *H, error) { + var nodeHash *H + var nodeData []byte + switch nodeHandle := nodeHandle.(type) { + case codec.HashedNode[H]: + prefixedKey := append(partialKey.JoinedBytes(), nodeHandle.Hash.Bytes()...) + var err error + nodeData, err = t.db.Get(prefixedKey) + if err != nil { + return nil, nil, err + } + if len(nodeData) == 0 { + if partialKey.Key == nil && partialKey.Padded == nil { + return nil, nil, fmt.Errorf("invalid state root: %v", nodeHandle.Hash) + } + return nil, nil, fmt.Errorf("incomplete database: %v", nodeHandle.Hash) + } + nodeHash = &nodeHandle.Hash + case codec.InlineNode: + nodeHash = nil + nodeData = nodeHandle } - t.recordAccess(encodedNodeAccess{hash: t.rootHash, encodedNode: encodedNode}) + reader := bytes.NewReader(nodeData) + decoded, err := codec.Decode[H](reader) + if err != nil { + return nil, nil, err + } - reader := bytes.NewReader(encodedNode) - return codec.Decode(reader) + if recordAccess { + t.recordAccess(EncodedNodeAccess[H]{Hash: t.rootHash, EncodedNode: nodeData}) + } + return decoded, nodeHash, nil } -// Internal methods - -func (t *TrieDB) getNodeAt(key []byte) (codec.EncodedNode, error) { - lookup := NewTrieLookup(t.db, t.rootHash, t.cache, t.recorder) - node, err := lookup.lookupNode(nibbles.KeyLEToNibbles(key)) +func (t *TrieDB[H, Hasher]) fetchValue(hash H, prefix nibbles.Prefix) ([]byte, error) { + prefixedKey := append(prefix.JoinedBytes(), hash.Bytes()...) + value, err := t.db.Get(prefixedKey) if err != nil { return nil, err } - - return node, nil -} - -func (t *TrieDB) getNode( - merkleValue codec.MerkleValue, -) (node codec.EncodedNode, err error) { - switch n := merkleValue.(type) { - case codec.InlineNode: - reader := bytes.NewReader(n) - return codec.Decode(reader) - case codec.HashedNode: - encodedNode, err := t.db.Get(n[:]) - if err != nil { - return nil, err - } - t.recordAccess(encodedNodeAccess{hash: t.rootHash, encodedNode: encodedNode}) - - reader := bytes.NewReader(encodedNode) - return codec.Decode(reader) - default: // should never happen - panic("unreachable") + if value == nil { + return nil, fmt.Errorf("incomplete database: %v", hash) } + t.recordAccess(ValueAccess[H]{Hash: t.rootHash, Value: value, FullKey: prefix.Key}) + return value, nil } // Remove removes the given key from the trie -func (t *TrieDB) remove(keyNibbles []byte) error { - var oldValue nodeValue +func (t *TrieDB[H, Hasher]) remove(keyNibbles nibbles.Nibbles) error { + var oldValue nodeValue[H] rootHandle := t.rootHandle - removeResult, err := t.removeAt(rootHandle, keyNibbles, &oldValue) + removeResult, err := t.removeAt(rootHandle, &keyNibbles, &oldValue) if err != nil { return err } if removeResult != nil { t.rootHandle = inMemory(removeResult.handle) } else { - t.rootHandle = persisted(hashedNullNode) + hashedNullNode := (*new(Hasher)).Hash([]byte{0}) + t.rootHandle = persisted[H]{hashedNullNode} t.rootHash = hashedNullNode } @@ -237,16 +249,15 @@ func (t *TrieDB) remove(keyNibbles []byte) error { } // Delete deletes the given key from the trie -func (t *TrieDB) Delete(key []byte) error { - keyNibbles := nibbles.KeyLEToNibbles(key) - return t.remove(keyNibbles) +func (t *TrieDB[H, Hasher]) Delete(key []byte) error { + return t.remove(nibbles.NewNibbles(key)) } // insert inserts the node and update the rootHandle -func (t *TrieDB) insert(keyNibbles, value []byte) error { - var oldValue nodeValue +func (t *TrieDB[H, Hasher]) insert(keyNibbles nibbles.Nibbles, value []byte) error { + var oldValue nodeValue[H] rootHandle := t.rootHandle - newHandle, _, err := t.insertAt(rootHandle, keyNibbles, value, &oldValue) + newHandle, _, err := t.insertAt(rootHandle, &keyNibbles, value, &oldValue) if err != nil { return err } @@ -256,31 +267,30 @@ func (t *TrieDB) insert(keyNibbles, value []byte) error { } // Put inserts the given key / value pair into the trie -func (t *TrieDB) Put(key, value []byte) error { - keyNibbles := nibbles.KeyLEToNibbles(key) - return t.insert(keyNibbles, value) +func (t *TrieDB[H, Hasher]) Put(key, value []byte) error { + return t.insert(nibbles.NewNibbles(key), value) } // insertAt inserts the given key / value pair into the node referenced by the // node handle `handle` -func (t *TrieDB) insertAt( +func (t *TrieDB[H, Hasher]) insertAt( handle NodeHandle, - keyNibbles, + keyNibbles *nibbles.Nibbles, value []byte, - oldValue *nodeValue, + oldValue *nodeValue[H], ) (strgHandle storageHandle, changed bool, err error) { switch h := handle.(type) { case inMemory: strgHandle = storageHandle(h) - case persisted: - strgHandle, err = t.lookupNode(common.Hash(h)) + case persisted[H]: + strgHandle, err = t.lookupNode(h.hash, keyNibbles.Left()) if err != nil { return -1, false, err } } stored := t.storage.destroy(strgHandle) - result, err := t.inspect(stored, keyNibbles, func(node Node, keyNibbles []byte) (action, error) { + result, err := t.inspect(stored, keyNibbles, func(node Node, keyNibbles *nibbles.Nibbles) (action, error) { return t.insertInspector(node, keyNibbles, value, oldValue) }) if err != nil { @@ -299,24 +309,24 @@ type RemoveAtResult struct { changed bool } -func (t *TrieDB) removeAt( +func (t *TrieDB[H, Hasher]) removeAt( handle NodeHandle, - keyNibbles []byte, - oldValue *nodeValue, + keyNibbles *nibbles.Nibbles, + oldValue *nodeValue[H], ) (*RemoveAtResult, error) { var stored StoredNode switch h := handle.(type) { case inMemory: stored = t.storage.destroy(storageHandle(h)) - case persisted: - handle, err := t.lookupNode(common.Hash(h)) + case persisted[H]: + handle, err := t.lookupNode(h.hash, keyNibbles.Left()) if err != nil { return nil, err } stored = t.storage.destroy(handle) } - result, err := t.inspect(stored, keyNibbles, func(node Node, keyNibbles []byte) (action, error) { + result, err := t.inspect(stored, keyNibbles, func(node Node, keyNibbles *nibbles.Nibbles) (action, error) { return t.removeInspector(node, keyNibbles, oldValue) }) if err != nil { @@ -340,11 +350,13 @@ type InspectResult struct { // inspect inspects the given node `stored` and calls the `inspector` function // then returns the new node and a boolean indicating if the node has changed -func (t *TrieDB) inspect( +func (t *TrieDB[H, Hasher]) inspect( stored StoredNode, - key []byte, - inspector func(Node, []byte) (action, error), + key *nibbles.Nibbles, + inspector func(Node, *nibbles.Nibbles) (action, error), ) (*InspectResult, error) { + // shallow copy since key will change offset through inspector + currentKey := *key switch n := stored.(type) { case NewStoredNode: res, err := inspector(n.node, key) @@ -361,19 +373,21 @@ func (t *TrieDB) inspect( default: panic("unreachable") } - case CachedStoredNode: + case CachedStoredNode[H]: res, err := inspector(n.node, key) if err != nil { return nil, err } switch a := res.(type) { case restoreNode: - return &InspectResult{CachedStoredNode{a.node, n.hash}, false}, nil + return &InspectResult{CachedStoredNode[H]{a.node, n.hash}, false}, nil case replaceNode: - t.deathRow[string(n.hash.ToBytes())] = nil + prefixedKey := append(currentKey.Left().JoinedBytes(), n.hash.Bytes()...) + t.deathRow[string(prefixedKey)] = nil return &InspectResult{NewStoredNode(a), true}, nil case deleteNode: - t.deathRow[string(n.hash.ToBytes())] = nil + prefixedKey := append(currentKey.Left().JoinedBytes(), n.hash.Bytes()...) + t.deathRow[string(prefixedKey)] = nil return nil, nil default: panic("unreachable") @@ -385,7 +399,7 @@ func (t *TrieDB) inspect( // fix is a helper function to reorganise the nodes after deleting a branch. // For example, if the node we are deleting is the only child for a branch node, we can transform that branch in a leaf -func (t *TrieDB) fix(branch Branch) (Node, error) { +func (t *TrieDB[H, Hasher]) fix(branch Branch[H], key *nibbles.Nibbles) (Node, error) { usedIndex := make([]byte, 0) for i := 0; i < codec.ChildrenCapacity; i++ { @@ -403,7 +417,7 @@ func (t *TrieDB) fix(branch Branch) (Node, error) { } // Make it a leaf - return Leaf{branch.partialKey, branch.value}, nil + return Leaf[H]{branch.partialKey, branch.value}, nil } else if len(usedIndex) == 1 && branch.value == nil { // Only one onward node. use child instead idx := usedIndex[0] @@ -411,12 +425,48 @@ func (t *TrieDB) fix(branch Branch) (Node, error) { child := branch.children[idx] branch.children[idx] = nil + key2 := key.Clone() + key2.Advance(uint(len(branch.partialKey.Data))* + nibbles.NibblesPerByte - branch.partialKey.Offset) + + var ( + start []byte + allocStart []byte + prefixEnd *byte + ) + prefix := key2.Left() + switch prefix.Padded { + case nil: + start = prefix.Key + allocStart = nil + pushed := nibbles.PushAtLeft(0, idx, 0) + prefixEnd = &pushed + default: + so := prefix.Key + so = append(so, nibbles.PadLeft(*prefix.Padded)|idx) + start = prefix.Key + allocStart = so + prefixEnd = nil + } + var childPrefix nibbles.Prefix + if allocStart != nil { + childPrefix = nibbles.Prefix{ + Key: allocStart, + Padded: prefixEnd, + } + } else { + childPrefix = nibbles.Prefix{ + Key: start, + Padded: prefixEnd, + } + } + var stored StoredNode switch n := child.(type) { case inMemory: stored = t.storage.destroy(storageHandle(n)) - case persisted: - handle, err := t.lookupNode(common.Hash(n)) + case persisted[H]: + handle, err := t.lookupNode(n.hash, childPrefix) if err != nil { return nil, fmt.Errorf("looking up node: %w", err) } @@ -427,18 +477,21 @@ func (t *TrieDB) fix(branch Branch) (Node, error) { switch n := stored.(type) { case NewStoredNode: childNode = n.node - case CachedStoredNode: - t.deathRow[string(n.hash.ToBytes())] = nil + case CachedStoredNode[H]: + prefixedKey := append(childPrefix.JoinedBytes(), n.hash.Bytes()...) + t.deathRow[string(prefixedKey)] = nil childNode = n.node } - combinedKey := bytes.Join([][]byte{branch.partialKey, {idx}, childNode.getPartialKey()}, nil) - switch n := childNode.(type) { - case Leaf: - return Leaf{combinedKey, n.value}, nil - case Branch: - return Branch{combinedKey, n.children, n.value}, nil + case Leaf[H]: + combinedKey := combineKey(branch.partialKey, nodeKey{Offset: nibbles.NibblesPerByte - 1, Data: []byte{idx}}) + combinedKey = combineKey(combinedKey, n.partialKey) + return Leaf[H]{combinedKey, n.value}, nil + case Branch[H]: + combinedKey := combineKey(branch.partialKey, nodeKey{Offset: nibbles.NibblesPerByte - 1, Data: []byte{idx}}) + combinedKey = combineKey(combinedKey, n.partialKey) + return Branch[H]{combinedKey, n.children, n.value}, nil default: panic("unreachable") } @@ -448,53 +501,84 @@ func (t *TrieDB) fix(branch Branch) (Node, error) { } } +func combineKey(start nodeKey, end nodeKey) nodeKey { + if !(start.Offset < nibbles.NibblesPerByte) { + panic("invalid start offset") + } + if !(end.Offset < nibbles.NibblesPerByte) { + panic("invalid end offset") + } + finalOffset := (start.Offset + end.Offset) % nibbles.NibblesPerByte + _ = start.ShiftKey(finalOffset) + var st uint + if end.Offset > 0 { + sl := len(start.Data) + start.Data[sl-1] |= nibbles.PadRight(end.Data[0]) + st = 1 + } else { + st = 0 + } + for i := st; i < uint(len(end.Data)); i++ { + start.Data = append(start.Data, end.Data[i]) + } + return start +} + // removeInspector removes the key node from the given node `stored` -func (t *TrieDB) removeInspector(stored Node, keyNibbles []byte, oldValue *nodeValue) (action, error) { - partial := keyNibbles +func (t *TrieDB[H, Hasher]) removeInspector( + stored Node, keyNibbles *nibbles.Nibbles, oldValue *nodeValue[H], +) (action, error) { + partial := keyNibbles.Clone() switch n := stored.(type) { case Empty: return deleteNode{}, nil - case Leaf: - if bytes.Equal(n.partialKey, partial) { - + case Leaf[H]: + existingKey := nibbles.NewNibblesFromNodeKey(n.partialKey) + if existingKey.Equal(partial) { // This is the node we are looking for so we delete it - t.replaceOldValue(oldValue, n.value, partial) + keyVal := keyNibbles.Clone() + keyVal.Advance(existingKey.Len()) + t.replaceOldValue(oldValue, n.value, keyVal.Left()) return deleteNode{}, nil } // Wrong partial, so we return the node as is return restoreNode{n}, nil - case Branch: - if len(partial) == 0 { + case Branch[H]: + if partial.Len() == 0 { if n.value == nil { // Nothing to delete since the branch doesn't contains a value return restoreNode{n}, nil } // The branch contains the value so we delete it - t.replaceOldValue(oldValue, n.value, partial) - newNode, err := t.fix(Branch{n.partialKey, n.children, nil}) + t.replaceOldValue(oldValue, n.value, keyNibbles.Left()) + newNode, err := t.fix(Branch[H]{n.partialKey, n.children, nil}, keyNibbles) if err != nil { return nil, err } return replaceNode{newNode}, nil } - common := nibbles.CommonPrefix(n.partialKey, partial) - existingLength := len(n.partialKey) + existingKey := nibbles.NewNibblesFromNodeKey(n.partialKey) - if common == existingLength && common == len(partial) { + common := existingKey.CommonPrefix(partial) + existingLength := existingKey.Len() + + if common == existingLength && common == partial.Len() { // Replace value if n.value != nil { - t.replaceOldValue(oldValue, n.value, partial) - newNode, err := t.fix(Branch{n.partialKey, n.children, nil}) + keyVal := keyNibbles.Clone() + keyVal.Advance(existingLength) + t.replaceOldValue(oldValue, n.value, keyVal.Left()) + newNode, err := t.fix(Branch[H]{n.partialKey, n.children, nil}, keyNibbles) return replaceNode{newNode}, err } - return restoreNode{Branch{n.partialKey, n.children, nil}}, nil + return restoreNode{Branch[H]{n.partialKey, n.children, nil}}, nil } else if common < existingLength { return restoreNode{n}, nil } // Check children - idx := partial[common] + idx := partial.At(common) // take child and replace it to nil child := n.children[idx] n.children[idx] = nil @@ -502,8 +586,10 @@ func (t *TrieDB) removeInspector(stored Node, keyNibbles []byte, oldValue *nodeV if child == nil { return restoreNode{n}, nil } + prefix := keyNibbles + keyNibbles.Advance(common + 1) - removeAtResult, err := t.removeAt(child, partial[len(n.partialKey)+1:], oldValue) + removeAtResult, err := t.removeAt(child, keyNibbles, oldValue) if err != nil { return nil, err } @@ -516,7 +602,7 @@ func (t *TrieDB) removeInspector(stored Node, keyNibbles []byte, oldValue *nodeV return restoreNode{n}, nil } - newNode, err := t.fix(n) + newNode, err := t.fix(n, prefix) if err != nil { return nil, err } @@ -527,45 +613,50 @@ func (t *TrieDB) removeInspector(stored Node, keyNibbles []byte, oldValue *nodeV } // insertInspector inserts the new key / value pair into the given node `stored` -func (t *TrieDB) insertInspector(stored Node, keyNibbles []byte, value []byte, oldValue *nodeValue) (action, error) { - partial := keyNibbles +func (t *TrieDB[H, Hasher]) insertInspector( + stored Node, keyNibbles *nibbles.Nibbles, value []byte, oldValue *nodeValue[H], +) (action, error) { + partial := keyNibbles.Clone() switch n := stored.(type) { case Empty: // If the node is empty we have to replace it with a leaf node with the // new value - value := NewValue(value, t.version.MaxInlineValue()) - return replaceNode{node: Leaf{partialKey: partial, value: value}}, nil - case Leaf: - existingKey := n.partialKey - common := nibbles.CommonPrefix(partial, existingKey) - - if common == len(existingKey) && common == len(partial) { + value := NewValue[H](value, t.version.MaxInlineValue()) + pnk := partial.NodeKey() + return replaceNode{node: Leaf[H]{partialKey: pnk, value: value}}, nil + case Leaf[H]: + existingKey := nibbles.NewNibblesFromNodeKey(n.partialKey) + common := existingKey.CommonPrefix(partial) + + if common == existingKey.Len() && common == partial.Len() { // We are trying to insert a value in the same leaf so we just need // to replace the value - value := NewValue(value, t.version.MaxInlineValue()) + value := NewValue[H](value, t.version.MaxInlineValue()) unchanged := n.value.equal(value) - t.replaceOldValue(oldValue, n.value, partial) - leaf := Leaf{partialKey: n.partialKey, value: value} + keyVal := keyNibbles.Clone() + keyVal.Advance(existingKey.Len()) + t.replaceOldValue(oldValue, n.value, keyVal.Left()) + leaf := Leaf[H]{partialKey: n.partialKey, value: value} if unchanged { // If the value didn't change we can restore this leaf previously // taken from storage return restoreNode{leaf}, nil } return replaceNode{leaf}, nil - } else if common < len(existingKey) { + } else if common < existingKey.Len() { // If the common prefix is less than this leaf's key then we need to // create a branch node. Then add this leaf and the new value to the // branch var children [codec.ChildrenCapacity]NodeHandle - idx := existingKey[common] + idx := existingKey.At(common) // Modify the existing leaf partial key and add it as a child - newLeaf := Leaf{existingKey[common+1:], n.value} + newLeaf := Leaf[H]{existingKey.Mid(common + 1).NodeKey(), n.value} children[idx] = inMemory(t.storage.alloc(NewStoredNode{node: newLeaf})) - branch := Branch{ - partialKey: partial[:common], + branch := Branch[H]{ + partialKey: partial.NodeKeyRange(common), children: children, value: nil, } @@ -579,8 +670,8 @@ func (t *TrieDB) insertInspector(stored Node, keyNibbles []byte, value []byte, o } else { // we have a common prefix but the new key is longer than the existing // then we turn this leaf into a branch and add the new leaf as a child - var branch Node = Branch{ - partialKey: n.partialKey, + var branch Node = Branch[H]{ + partialKey: existingKey.NodeKey(), children: [codec.ChildrenCapacity]NodeHandle{}, value: n.value, } @@ -593,62 +684,64 @@ func (t *TrieDB) insertInspector(stored Node, keyNibbles []byte, value []byte, o branch = action.getNode() return replaceNode{branch}, nil } - case Branch: - existingKey := n.partialKey - common := nibbles.CommonPrefix(partial, existingKey) + case Branch[H]: + existingKey := nibbles.NewNibblesFromNodeKey(n.partialKey) + common := partial.CommonPrefix(existingKey) - if common == len(existingKey) && common == len(partial) { + if common == existingKey.Len() && common == partial.Len() { // We are trying to insert a value in the same branch so we just need // to replace the value - value := NewValue(value, t.version.MaxInlineValue()) + value := NewValue[H](value, t.version.MaxInlineValue()) var unchanged bool if n.value != nil { unchanged = n.value.equal(value) } - branch := Branch{existingKey, n.children, value} + branch := Branch[H]{existingKey.NodeKey(), n.children, value} - t.replaceOldValue(oldValue, n.value, partial) + keyVal := keyNibbles.Clone() + keyVal.Advance(existingKey.Len()) + t.replaceOldValue(oldValue, n.value, keyVal.Left()) if unchanged { // If the value didn't change we can restore this leaf previously // taken from storage return restoreNode{branch}, nil } return replaceNode{branch}, nil - } else if common < len(existingKey) { + } else if common < existingKey.Len() { // If the common prefix is less than this branch's key then we need to // create a branch node in between. // Then add this branch and the new value to the new branch // So we take this branch and we add it as a child of the new one - branchPartial := existingKey[common+1:] - lowerBranch := Branch{branchPartial, n.children, n.value} + branchPartial := existingKey.Mid(common + 1).NodeKey() + lowerBranch := Branch[H]{branchPartial, n.children, n.value} allocStorage := t.storage.alloc(NewStoredNode{node: lowerBranch}) children := [codec.ChildrenCapacity]NodeHandle{} - ix := existingKey[common] + ix := existingKey.At(common) children[ix] = inMemory(allocStorage) - value := NewValue(value, t.version.MaxInlineValue()) + value := NewValue[H](value, t.version.MaxInlineValue()) - if len(partial)-common == 0 { + if partial.Len()-common == 0 { // The value should be part of the branch return replaceNode{ - Branch{ - existingKey[:common], + Branch[H]{ + existingKey.NodeKeyRange(common), children, value, }, }, nil } else { // Value is in a leaf under the branch so we have to create it - storedLeaf := Leaf{partial[common+1:], value} + storedLeaf := Leaf[H]{partial.Mid(common + 1).NodeKey(), value} leaf := t.storage.alloc(NewStoredNode{node: storedLeaf}) - ix = partial[common] + ix = partial.At(common) children[ix] = inMemory(leaf) return replaceNode{ - Branch{ - existingKey[:common], + Branch[H]{ + existingKey.NodeKeyRange(common), children, nil, }, @@ -656,10 +749,11 @@ func (t *TrieDB) insertInspector(stored Node, keyNibbles []byte, value []byte, o } } else { // append after common == existing_key and partial > common - idx := partial[common] - keyNibbles = keyNibbles[common+1:] + idx := partial.At(common) + keyNibbles.Advance(common + 1) child := n.children[idx] if child != nil { + n.children[idx] = nil // We have to add the new value to the child newChild, changed, err := t.insertAt(child, keyNibbles, value, oldValue) if err != nil { @@ -668,8 +762,8 @@ func (t *TrieDB) insertInspector(stored Node, keyNibbles []byte, value []byte, o n.children[idx] = inMemory(newChild) if !changed { // Our branch is untouched so we can restore it - branch := Branch{ - existingKey, + branch := Branch[H]{ + existingKey.NodeKey(), n.children, n.value, } @@ -678,12 +772,12 @@ func (t *TrieDB) insertInspector(stored Node, keyNibbles []byte, value []byte, o } } else { // Original has nothing here so we have to create a new leaf - value := NewValue(value, t.version.MaxInlineValue()) - leaf := t.storage.alloc(NewStoredNode{node: Leaf{keyNibbles, value}}) + value := NewValue[H](value, t.version.MaxInlineValue()) + leaf := t.storage.alloc(NewStoredNode{node: Leaf[H]{keyNibbles.NodeKey(), value}}) n.children[idx] = inMemory(leaf) } - return replaceNode{Branch{ - existingKey, + return replaceNode{Branch[H]{ + existingKey.NodeKey(), n.children, n.value, }}, nil @@ -693,16 +787,22 @@ func (t *TrieDB) insertInspector(stored Node, keyNibbles []byte, value []byte, o } } -func (t *TrieDB) replaceOldValue( - oldValue *nodeValue, - storedValue nodeValue, - prefix []byte, +func (t *TrieDB[H, Hasher]) replaceOldValue( + oldValue *nodeValue[H], + storedValue nodeValue[H], + prefix nibbles.Prefix, ) { switch oldv := storedValue.(type) { - case valueRef, newValueRef: + case valueRef[H]: hash := oldv.getHash() - if hash != common.EmptyHash { - prefixedKey := append(prefix, oldv.getHash().ToBytes()...) + if hash != (*new(H)) { + prefixedKey := append(prefix.JoinedBytes(), hash.Bytes()...) + t.deathRow[string(prefixedKey)] = nil + } + case newValueRef[H]: + hash := oldv.getHash() + if hash != (*new(H)) { + prefixedKey := append(prefix.JoinedBytes(), hash.Bytes()...) t.deathRow[string(prefixedKey)] = nil } } @@ -711,27 +811,28 @@ func (t *TrieDB) replaceOldValue( // lookup node in DB and add it in storage, return storage handle // TODO: implement cache to improve performance -func (t *TrieDB) lookupNode(hash common.Hash) (storageHandle, error) { - encodedNode, err := t.db.Get(hash[:]) +func (t *TrieDB[H, Hasher]) lookupNode(hash H, key nibbles.Prefix) (storageHandle, error) { + prefixedKey := append(key.JoinedBytes(), hash.Bytes()...) + encodedNode, err := t.db.Get(prefixedKey) if err != nil { return -1, ErrIncompleteDB } - t.recordAccess(encodedNodeAccess{hash: t.rootHash, encodedNode: encodedNode}) + t.recordAccess(EncodedNodeAccess[H]{Hash: t.rootHash, EncodedNode: encodedNode}) - node, err := newNodeFromEncoded(hash, encodedNode, t.storage) + node, err := newNodeFromEncoded[H](hash, encodedNode, t.storage) if err != nil { return -1, err } - return t.storage.alloc(CachedStoredNode{ + return t.storage.alloc(CachedStoredNode[H]{ node: node, hash: hash, }), nil } // commit writes all trie changes to the underlying db -func (t *TrieDB) commit() error { +func (t *TrieDB[H, Hasher]) commit() error { logger.Debug("Committing trie changes to db") logger.Debugf("%d nodes to remove from db", len(t.deathRow)) @@ -754,7 +855,7 @@ func (t *TrieDB) commit() error { var handle storageHandle switch h := t.rootHandle.(type) { - case persisted: + case persisted[H]: return nil // nothing to commit since the root is already in db case inMemory: handle = storageHandle(h) @@ -763,36 +864,36 @@ func (t *TrieDB) commit() error { switch stored := t.storage.destroy(handle).(type) { case NewStoredNode: // Reconstructs the full key for root node - var k []byte + var fullKey *nibbles.NibbleSlice + if pk := stored.getNode().getPartialKey(); pk != nil { + fk := nibbles.NewNibblesFromNodeKey(*pk) + ns := nibbles.NewNibbleSliceFromNibbles(fk) + fullKey = &ns + } - encodedNode, err := newEncodedNode( - stored.node, - func(node nodeToEncode, partialKey []byte, childIndex *byte) (ChildReference, error) { - k = append(k, partialKey...) - mov := len(partialKey) - if childIndex != nil { - k = append(k, *childIndex) - mov += 1 - } + var k nibbles.NibbleSlice + encodedNode, err := newEncodedNode[H]( + stored.node, + func(node nodeToEncode, partialKey *nibbles.Nibbles, childIndex *byte) (ChildReference, error) { + mov := k.AppendOptionalSliceAndNibble(partialKey, childIndex) switch n := node.(type) { case newNodeToEncode: - hash := common.MustBlake2bHash(n.value) - prefixedKey := append(n.partialKey, hash.ToBytes()...) + hash := (*new(Hasher)).Hash(n.value) + prefixedKey := append(k.Prefix().JoinedBytes(), hash.Bytes()...) err := dbBatch.Put(prefixedKey, n.value) if err != nil { return nil, err } - - k = k[:mov] - return HashChildReference(hash), nil + k.DropLasts(mov) + return HashChildReference[H]{hash}, nil case trieNodeToEncode: - result, err := t.commitChild(dbBatch, n.child, k) + result, err := t.commitChild(dbBatch, n.child, &k) if err != nil { return nil, err } - k = k[:mov] + k.DropLasts(mov) return result, nil default: panic("unreachable") @@ -804,21 +905,24 @@ func (t *TrieDB) commit() error { return err } - hash := common.MustBlake2bHash(encodedNode) - err = dbBatch.Put(hash[:], encodedNode) + hash := (*new(Hasher)).Hash(encodedNode) + err = dbBatch.Put(hash.Bytes(), encodedNode) if err != nil { return err } t.rootHash = hash - t.rootHandle = persisted(t.rootHash) + t.rootHandle = persisted[H]{t.rootHash} + + // TODO: use fullKey when caching these nodes + _ = fullKey // Flush all db changes return dbBatch.Flush() - case CachedStoredNode: + case CachedStoredNode[H]: t.rootHash = stored.hash t.rootHandle = inMemory( - t.storage.alloc(CachedStoredNode{stored.node, stored.hash}), + t.storage.alloc(CachedStoredNode[H]{stored.node, stored.hash}), ) return nil default: @@ -827,72 +931,73 @@ func (t *TrieDB) commit() error { } // Commit a node by hashing it and writing it to the db. -func (t *TrieDB) commitChild( +func (t *TrieDB[H, Hasher]) commitChild( dbBatch database.Batch, child NodeHandle, - prefixKey []byte, + prefixKey *nibbles.NibbleSlice, ) (ChildReference, error) { switch nh := child.(type) { - case persisted: + case persisted[H]: // Already persisted we have to do nothing - return HashChildReference(nh), nil + return HashChildReference[H]{nh.hash}, nil case inMemory: stored := t.storage.destroy(storageHandle(nh)) switch storedNode := stored.(type) { - case CachedStoredNode: - return HashChildReference(storedNode.hash), nil + case CachedStoredNode[H]: + return HashChildReference[H]{storedNode.hash}, nil case NewStoredNode: - // We have to store the node in the DB - commitChildFunc := func(node nodeToEncode, partialKey []byte, childIndex *byte) (ChildReference, error) { - prefixKey = append(prefixKey, partialKey...) - mov := len(partialKey) - if childIndex != nil { - prefixKey = append(prefixKey, *childIndex) - mov += 1 - } + var fullKey *nibbles.NibbleSlice + prefix := prefixKey.Clone() + if partial := stored.getNode().getPartialKey(); partial != nil { + fk := nibbles.NewNibblesFromNodeKey(*partial) + prefix.AppendPartial(fk.RightPartial()) + } + fullKey = &prefix + // TODO: caching uses fullKey + _ = fullKey + // We have to store the node in the DB + commitChildFunc := func(node nodeToEncode, partialKey *nibbles.Nibbles, childIndex *byte) (ChildReference, error) { + mov := prefixKey.AppendOptionalSliceAndNibble(partialKey, childIndex) switch n := node.(type) { case newNodeToEncode: - hash := common.MustBlake2bHash(n.value) - prefixedKey := append(n.partialKey, hash.ToBytes()...) + hash := (*new(Hasher)).Hash(n.value) + prefixedKey := append(prefixKey.Prefix().JoinedBytes(), hash.Bytes()...) err := dbBatch.Put(prefixedKey, n.value) if err != nil { panic("inserting in db") } - if t.cache != nil { - t.cache.SetValue(n.partialKey, n.value) - } - - prefixKey = prefixKey[:mov] - return HashChildReference(hash), nil + prefixKey.DropLasts(mov) + return HashChildReference[H]{hash}, nil case trieNodeToEncode: result, err := t.commitChild(dbBatch, n.child, prefixKey) if err != nil { return nil, err } - prefixKey = prefixKey[:mov] + prefixKey.DropLasts(mov) return result, nil default: panic("unreachable") } } - encoded, err := newEncodedNode(storedNode.node, commitChildFunc) + encoded, err := newEncodedNode[H](storedNode.node, commitChildFunc) if err != nil { panic("encoding node") } // Not inlined node - if len(encoded) >= common.HashLength { - hash := common.MustBlake2bHash(encoded) - err := dbBatch.Put(hash[:], encoded) + if len(encoded) >= (*new(H)).Length() { + hash := (*new(Hasher)).Hash(encoded) + prefixedKey := append(prefixKey.Prefix().JoinedBytes(), hash.Bytes()...) + err := dbBatch.Put(prefixedKey, encoded) if err != nil { return nil, err } - return HashChildReference(hash), nil + return HashChildReference[H]{hash}, nil } else { return InlineChildReference(encoded), nil } @@ -904,18 +1009,12 @@ func (t *TrieDB) commitChild( } } -func (t *TrieDB) Iter() trie.TrieIterator { - return NewTrieDBIterator(t) -} - -func (t *TrieDB) PrefixedIter(prefix []byte) trie.TrieIterator { - return NewPrefixedTrieDBIterator(t, prefix) -} - -func (t *TrieDB) recordAccess(access trieAccess) { +func (t *TrieDB[H, Hasher]) recordAccess(access TrieAccess) { if t.recorder != nil { - t.recorder.record(access) + t.recorder.Record(access) } } -var _ trie.TrieRead = (*TrieDB)(nil) +func (t *TrieDB[H, Hasher]) GetHash(key []byte) (*H, error) { + panic("unimpl") +} diff --git a/pkg/trie/triedb/triedb_iterator.go b/pkg/trie/triedb/triedb_iterator.go deleted file mode 100644 index 9172f03e1d..0000000000 --- a/pkg/trie/triedb/triedb_iterator.go +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright 2024 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package triedb - -import ( - "bytes" - - "github.com/ChainSafe/gossamer/pkg/trie" - nibbles "github.com/ChainSafe/gossamer/pkg/trie/codec" - "github.com/ChainSafe/gossamer/pkg/trie/triedb/codec" -) - -type iteratorState struct { - parentFullKey []byte // key of the parent node of the actual node - node codec.EncodedNode // actual node -} - -// fullKeyNibbles return the full key of the node contained in this state -// child is the child where the node is stored in the parent node -func (s *iteratorState) fullKeyNibbles(child *int) []byte { - fullKey := bytes.Join([][]byte{s.parentFullKey, s.node.GetPartialKey()}, nil) - if child != nil { - return bytes.Join([][]byte{fullKey, {byte(*child)}}, nil) - } - - return nibbles.NibblesToKeyLE(fullKey) -} - -type TrieDBIterator struct { - db *TrieDB // trie to iterate over - nodeStack []*iteratorState // Pending nodes to visit -} - -func NewTrieDBIterator(trie *TrieDB) *TrieDBIterator { - rootNode, err := trie.getRootNode() - if err != nil { - panic("trying to create trie iterator with incomplete trie DB") - } - return &TrieDBIterator{ - db: trie, - nodeStack: []*iteratorState{ - { - node: rootNode, - }, - }, - } -} - -func NewPrefixedTrieDBIterator(trie *TrieDB, prefix []byte) *TrieDBIterator { - nodeAtPrefix, err := trie.getNodeAt(prefix) - if err != nil { - panic("trying to create trie iterator with incomplete trie DB") - } - - return &TrieDBIterator{ - db: trie, - nodeStack: []*iteratorState{ - { - parentFullKey: prefix[:len(nodeAtPrefix.GetPartialKey())-1], - node: nodeAtPrefix, - }, - }, - } -} - -// nextToVisit sets the next node to visit in the iterator -func (i *TrieDBIterator) nextToVisit(parentKey []byte, node codec.EncodedNode) { - i.nodeStack = append(i.nodeStack, &iteratorState{ - parentFullKey: parentKey, - node: node, - }) -} - -// nextState pops the next node to visit from the stack -// warn: this function does not check if the node stack is empty -// this check should be made by the caller -func (i *TrieDBIterator) nextState() *iteratorState { - currentState := i.nodeStack[len(i.nodeStack)-1] - i.nodeStack = i.nodeStack[:len(i.nodeStack)-1] - return currentState -} - -func (i *TrieDBIterator) NextEntry() *trie.Entry { - for len(i.nodeStack) > 0 { - currentState := i.nextState() - currentNode := currentState.node - - switch n := currentNode.(type) { - case codec.Leaf: - key := currentState.fullKeyNibbles(nil) - value := i.db.Get(key) - return &trie.Entry{Key: key, Value: value} - case codec.Branch: - // Reverse iterate over children because we are using a LIFO stack - // and we want to visit the leftmost child first - for idx := len(n.Children) - 1; idx >= 0; idx-- { - child := n.Children[idx] - if child != nil { - childNode, err := i.db.getNode(child) - if err != nil { - panic(err) - } - i.nextToVisit(currentState.fullKeyNibbles(&idx), childNode) - } - } - if n.GetValue() != nil { - key := currentState.fullKeyNibbles(nil) - value := i.db.Get(key) - return &trie.Entry{Key: key, Value: value} - } - } - } - - return nil -} - -// NextKey performs a depth-first search on the trie and returns the next key -// based on the current state of the iterator. -func (i *TrieDBIterator) NextKey() []byte { - entry := i.NextEntry() - if entry != nil { - return entry.Key - } - return nil -} - -func (i *TrieDBIterator) NextKeyFunc(predicate func(nextKey []byte) bool) (nextKey []byte) { - for entry := i.NextEntry(); entry != nil; entry = i.NextEntry() { - if predicate(entry.Key) { - return entry.Key - } - } - return nil -} - -func (i *TrieDBIterator) Seek(targetKey []byte) { - for key := i.NextKey(); bytes.Compare(key, targetKey) < 0; key = i.NextKey() { - } -} - -var _ trie.TrieIterator = (*TrieDBIterator)(nil) diff --git a/pkg/trie/triedb/triedb_iterator_test.go b/pkg/trie/triedb/triedb_iterator_test.go index aa89caca9e..fc050bcc05 100644 --- a/pkg/trie/triedb/triedb_iterator_test.go +++ b/pkg/trie/triedb/triedb_iterator_test.go @@ -6,6 +6,8 @@ package triedb import ( "testing" + "github.com/ChainSafe/gossamer/internal/primitives/core/hash" + "github.com/ChainSafe/gossamer/internal/primitives/runtime" "github.com/ChainSafe/gossamer/pkg/trie" "github.com/ChainSafe/gossamer/pkg/trie/inmemory" "github.com/stretchr/testify/assert" @@ -29,36 +31,60 @@ func TestIterator(t *testing.T) { for k, v := range entries { inMemoryTrie.Put([]byte(k), v) } - err := inMemoryTrie.WriteDirty(db) assert.NoError(t, err) root, err := inMemoryTrie.Hash() assert.NoError(t, err) - trieDB := NewTrieDB(root, db) + inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + trieDB := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) + + for k, v := range entries { + err := trieDB.Put([]byte(k), v) + assert.NoError(t, err) + } + assert.NoError(t, trieDB.commit()) + + // check that the root hashes are the same + assert.Equal(t, root.ToBytes(), trieDB.rootHash.Bytes()) + t.Run("iterate_over_all_entries", func(t *testing.T) { - iter := NewTrieDBIterator(trieDB) + iter, err := newRawIterator(trieDB) + assert.NoError(t, err) expected := inMemoryTrie.NextKey([]byte{}) i := 0 - for key := iter.NextKey(); key != nil; key = iter.NextKey() { - assert.Equal(t, expected, key) + for { + item, err := iter.NextItem() + assert.NoError(t, err) + if item == nil { + break + } + assert.Equal(t, expected, item.Key) expected = inMemoryTrie.NextKey(expected) i++ } - assert.Equal(t, len(entries), i) }) - t.Run("iterate_from_given_key", func(t *testing.T) { - iter := NewTrieDBIterator(trieDB) + t.Run("iterate_after_seeking", func(t *testing.T) { + iter, err := newRawIterator(trieDB) + assert.NoError(t, err) - iter.Seek([]byte("not")) + found, err := iter.seek([]byte("not"), true) + assert.NoError(t, err) + assert.True(t, found) expected := inMemoryTrie.NextKey([]byte("not")) - actual := iter.NextKey() + actual, err := iter.NextItem() + assert.NoError(t, err) + assert.NotNil(t, actual) - assert.Equal(t, expected, actual) + assert.Equal(t, []byte("not"), actual.Key) + actual, err = iter.NextItem() + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal(t, expected, actual.Key) }) } diff --git a/pkg/trie/triedb/triedb_test.go b/pkg/trie/triedb/triedb_test.go index 7a46373a3a..6d47106770 100644 --- a/pkg/trie/triedb/triedb_test.go +++ b/pkg/trie/triedb/triedb_test.go @@ -6,9 +6,13 @@ package triedb import ( "testing" + "github.com/ChainSafe/gossamer/internal/primitives/core/hash" + "github.com/ChainSafe/gossamer/internal/primitives/runtime" "github.com/ChainSafe/gossamer/pkg/trie" "github.com/ChainSafe/gossamer/pkg/trie/triedb/codec" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibbles" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestInsertions(t *testing.T) { @@ -16,20 +20,21 @@ func TestInsertions(t *testing.T) { testCases := map[string]struct { trieEntries []trie.Entry - key []byte - value []byte - stored nodeStorage + key []uint8 + value []uint8 + stored nodeStorage[hash.H256] + dontCheck bool }{ "nil_parent": { trieEntries: []trie.Entry{}, - key: []byte{1}, + key: []byte{0x01}, value: []byte("leaf"), - stored: nodeStorage{ + stored: nodeStorage[hash.H256]{ nodes: []StoredNode{ NewStoredNode{ - Leaf{ - partialKey: []byte{1}, - value: inline([]byte("leaf")), + Leaf[hash.H256]{ + partialKey: nodeKey{Data: []byte{0x01}, Offset: 0}, + value: inline[hash.H256]([]byte("leaf")), }, }, }, @@ -38,27 +43,27 @@ func TestInsertions(t *testing.T) { "branch_parent": { trieEntries: []trie.Entry{ { - Key: []byte{1}, + Key: []byte{0x01}, Value: []byte("branch"), }, }, - key: []byte{1, 0}, + key: []byte{0x01, 0x01}, value: []byte("leaf"), - stored: nodeStorage{ + stored: nodeStorage[hash.H256]{ nodes: []StoredNode{ NewStoredNode{ - Leaf{ - partialKey: []byte{}, - value: inline([]byte("leaf")), + Leaf[hash.H256]{ + partialKey: nodeKey{Data: []byte{0x01}, Offset: 1}, + value: inline[hash.H256]([]byte("leaf")), }, }, NewStoredNode{ - Branch{ - partialKey: []byte{1}, - value: inline([]byte("branch")), + Branch[hash.H256]{ + partialKey: nodeKey{Data: []byte{0x01}}, + value: inline[hash.H256]([]byte("branch")), children: [codec.ChildrenCapacity]NodeHandle{ - inMemory(0), nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + inMemory(0), nil, nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, nil, nil, nil, nil, }, }, }, @@ -78,28 +83,28 @@ func TestInsertions(t *testing.T) { }, key: []byte{1, 0}, value: []byte("in between branch"), - stored: nodeStorage{ + stored: nodeStorage[hash.H256]{ nodes: []StoredNode{ NewStoredNode{ - Branch{ - partialKey: []byte{}, - value: inline([]byte("in between branch")), + Branch[hash.H256]{ + partialKey: nodeKey{Data: []byte{0}, Offset: 1}, + value: inline[hash.H256]([]byte("in between branch")), children: [codec.ChildrenCapacity]NodeHandle{ - nil, inMemory(1), nil, nil, nil, nil, + inMemory(1), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, }, }, }, NewStoredNode{ - Leaf{ - partialKey: []byte{}, - value: inline([]byte("leaf")), + Leaf[hash.H256]{ + partialKey: nodeKey{Data: []byte{0x01}, Offset: 1}, + value: inline[hash.H256]([]byte("leaf")), }, }, NewStoredNode{ - Branch{ - partialKey: []byte{1}, - value: inline([]byte("branch")), + Branch[hash.H256]{ + partialKey: nodeKey{Data: []byte{1}, Offset: 0}, + value: inline[hash.H256]([]byte("branch")), children: [codec.ChildrenCapacity]NodeHandle{ inMemory(0), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, @@ -122,28 +127,28 @@ func TestInsertions(t *testing.T) { }, key: []byte{1}, value: []byte("top branch"), - stored: nodeStorage{ + stored: nodeStorage[hash.H256]{ nodes: []StoredNode{ NewStoredNode{ - Leaf{ - partialKey: []byte{}, - value: inline([]byte("leaf")), + Leaf[hash.H256]{ + partialKey: nodeKey{Data: []byte{1}, Offset: 1}, + value: inline[hash.H256]([]byte("leaf")), }, }, NewStoredNode{ - Branch{ - partialKey: []byte{}, - value: inline([]byte("branch")), + Branch[hash.H256]{ + partialKey: nodeKey{Data: []byte{0}, Offset: 1}, + value: inline[hash.H256]([]byte("branch")), children: [codec.ChildrenCapacity]NodeHandle{ - nil, inMemory(0), nil, nil, nil, nil, nil, nil, + inMemory(0), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, }, }, }, NewStoredNode{ - Branch{ - partialKey: []byte{1}, - value: inline([]byte("top branch")), + Branch[hash.H256]{ + partialKey: nodeKey{Data: []byte{1}}, + value: inline[hash.H256]([]byte("top branch")), children: [codec.ChildrenCapacity]NodeHandle{ inMemory(1), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, @@ -166,18 +171,18 @@ func TestInsertions(t *testing.T) { }, key: []byte{1}, value: []byte("new branch"), - stored: nodeStorage{ + stored: nodeStorage[hash.H256]{ nodes: []StoredNode{ NewStoredNode{ - Leaf{ - partialKey: []byte{}, - value: inline([]byte("leaf")), + Leaf[hash.H256]{ + partialKey: nodeKey{Data: []byte{0}, Offset: 1}, + value: inline[hash.H256]([]byte("leaf")), }, }, NewStoredNode{ - Branch{ - partialKey: []byte{1}, - value: inline([]byte("new branch")), + Branch[hash.H256]{ + partialKey: nodeKey{Data: []byte{1}}, + value: inline[hash.H256]([]byte("new branch")), children: [codec.ChildrenCapacity]NodeHandle{ inMemory(0), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, @@ -186,6 +191,7 @@ func TestInsertions(t *testing.T) { }, }, }, + dontCheck: true, }, "override_branch_value_same_value": { trieEntries: []trie.Entry{ @@ -200,18 +206,18 @@ func TestInsertions(t *testing.T) { }, key: []byte{1}, value: []byte("branch"), - stored: nodeStorage{ + stored: nodeStorage[hash.H256]{ nodes: []StoredNode{ NewStoredNode{ - Leaf{ - partialKey: []byte{}, - value: inline([]byte("leaf")), + Leaf[hash.H256]{ + partialKey: nodeKey{Data: []byte{0}, Offset: 1}, + value: inline[hash.H256]([]byte("leaf")), }, }, NewStoredNode{ - Branch{ - partialKey: []byte{1}, - value: inline([]byte("branch")), + Branch[hash.H256]{ + partialKey: nodeKey{Data: []byte{1}}, + value: inline[hash.H256]([]byte("branch")), children: [codec.ChildrenCapacity]NodeHandle{ inMemory(0), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, @@ -234,12 +240,12 @@ func TestInsertions(t *testing.T) { }, key: []byte{1, 0}, value: []byte("leaf"), - stored: nodeStorage{ + stored: nodeStorage[hash.H256]{ nodes: []StoredNode{ NewStoredNode{ - Branch{ - partialKey: []byte{1}, - value: inline([]byte("branch")), + Branch[hash.H256]{ + partialKey: nodeKey{Data: []byte{1}}, + value: inline[hash.H256]([]byte("branch")), children: [codec.ChildrenCapacity]NodeHandle{ inMemory(1), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, @@ -247,9 +253,9 @@ func TestInsertions(t *testing.T) { }, }, NewStoredNode{ - Leaf{ - partialKey: []byte{}, - value: inline([]byte("leaf")), + Leaf[hash.H256]{ + partialKey: nodeKey{Data: []byte{0}, Offset: 1}, + value: inline[hash.H256]([]byte("leaf")), }, }, }, @@ -264,16 +270,17 @@ func TestInsertions(t *testing.T) { }, key: []byte{1}, value: []byte("new leaf"), - stored: nodeStorage{ + stored: nodeStorage[hash.H256]{ nodes: []StoredNode{ NewStoredNode{ - Leaf{ - partialKey: []byte{1}, - value: inline([]byte("new leaf")), + Leaf[hash.H256]{ + partialKey: nodeKey{Data: []byte{1}}, + value: inline[hash.H256]([]byte("new leaf")), }, }, }, }, + dontCheck: true, }, "write_same_leaf_value_to_leaf_parent": { trieEntries: []trie.Entry{ @@ -284,12 +291,12 @@ func TestInsertions(t *testing.T) { }, key: []byte{1}, value: []byte("same"), - stored: nodeStorage{ + stored: nodeStorage[hash.H256]{ nodes: []StoredNode{ NewStoredNode{ - Leaf{ - partialKey: []byte{1}, - value: inline([]byte("same")), + Leaf[hash.H256]{ + partialKey: nodeKey{Data: []byte{1}}, + value: inline[hash.H256]([]byte("same")), }, }, }, @@ -298,35 +305,38 @@ func TestInsertions(t *testing.T) { "write_leaf_as_divergent_child_next_to_parent_leaf": { trieEntries: []trie.Entry{ { - Key: []byte{1, 2}, + Key: []byte{0x01, 0x02}, Value: []byte("original leaf"), }, }, - key: []byte{2, 3}, + key: []byte{0x02, 0x03}, value: []byte("leaf"), - stored: nodeStorage{ + stored: nodeStorage[hash.H256]{ nodes: []StoredNode{ NewStoredNode{ - Leaf{ - partialKey: []byte{2}, - value: inline([]byte("original leaf")), + Leaf[hash.H256]{ + partialKey: nodeKey{Data: []byte{0x02}}, + value: inline[hash.H256]([]byte("original leaf")), }, }, NewStoredNode{ - Leaf{ - partialKey: []byte{3}, - value: inline([]byte("leaf")), + Leaf[hash.H256]{ + partialKey: nodeKey{Data: []byte{0x03}}, + value: inline[hash.H256]([]byte("leaf")), }, }, NewStoredNode{ - Branch{ - partialKey: []byte{}, + Branch[hash.H256]{ + partialKey: nodeKey{Data: []byte{0x00}, Offset: 1}, value: nil, children: [codec.ChildrenCapacity]NodeHandle{ nil, - inMemory(0), inMemory(1), - nil, nil, nil, nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, + inMemory(0), + inMemory(1), + nil, + nil, nil, nil, nil, + nil, nil, nil, nil, + nil, nil, nil, nil, }, }, }, @@ -339,18 +349,24 @@ func TestInsertions(t *testing.T) { testCase := testCase t.Run(name, func(t *testing.T) { t.Parallel() - // Setup trie - inmemoryDB := NewMemoryDB(EmptyNode) - trie := NewEmptyTrieDB(inmemoryDB) + inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) for _, entry := range testCase.trieEntries { - assert.NoError(t, trie.insert(entry.Key, entry.Value)) + require.NoError(t, trie.Put(entry.Key, entry.Value)) } - // Add new key-value pair - err := trie.insert(testCase.key, testCase.value) - assert.NoError(t, err) + err := trie.Put(testCase.key, testCase.value) + require.NoError(t, err) + + if !testCase.dontCheck { + // Check values for keys + for _, entry := range testCase.trieEntries { + require.Equal(t, entry.Value, trie.Get(entry.Key)) + } + } + require.Equal(t, testCase.value, trie.Get(testCase.key)) // Check we have what we expect assert.Equal(t, testCase.stored.nodes, trie.storage.nodes) @@ -364,7 +380,7 @@ func TestDeletes(t *testing.T) { testCases := map[string]struct { trieEntries []trie.Entry key []byte - expected nodeStorage + expected nodeStorage[hash.H256] }{ "nil_key": { trieEntries: []trie.Entry{ @@ -373,12 +389,12 @@ func TestDeletes(t *testing.T) { Value: []byte("leaf"), }, }, - expected: nodeStorage{ + expected: nodeStorage[hash.H256]{ nodes: []StoredNode{ NewStoredNode{ - Leaf{ - partialKey: []byte{1}, - value: inline([]byte("leaf")), + Leaf[hash.H256]{ + partialKey: nodeKey{Data: []byte{1}}, + value: inline[hash.H256]([]byte("leaf")), }, }, }, @@ -386,7 +402,7 @@ func TestDeletes(t *testing.T) { }, "empty_trie": { key: []byte{1}, - expected: nodeStorage{ + expected: nodeStorage[hash.H256]{ nodes: []StoredNode{nil}, }, }, @@ -398,7 +414,7 @@ func TestDeletes(t *testing.T) { }, }, key: []byte{1}, - expected: nodeStorage{ + expected: nodeStorage[hash.H256]{ nodes: []StoredNode{nil}, }, }, @@ -414,13 +430,13 @@ func TestDeletes(t *testing.T) { }, }, key: []byte{1}, - expected: nodeStorage{ + expected: nodeStorage[hash.H256]{ nodes: []StoredNode{ nil, NewStoredNode{ - Leaf{ - partialKey: []byte{1, 0}, - value: inline([]byte("leaf")), + Leaf[hash.H256]{ + partialKey: nodeKey{Data: []byte{1, 0}}, + value: inline[hash.H256]([]byte("leaf")), }, }, }, @@ -438,23 +454,23 @@ func TestDeletes(t *testing.T) { }, }, key: []byte{1}, - expected: nodeStorage{ + expected: nodeStorage[hash.H256]{ nodes: []StoredNode{ NewStoredNode{ - Leaf{ - partialKey: []byte{}, - value: inline([]byte("leaf1")), + Leaf[hash.H256]{ + partialKey: nodeKey{Data: make([]byte, 0)}, + value: inline[hash.H256]([]byte("leaf1")), }, }, NewStoredNode{ - Leaf{ - partialKey: []byte{}, - value: inline([]byte("leaf2")), + Leaf[hash.H256]{ + partialKey: nodeKey{Data: make([]byte, 0)}, + value: inline[hash.H256]([]byte("leaf2")), }, }, NewStoredNode{ - Branch{ - partialKey: []byte{1}, + Branch[hash.H256]{ + partialKey: nodeKey{Data: []byte{0x00, 0x10}, Offset: 1}, children: [codec.ChildrenCapacity]NodeHandle{ inMemory(0), inMemory(1), }, @@ -471,15 +487,15 @@ func TestDeletes(t *testing.T) { t.Parallel() // Setup trie - inmemoryDB := NewMemoryDB(EmptyNode) - trie := NewEmptyTrieDB(inmemoryDB) + inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) for _, entry := range testCase.trieEntries { - assert.NoError(t, trie.insert(entry.Key, entry.Value)) + assert.NoError(t, trie.Put(entry.Key, entry.Value)) } // Remove key - err := trie.remove(testCase.key) + err := trie.Delete(testCase.key) assert.NoError(t, err) // Check we have what we expect @@ -495,7 +511,7 @@ func TestInsertAfterDelete(t *testing.T) { trieEntries []trie.Entry key []byte value []byte - expected nodeStorage + expected nodeStorage[hash.H256] }{ "insert_leaf_after_delete": { trieEntries: []trie.Entry{ @@ -506,12 +522,12 @@ func TestInsertAfterDelete(t *testing.T) { }, key: []byte{1}, value: []byte("new leaf"), - expected: nodeStorage{ + expected: nodeStorage[hash.H256]{ nodes: []StoredNode{ NewStoredNode{ - Leaf{ - partialKey: []byte{1}, - value: inline([]byte("new leaf")), + Leaf[hash.H256]{ + partialKey: nodeKey{Data: []byte{1}}, + value: inline[hash.H256]([]byte("new leaf")), }, }, }, @@ -530,18 +546,18 @@ func TestInsertAfterDelete(t *testing.T) { }, key: []byte{1}, value: []byte("new branch"), - expected: nodeStorage{ + expected: nodeStorage[hash.H256]{ nodes: []StoredNode{ NewStoredNode{ - Leaf{ - partialKey: []byte{}, - value: inline([]byte("leaf")), + Leaf[hash.H256]{ + partialKey: nodeKey{Data: []byte{0}, Offset: 1}, + value: inline[hash.H256]([]byte("leaf")), }, }, NewStoredNode{ - Branch{ - partialKey: []byte{1}, - value: inline([]byte("new branch")), + Branch[hash.H256]{ + partialKey: nodeKey{Data: []byte{1}}, + value: inline[hash.H256]([]byte("new branch")), children: [codec.ChildrenCapacity]NodeHandle{ inMemory(0), }, @@ -558,19 +574,19 @@ func TestInsertAfterDelete(t *testing.T) { t.Parallel() // Setup trie - inmemoryDB := NewMemoryDB(EmptyNode) - trie := NewEmptyTrieDB(inmemoryDB) + inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) for _, entry := range testCase.trieEntries { - assert.NoError(t, trie.insert(entry.Key, entry.Value)) + assert.NoError(t, trie.insert(nibbles.NewNibbles(entry.Key), entry.Value)) } // Remove key - err := trie.remove(testCase.key) + err := trie.remove(nibbles.NewNibbles(testCase.key)) assert.NoError(t, err) // Add again - err = trie.insert(testCase.key, testCase.value) + err = trie.insert(nibbles.NewNibbles(testCase.key), testCase.value) assert.NoError(t, err) // Check we have what we expect @@ -585,8 +601,8 @@ func TestDBCommits(t *testing.T) { t.Run("commit_leaf", func(t *testing.T) { t.Parallel() - inmemoryDB := NewMemoryDB(EmptyNode) - trie := NewEmptyTrieDB(inmemoryDB) + inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) err := trie.Put([]byte("leaf"), []byte("leafvalue")) assert.NoError(t, err) @@ -605,8 +621,8 @@ func TestDBCommits(t *testing.T) { t.Run("commit_branch_and_inlined_leaf", func(t *testing.T) { t.Parallel() - inmemoryDB := NewMemoryDB(EmptyNode) - trie := NewEmptyTrieDB(inmemoryDB) + inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) err := trie.Put([]byte("branchleaf"), []byte("leafvalue")) assert.NoError(t, err) @@ -629,8 +645,8 @@ func TestDBCommits(t *testing.T) { t.Run("commit_branch_and_hashed_leaf", func(t *testing.T) { t.Parallel() - inmemoryDB := NewMemoryDB(EmptyNode) - tr := NewEmptyTrieDB(inmemoryDB) + inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + tr := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) err := tr.Put([]byte("branchleaf"), make([]byte, 40)) assert.NoError(t, err) @@ -654,8 +670,8 @@ func TestDBCommits(t *testing.T) { t.Run("commit_leaf_with_hashed_value", func(t *testing.T) { t.Parallel() - inmemoryDB := NewMemoryDB(EmptyNode) - tr := NewEmptyTrieDB(inmemoryDB) + inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + tr := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) tr.SetVersion(trie.V1) err := tr.Put([]byte("leaf"), make([]byte, 40)) @@ -676,8 +692,8 @@ func TestDBCommits(t *testing.T) { t.Run("commit_leaf_with_hashed_value_then_remove_it", func(t *testing.T) { t.Parallel() - inmemoryDB := NewMemoryDB(EmptyNode) - tr := NewEmptyTrieDB(inmemoryDB) + inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + tr := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) tr.SetVersion(trie.V1) err := tr.Put([]byte("leaf"), make([]byte, 40)) @@ -700,8 +716,8 @@ func TestDBCommits(t *testing.T) { t.Run("commit_branch_and_hashed_leaf_with_hashed_value", func(t *testing.T) { t.Parallel() - inmemoryDB := NewMemoryDB(EmptyNode) - tr := NewEmptyTrieDB(inmemoryDB) + inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + tr := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) tr.SetVersion(trie.V1) err := tr.Put([]byte("branchleaf"), make([]byte, 40)) @@ -727,8 +743,8 @@ func TestDBCommits(t *testing.T) { t.Run("commit_branch_and_hashed_leaf_with_hashed_value_then_delete_it", func(t *testing.T) { t.Parallel() - inmemoryDB := NewMemoryDB(EmptyNode) - tr := NewEmptyTrieDB(inmemoryDB) + inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + tr := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) tr.SetVersion(trie.V1) err := tr.Put([]byte("branchleaf"), make([]byte, 40)) @@ -757,8 +773,8 @@ func TestDBCommits(t *testing.T) { t.Run("commit_branch_with_leaf_then_delete_leaf", func(t *testing.T) { t.Parallel() - inmemoryDB := NewMemoryDB(EmptyNode) - trie := NewEmptyTrieDB(inmemoryDB) + inmemoryDB := NewMemoryDB[hash.H256, runtime.BlakeTwo256](EmptyNode) + trie := NewEmptyTrieDB[hash.H256, runtime.BlakeTwo256](inmemoryDB) err := trie.Put([]byte("branchleaf"), []byte("leafvalue")) assert.NoError(t, err) diff --git a/pkg/trie/triedb/util_test.go b/pkg/trie/triedb/util_test.go index f0e782a8f0..4a16848319 100644 --- a/pkg/trie/triedb/util_test.go +++ b/pkg/trie/triedb/util_test.go @@ -5,39 +5,41 @@ package triedb import ( "bytes" + "strings" "github.com/ChainSafe/gossamer/internal/database" - "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/pkg/trie/db" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/hash" + "github.com/stretchr/testify/assert" ) // MemoryDB is an in-memory implementation of the Database interface backed by a // map. It uses blake2b as hashing algorithm type MemoryDB struct { - data map[common.Hash][]byte - hashedNullNode common.Hash + data map[string][]byte + hashedNullNode string nullNodeData []byte } -func NewMemoryDB(data []byte) *MemoryDB { +func NewMemoryDB[H hash.Hash, Hasher hash.Hasher[H]](data []byte) *MemoryDB { return &MemoryDB{ - data: make(map[common.Hash][]byte), - hashedNullNode: common.MustBlake2bHash(data), + data: make(map[string][]byte), + hashedNullNode: string((*new(Hasher)).Hash(data).Bytes()), nullNodeData: data, } } -func (db *MemoryDB) emplace(key common.Hash, value []byte) { +func (db *MemoryDB) emplace(key []byte, value []byte) { if bytes.Equal(value, db.nullNodeData) { return } - db.data[key] = value + db.data[string(key)] = value } func (db *MemoryDB) Get(key []byte) ([]byte, error) { - dbKey := common.NewHash(key) - if dbKey == db.hashedNullNode { + dbKey := string(key) + if strings.Contains(dbKey, db.hashedNullNode) { return db.nullNodeData, nil } if value, has := db.data[dbKey]; has { @@ -48,13 +50,12 @@ func (db *MemoryDB) Get(key []byte) ([]byte, error) { } func (db *MemoryDB) Put(key []byte, value []byte) error { - dbKey := common.NewHash(key) - db.emplace(dbKey, value) + db.emplace(key, value) return nil } func (db *MemoryDB) Del(key []byte) error { - dbKey := common.NewHash(key) + dbKey := string(key) delete(db.data, dbKey) return nil } @@ -84,3 +85,9 @@ func (b *MemoryBatch) ValueSize() int { } var _ database.Batch = &MemoryBatch{} + +func newTestDB(t assert.TestingT) database.Table { + db, err := database.NewPebble("", true) + assert.NoError(t, err) + return database.NewTable(db, "trie") +}