diff --git a/trie.go b/trie.go index 6fd36e5..0374c69 100644 --- a/trie.go +++ b/trie.go @@ -62,14 +62,18 @@ type Node[K cmp.Ordered, V any] interface { // Parent returns the parent of this node Parent() Node[K, V] - // Ancestors returns a sequence of ancestors of this node + // Ancestors returns a sequence of ancestors of this node. + // The first element is the root element, progressing all the way + // up to the parent of this node. Ancestors() iter.Seq[Node[K, V]] } // New creates a new Trie object. func New[L any, K cmp.Ordered, V any](tokenizer Tokenizer[L, K]) *Trie[L, K, V] { + node := newNode[K, V]() + node.isRoot = true return &Trie[L, K, V]{ - root: newNode[K, V](), + root: node, tokenizer: tokenizer, } } @@ -220,7 +224,7 @@ func put[K cmp.Ordered, V any](root Node[K, V], tokens []K, value V) { if cur == nil { newRoot = newNode } else { - cur.children = append(cur.children, newNode) + cur.AddChild(newNode) } cur = newNode } @@ -234,6 +238,7 @@ type node[K cmp.Ordered, V any] struct { mu sync.RWMutex key K value V + isRoot bool children []*node[K, V] parent *node[K, V] } @@ -255,13 +260,24 @@ func (n *node[K, V]) Parent() Node[K, V] { } func (n *node[K, V]) Ancestors() iter.Seq[Node[K, V]] { + var ancestors []*node[K, V] + for { + n = n.parent + if n == nil { + break + } + ancestors = append(ancestors, n) + } + return func(yield func(Node[K, V]) bool) { - cur := n.parent - for cur != nil { - if !yield(cur) { - break + for len(ancestors) > 0 { + cur := ancestors[len(ancestors)-1] + if cur != nil && !cur.isRoot { + if !yield(cur) { + break + } } - cur = cur.parent + ancestors = ancestors[:len(ancestors)-1] } } }