From 93c55755499a855a52e46301361be19e49a99463 Mon Sep 17 00:00:00 2001 From: Aleksandr Logunov Date: Tue, 24 Sep 2024 17:39:08 +0400 Subject: [PATCH] feat: full memtrie logic for range retain (#12130) Next step on #12074. Supporting all cases I came up with where nodes restructuring is required. Surprisingly, `squash_node` is enough to call. I only needed to implement trivial cases which weren't possible for single key deletion. I implemented the tests first, and majority of them failed before changing the logic. Each test is comparing naive approach with `retain_multi_range`. ### Notes * I'm a bit scared that I didn't realise the need to squash Extension before. Fortunately, `squash_node` handles that, but if you feel some cases are not covered here, feel free to post suggestions! * Reused + copypasted some Robin' tooling to generate interesting nodes conversions. * Note that test for reading "extra" child node is not required because we always read all children. ### Next steps * A bit more testing * Similar logic for partial trie * Generating intervals needed for resharding * Use that to implement shard switch on chain --- core/store/src/trie/mem/loading.rs | 64 ++--- core/store/src/trie/mem/mod.rs | 2 + core/store/src/trie/mem/nibbles_utils.rs | 46 ++++ core/store/src/trie/mem/resharding.rs | 334 ++++++++++++++++++++--- core/store/src/trie/mem/updating.rs | 142 +++++----- 5 files changed, 435 insertions(+), 153 deletions(-) create mode 100644 core/store/src/trie/mem/nibbles_utils.rs diff --git a/core/store/src/trie/mem/loading.rs b/core/store/src/trie/mem/loading.rs index 882392b1552..b98e3b15d5c 100644 --- a/core/store/src/trie/mem/loading.rs +++ b/core/store/src/trie/mem/loading.rs @@ -195,6 +195,7 @@ mod tests { }; use crate::trie::mem::loading::load_trie_from_flat_state; use crate::trie::mem::lookup::memtrie_lookup; + use crate::trie::mem::nibbles_utils::{all_two_nibble_nibbles, multi_hex_to_nibbles}; use crate::{DBCol, KeyLookupMode, NibbleSlice, ShardTries, Store, Trie, TrieUpdate}; use near_primitives::congestion_info::CongestionInfo; use near_primitives::hash::CryptoHash; @@ -300,18 +301,6 @@ mod tests { check_maybe_parallelize(keys, true); } - fn nibbles(hex: &str) -> Vec { - if hex == "_" { - return vec![]; - } - assert!(hex.len() % 2 == 0); - hex::decode(hex).unwrap() - } - - fn all_nibbles(hexes: &str) -> Vec> { - hexes.split_whitespace().map(|x| nibbles(x)).collect() - } - #[test] fn test_memtrie_empty() { check(vec![]); @@ -319,61 +308,42 @@ mod tests { #[test] fn test_memtrie_root_is_leaf() { - check(all_nibbles("_")); - check(all_nibbles("00")); - check(all_nibbles("01")); - check(all_nibbles("ff")); - check(all_nibbles("0123456789abcdef")); + check(multi_hex_to_nibbles("_")); + check(multi_hex_to_nibbles("00")); + check(multi_hex_to_nibbles("01")); + check(multi_hex_to_nibbles("ff")); + check(multi_hex_to_nibbles("0123456789abcdef")); } #[test] fn test_memtrie_root_is_extension() { - check(all_nibbles("1234 13 14")); - check(all_nibbles("12345678 1234abcd")); + check(multi_hex_to_nibbles("1234 13 14")); + check(multi_hex_to_nibbles("12345678 1234abcd")); } #[test] fn test_memtrie_root_is_branch() { - check(all_nibbles("11 22")); - check(all_nibbles("12345678 22345678 32345678")); - check(all_nibbles("11 22 33 44 55 66 77 88 99 aa bb cc dd ee ff")); + check(multi_hex_to_nibbles("11 22")); + check(multi_hex_to_nibbles("12345678 22345678 32345678")); + check(multi_hex_to_nibbles("11 22 33 44 55 66 77 88 99 aa bb cc dd ee ff")); } #[test] fn test_memtrie_root_is_branch_with_value() { - check(all_nibbles("_ 11")); + check(multi_hex_to_nibbles("_ 11")); } #[test] fn test_memtrie_prefix_patterns() { - check(all_nibbles("10 21 2210 2221 222210 222221 22222210 22222221")); - check(all_nibbles("11111112 11111120 111112 111120 1112 1120 12 20")); - check(all_nibbles("11 1111 111111 11111111 1111111111 111111111111")); - check(all_nibbles("_ 11 1111 111111 11111111 1111111111 111111111111")); + check(multi_hex_to_nibbles("10 21 2210 2221 222210 222221 22222210 22222221")); + check(multi_hex_to_nibbles("11111112 11111120 111112 111120 1112 1120 12 20")); + check(multi_hex_to_nibbles("11 1111 111111 11111111 1111111111 111111111111")); + check(multi_hex_to_nibbles("_ 11 1111 111111 11111111 1111111111 111111111111")); } #[test] fn test_full_16ary_trees() { - check(all_nibbles( - " - 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f - 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f - 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f - 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f - 40 41 42 43 44 45 46 47 48 49 4a 4b 4c 4d 4e 4f - 50 51 52 53 54 55 56 57 58 59 5a 5b 5c 5d 5e 5f - 60 61 62 63 64 65 66 67 68 69 6a 6b 6c 6d 6e 6f - 70 71 72 73 74 75 76 77 78 79 7a 7b 7c 7d 7e 7f - 80 81 82 83 84 85 86 87 88 89 8a 8b 8c 8d 8e 8f - 90 91 92 93 94 95 96 97 98 99 9a 9b 9c 9d 9e 9f - a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 aa ab ac ad ae af - b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 ba bb bc bd be bf - c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 ca cb cc cd ce cf - d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 da db dc dd de df - e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 ea eb ec ed ee ef - f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 fa fb fc fd fe ff - ", - )) + check(all_two_nibble_nibbles()) } #[test] diff --git a/core/store/src/trie/mem/mod.rs b/core/store/src/trie/mem/mod.rs index 03fa125e495..f04381f3004 100644 --- a/core/store/src/trie/mem/mod.rs +++ b/core/store/src/trie/mem/mod.rs @@ -7,6 +7,8 @@ pub mod loading; mod lookup; pub mod mem_tries; pub mod metrics; +#[cfg(test)] +pub(crate) mod nibbles_utils; pub mod node; mod parallel_loader; pub mod resharding; diff --git a/core/store/src/trie/mem/nibbles_utils.rs b/core/store/src/trie/mem/nibbles_utils.rs new file mode 100644 index 00000000000..c0f5d168dac --- /dev/null +++ b/core/store/src/trie/mem/nibbles_utils.rs @@ -0,0 +1,46 @@ +/// Utilties for generating vectors of nibbles from human-readable strings. +/// +/// Input for a single vector is a hex string, e.g. 5da3593f. +/// It has even length, as tries support only keys in bytes, thus keys of +/// odd nibble length do not occur. +/// Each symbol is interpreted as a nibble (half-byte). +/// Result is a vector of decoded hexes as nibbles, e.g. +/// [5, 13, 10, 3, 5, 9, 3, 15]. + +pub(crate) fn hex_to_nibbles(hex: &str) -> Vec { + if hex == "_" { + return vec![]; + } + assert!(hex.len() % 2 == 0); + hex::decode(hex).unwrap() +} + +/// Converts a string of hex strings separated by whitespaces into a vector of +/// vectors of nibbles. For example, "01 02 10" is converted to +/// [[0, 1], [0, 2], [1, 0]]. +pub(crate) fn multi_hex_to_nibbles(hexes: &str) -> Vec> { + hexes.split_whitespace().map(|x| hex_to_nibbles(x)).collect() +} + +pub(crate) fn all_two_nibble_nibbles() -> Vec> { + multi_hex_to_nibbles( + " + 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f + 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f + 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f + 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f + 40 41 42 43 44 45 46 47 48 49 4a 4b 4c 4d 4e 4f + 50 51 52 53 54 55 56 57 58 59 5a 5b 5c 5d 5e 5f + 60 61 62 63 64 65 66 67 68 69 6a 6b 6c 6d 6e 6f + 70 71 72 73 74 75 76 77 78 79 7a 7b 7c 7d 7e 7f + 80 81 82 83 84 85 86 87 88 89 8a 8b 8c 8d 8e 8f + 90 91 92 93 94 95 96 97 98 99 9a 9b 9c 9d 9e 9f + a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 aa ab ac ad ae af + b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 ba bb bc bd be bf + c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 ca cb cc cd ce cf + d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 da db dc dd de df + e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 ea eb ec ed ee ef + f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 fa fb fc fd fe ff + ", + ) +} diff --git a/core/store/src/trie/mem/resharding.rs b/core/store/src/trie/mem/resharding.rs index 281e828abba..6bc1b98ff80 100644 --- a/core/store/src/trie/mem/resharding.rs +++ b/core/store/src/trie/mem/resharding.rs @@ -103,6 +103,7 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { } else { self.place_node(node_id, UpdatedMemTrieNode::Leaf { extension, value }); } + return; } UpdatedMemTrieNode::Branch { mut children, mut value } => { if !intervals_nibbles.iter().any(|interval| interval.contains(&key_nibbles)) { @@ -128,9 +129,6 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { } } - // TODO(#12074): squash the branch if needed. Consider reusing - // `squash_nodes`. - self.place_node(node_id, UpdatedMemTrieNode::Branch { children, value }); } UpdatedMemTrieNode::Extension { extension, child } => { @@ -140,19 +138,16 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { let child_key = [key_nibbles, extension_nibbles].concat(); self.retain_multi_range_recursive(new_child_id, child_key, intervals_nibbles); - if self.updated_nodes[new_child_id] == Some(UpdatedMemTrieNode::Empty) { - self.place_node(node_id, UpdatedMemTrieNode::Empty); - } else { - self.place_node( - node_id, - UpdatedMemTrieNode::Extension { - extension, - child: OldOrUpdatedNodeId::Updated(new_child_id), - }, - ); - } + let node = UpdatedMemTrieNode::Extension { + extension, + child: OldOrUpdatedNodeId::Updated(new_child_id), + }; + self.place_node(node_id, node); } } + + // We may need to change node type to keep the trie structure unique. + self.squash_node(node_id); } } @@ -190,28 +185,86 @@ fn retain_decision(key: &[u8], intervals: &[Range>]) -> RetainDecision { } // TODO(#12074): tests for -// - multiple retain ranges -// - result is empty, or no changes are made -// - removing keys one-by-one gives the same result as corresponding range retain // - `retain_split_shard` API -// - all results of squashing branch -// - checking not accessing not-inlined nodes +// - checking not accessing not-inlined values // - proof correctness #[cfg(test)] mod tests { + use rand::rngs::StdRng; + use rand::seq::SliceRandom; + use rand::{Rng, SeedableRng}; + use std::ops::Range; use std::sync::Arc; use itertools::Itertools; use near_primitives::{shard_layout::ShardUId, types::StateRoot}; use crate::{ + test_utils::TestTriesBuilder, trie::{ - mem::{iter::MemTrieIterator, mem_tries::MemTries}, + mem::{ + iter::MemTrieIterator, + mem_tries::MemTries, + nibbles_utils::{all_two_nibble_nibbles, hex_to_nibbles, multi_hex_to_nibbles}, + }, trie_storage::TrieMemoryPartialStorage, }, Trie, }; + // Logic for a single test. + // Creates trie from initial entries, applies retain multi range to it and + // compares the result with naive approach. + fn run(initial_entries: Vec<(Vec, Vec)>, retain_multi_ranges: Vec>>) { + // Generate naive result and state root. + let mut retain_result_naive = initial_entries + .iter() + .filter(|&(key, _)| retain_multi_ranges.iter().any(|range| range.contains(key))) + .cloned() + .collect_vec(); + retain_result_naive.sort(); + + let shard_tries = TestTriesBuilder::new().build(); + let changes = retain_result_naive + .iter() + .map(|(key, value)| (key.clone(), Some(value.clone()))) + .collect_vec(); + let expected_state_root = crate::test_utils::test_populate_trie( + &shard_tries, + &Trie::EMPTY_ROOT, + ShardUId::single_shard(), + changes, + ); + + let mut memtries = MemTries::new(ShardUId::single_shard()); + let mut update = memtries.update(Trie::EMPTY_ROOT, false).unwrap(); + for (key, value) in initial_entries { + update.insert(&key, value); + } + let memtrie_changes = update.to_mem_trie_changes_only(); + let state_root = memtries.apply_memtrie_changes(0, &memtrie_changes); + + let update = memtries.update(state_root, true).unwrap(); + let (mut trie_changes, _) = update.retain_multi_range(&retain_multi_ranges); + let memtrie_changes = trie_changes.mem_trie_changes.take().unwrap(); + let new_state_root = memtries.apply_memtrie_changes(1, &memtrie_changes); + + let entries = if new_state_root != StateRoot::default() { + let state_root_ptr = memtries.get_root(&new_state_root).unwrap(); + let trie = + Trie::new(Arc::new(TrieMemoryPartialStorage::default()), new_state_root, None); + MemTrieIterator::new(Some(state_root_ptr), &trie).map(|e| e.unwrap()).collect_vec() + } else { + vec![] + }; + + // Check entries first to provide more context in case of failure. + assert_eq!(entries, retain_result_naive); + + // Check state root, because it must be unique. + assert_eq!(new_state_root, expected_state_root); + } + #[test] /// Applies single range retain to the trie and checks the result. fn test_retain_single_range() { @@ -222,27 +275,234 @@ mod tests { (b"david".to_vec(), vec![4]), ]; let retain_range = b"amy".to_vec()..b"david".to_vec(); - let retain_result = vec![(b"bob".to_vec(), vec![2]), (b"charlie".to_vec(), vec![3])]; + run(initial_entries, vec![retain_range]); + } - let mut memtries = MemTries::new(ShardUId::single_shard()); - let empty_state_root = StateRoot::default(); - let mut update = memtries.update(empty_state_root, false).unwrap(); - for (key, value) in initial_entries { - update.insert(&key, value); + #[test] + /// Applies two ranges retain to the trie and checks the result. + fn test_retain_two_ranges() { + let initial_entries = vec![ + (b"alice".to_vec(), vec![1]), + (b"bob".to_vec(), vec![2]), + (b"charlie".to_vec(), vec![3]), + (b"david".to_vec(), vec![4]), + (b"edward".to_vec(), vec![5]), + (b"frank".to_vec(), vec![6]), + ]; + let retain_ranges = + vec![b"bill".to_vec()..b"bowl".to_vec(), b"daaa".to_vec()..b"france".to_vec()]; + run(initial_entries, retain_ranges); + } + + #[test] + /// Checks case when no keys are retained. + fn test_empty_result() { + let initial_entries = vec![ + (b"alice".to_vec(), vec![1]), + (b"miles".to_vec(), vec![2]), + (b"willy".to_vec(), vec![3]), + ]; + let retain_ranges = vec![b"ellie".to_vec()..b"key".to_vec()]; + run(initial_entries, retain_ranges); + } + + #[test] + /// Checks case when all keys are retained. + fn test_full_result() { + let initial_entries = vec![ + (b"f23".to_vec(), vec![1]), + (b"f32".to_vec(), vec![2]), + (b"f44".to_vec(), vec![3]), + ]; + let retain_ranges = vec![b"f11".to_vec()..b"f45".to_vec()]; + run(initial_entries, retain_ranges); + } + + #[test] + /// Checks empty trie. + fn test_empty_trie() { + let initial_entries = vec![]; + let retain_ranges = vec![b"bar".to_vec()..b"foo".to_vec()]; + run(initial_entries, retain_ranges); + } + + #[test] + /// Checks case when all keys are prefixes of some string. + fn test_prefixes() { + let initial_entries = vec![ + (b"a".to_vec(), vec![1]), + (b"aa".to_vec(), vec![2]), + (b"aaa".to_vec(), vec![3]), + (b"aaaa".to_vec(), vec![1]), + (b"aaaaa".to_vec(), vec![2]), + (b"aaaaaa".to_vec(), vec![3]), + ]; + let retain_ranges = vec![b"aa".to_vec()..b"aaaaa".to_vec()]; + run(initial_entries, retain_ranges); + } + + #[test] + /// Checks case when branch and extension nodes are explored but completely + /// removed. + fn test_descend_and_remove() { + let keys = multi_hex_to_nibbles("00 0000 0011"); + let initial_entries = keys.into_iter().map(|key| (key, vec![1])).collect_vec(); + let retain_ranges = vec![hex_to_nibbles("0001")..hex_to_nibbles("0010")]; + run(initial_entries, retain_ranges); + } + + #[test] + /// Checks case when branch is converted to leaf. + fn test_branch_to_leaf() { + let keys = multi_hex_to_nibbles("ba bc ca"); + let initial_entries = keys.into_iter().map(|key| (key, vec![1])).collect_vec(); + let retain_ranges = vec![hex_to_nibbles("bc")..hex_to_nibbles("be")]; + run(initial_entries, retain_ranges); + } + + #[test] + /// Checks case when branch with value is converted to leaf. + fn test_branch_with_value_to_leaf() { + let keys = multi_hex_to_nibbles("d4 d4a3 d4b9 d5 e6"); + let initial_entries = keys.into_iter().map(|key| (key, vec![1])).collect_vec(); + let retain_ranges = vec![hex_to_nibbles("d4")..hex_to_nibbles("d4a0")]; + run(initial_entries, retain_ranges); + } + + #[test] + /// Checks case when branch without value is converted to extension. + fn test_branch_to_extension() { + let keys = multi_hex_to_nibbles("21 2200 2201"); + let initial_entries = keys.into_iter().map(|key| (key, vec![1])).collect_vec(); + let retain_ranges = vec![hex_to_nibbles("2200")..hex_to_nibbles("2202")]; + run(initial_entries, retain_ranges); + } + + #[test] + /// Checks case when result is a single key, and all nodes on the way are + /// squashed, in particular, extension nodes are joined into one. + fn test_extend_extensions() { + let keys = multi_hex_to_nibbles("dd d0 d1 dddd00 dddd01 dddddd"); + let initial_entries = keys.into_iter().map(|key| (key, vec![1])).collect_vec(); + let retain_ranges = vec![hex_to_nibbles("dddddd")..hex_to_nibbles("ddddde")]; + run(initial_entries, retain_ranges); + } + + #[test] + /// Checks case when branch is visited but not restructured. + fn test_branch_not_restructured() { + let keys = multi_hex_to_nibbles("60 61 62 70"); + let initial_entries = keys.into_iter().map(|key| (key, vec![1])).collect_vec(); + let retain_ranges = vec![hex_to_nibbles("61")..hex_to_nibbles("71")]; + run(initial_entries, retain_ranges); + } + + #[test] + /// Checks case with branching on every step but when only prefixes of some + /// key are retained. + fn test_branch_prefixes() { + let keys = multi_hex_to_nibbles( + " + 00 + 10 + 01 + 0000 + 0010 + 0001 + 000000 + 000010 + 000001 + 00000000 + 00000010 + 00000001 + 0000000000 + 0000000010 + 0000000001 + 000000000000 + 000000000010 + 000000000011 + ", + ); + let initial_entries = keys.into_iter().map(|key| (key, vec![1])).collect_vec(); + let retain_ranges = vec![hex_to_nibbles("0000")..hex_to_nibbles("00000000")]; + run(initial_entries, retain_ranges); + } + + #[test] + /// Checks multiple ranges retain on full 16-ary tree. + fn test_full_16ary() { + let keys = all_two_nibble_nibbles(); + let initial_entries = keys.into_iter().map(|key| (key, vec![1])).collect_vec(); + let retain_ranges = vec![ + hex_to_nibbles("0f")..hex_to_nibbles("10"), + hex_to_nibbles("20")..hex_to_nibbles("2fff"), + hex_to_nibbles("55")..hex_to_nibbles("56"), + hex_to_nibbles("a5aa")..hex_to_nibbles("c3"), + hex_to_nibbles("c3")..hex_to_nibbles("c5"), + hex_to_nibbles("c8")..hex_to_nibbles("ca"), + hex_to_nibbles("cb")..hex_to_nibbles("cc"), + ]; + run(initial_entries, retain_ranges); + } + + fn random_key(max_key_len: usize, rng: &mut StdRng) -> Vec { + let key_len = rng.gen_range(0..=max_key_len); + let mut key = Vec::new(); + for _ in 0..key_len { + let byte: u8 = rng.gen(); + key.push(byte); } - let memtrie_changes = update.to_mem_trie_changes_only(); - let state_root = memtries.apply_memtrie_changes(0, &memtrie_changes); + key + } - let update = memtries.update(state_root, true).unwrap(); - let (mut trie_changes, _) = update.retain_multi_range(&[retain_range]); - let memtrie_changes = trie_changes.mem_trie_changes.take().unwrap(); - let new_state_root = memtries.apply_memtrie_changes(1, &memtrie_changes); + fn check_random(max_key_len: usize, max_keys_count: usize, test_count: usize) { + let mut rng = StdRng::seed_from_u64(442); + for _ in 0..test_count { + let key_cnt = rng.gen_range(1..=max_keys_count); + let mut keys = Vec::new(); + for _ in 0..key_cnt { + keys.push(random_key(max_key_len, &mut rng)); + } + keys.sort(); + keys.dedup(); + keys.shuffle(&mut rng); - let state_root_ptr = memtries.get_root(&new_state_root).unwrap(); - let trie = Trie::new(Arc::new(TrieMemoryPartialStorage::default()), new_state_root, None); - let entries = - MemTrieIterator::new(Some(state_root_ptr), &trie).map(|e| e.unwrap()).collect_vec(); + let mut boundary_left = random_key(max_key_len, &mut rng); + let mut boundary_right = random_key(max_key_len, &mut rng); + if boundary_left == boundary_right { + continue; + } + if boundary_left > boundary_right { + std::mem::swap(&mut boundary_left, &mut boundary_right); + } + let initial_entries = keys.into_iter().map(|key| (key, vec![1])).collect_vec(); + let retain_ranges = vec![boundary_left..boundary_right]; + run(initial_entries, retain_ranges); + } + } - assert_eq!(entries, retain_result); + #[test] + fn test_rand_small() { + check_random(3, 20, 10); + } + + #[test] + fn test_rand_many_keys() { + check_random(5, 1000, 10); + } + + #[test] + fn test_rand_long_keys() { + check_random(20, 100, 10); + } + + #[test] + fn test_rand_long_long_keys() { + check_random(1000, 1000, 1); + } + + #[test] + fn test_rand_large_data() { + check_random(32, 100000, 1); } } diff --git a/core/store/src/trie/mem/updating.rs b/core/store/src/trie/mem/updating.rs index a833ddca487..f6718da6020 100644 --- a/core/store/src/trie/mem/updating.rs +++ b/core/store/src/trie/mem/updating.rs @@ -508,76 +508,86 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { } } - self.squash_nodes(path); + // We may need to change node type to keep the trie structure unique. + for node_id in path.into_iter().rev() { + self.squash_node(node_id); + } } - /// As we delete a key, it may be necessary to change the types of the nodes - /// along the path from the root to the key, in order to keep the trie - /// structure unique. For example, if a branch node has only one child and - /// no value, it must be converted to an extension node. If that extension - /// node also has a parent that is an extension node, they must be combined - /// into a single extension node. This function takes care of all these - /// cases. - fn squash_nodes(&mut self, path: Vec) { - // Correctness can be shown by induction on path prefix. - for node_id in path.into_iter().rev() { - let node = self.take_node(node_id); - match node { - UpdatedMemTrieNode::Empty => { - // Empty node will be absorbed by its parent node, so defer that. - self.place_node(node_id, UpdatedMemTrieNode::Empty); - } - UpdatedMemTrieNode::Leaf { .. } => { - // It's impossible that we would squash a leaf node, because if we - // had deleted a leaf it would become Empty instead. - unreachable!(); - } - UpdatedMemTrieNode::Branch { mut children, value } => { - // Remove any children that are now empty (removed). - for child in children.iter_mut() { - if let Some(OldOrUpdatedNodeId::Updated(child_node_id)) = child { - if let UpdatedMemTrieNode::Empty = - self.updated_nodes[*child_node_id as usize].as_ref().unwrap() - { - *child = None; - } + /// When we delete keys, it may be necessary to change types of some nodes, + /// in order to keep the trie structure unique. For example, if a branch + /// had two children, but after deletion ended up with one child and no + /// value, it must be converted to an extension node. Or, if an extension + /// node ended up having a child which is also an extension node, they must + /// be combined into a single extension node. This function takes care of + /// all these cases for a single node. + /// + /// To restructure trie correctly, this function must be called in + /// post-order traversal for every modified node. It may be proven by + /// induction on subtrees. + /// For single key removal, it is called for every node on the path from + /// the leaf to the root. + /// For range removal, it is called in the end of recursive range removal + /// function, which is the definition of post-order traversal. + pub(crate) fn squash_node(&mut self, node_id: UpdatedMemTrieNodeId) { + let node = self.take_node(node_id); + match node { + UpdatedMemTrieNode::Empty => { + // Empty node will be absorbed by its parent node, so defer that. + self.place_node(node_id, UpdatedMemTrieNode::Empty); + } + UpdatedMemTrieNode::Leaf { .. } => { + // It's impossible that we would squash a leaf node, because if we + // had deleted a leaf it would become Empty instead. + unreachable!(); + } + UpdatedMemTrieNode::Branch { mut children, value } => { + // Remove any children that are now empty (removed). + for child in children.iter_mut() { + if let Some(OldOrUpdatedNodeId::Updated(child_node_id)) = child { + if let UpdatedMemTrieNode::Empty = + self.updated_nodes[*child_node_id as usize].as_ref().unwrap() + { + *child = None; } } - let num_children = children.iter().filter(|node| node.is_some()).count(); - if num_children == 0 { - // Branch with zero children becomes leaf. It's not possible for it to - // become empty, because a branch had at least two children or a value - // and at least one child, so deleting a single value could not - // eliminate both of them. - let leaf_node = UpdatedMemTrieNode::Leaf { - extension: NibbleSlice::new(&[]) - .encoded(true) - .into_vec() - .into_boxed_slice(), - value: value.unwrap(), - }; - self.place_node(node_id, leaf_node); - } else if num_children == 1 && value.is_none() { - // Branch with 1 child but no value becomes extension. - let (idx, child) = children - .into_iter() - .enumerate() - .find_map(|(idx, node)| node.map(|node| (idx, node))) - .unwrap(); - let extension = NibbleSlice::new(&[(idx << 4) as u8]) - .encoded_leftmost(1, false) - .into_vec() - .into_boxed_slice(); - self.extend_child(node_id, extension, child); - } else { - // Branch with more than 1 children stays branch. - self.place_node(node_id, UpdatedMemTrieNode::Branch { children, value }); - } } - UpdatedMemTrieNode::Extension { extension, child } => { + let num_children = children.iter().filter(|node| node.is_some()).count(); + if num_children == 0 { + match value { + None => self.place_node(node_id, UpdatedMemTrieNode::Empty), + Some(value) => { + // Branch with zero children and a value becomes leaf. + let leaf_node = UpdatedMemTrieNode::Leaf { + extension: NibbleSlice::new(&[]) + .encoded(true) + .into_vec() + .into_boxed_slice(), + value, + }; + self.place_node(node_id, leaf_node); + } + } + } else if num_children == 1 && value.is_none() { + // Branch with 1 child but no value becomes extension. + let (idx, child) = children + .into_iter() + .enumerate() + .find_map(|(idx, node)| node.map(|node| (idx, node))) + .unwrap(); + let extension = NibbleSlice::new(&[(idx << 4) as u8]) + .encoded_leftmost(1, false) + .into_vec() + .into_boxed_slice(); self.extend_child(node_id, extension, child); + } else { + // Branch with more than 1 children stays branch. + self.place_node(node_id, UpdatedMemTrieNode::Branch { children, value }); } } + UpdatedMemTrieNode::Extension { extension, child } => { + self.extend_child(node_id, extension, child); + } } } @@ -596,13 +606,7 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { let child_node = self.take_node(child_id); match child_node { UpdatedMemTrieNode::Empty => { - // This case is not possible. In a trie in general, an extension - // node could only have a child that is a branch (possibly with - // value) node. But a branch node either has a value and at least - // one child, or has at least two children. In either case, it's - // impossible for a single deletion to cause the child to become - // empty. - unreachable!(); + self.place_node(node_id, UpdatedMemTrieNode::Empty); } // If the child is a leaf (which could happen if a branch node lost // all its branches and only had a value left, or is left with only