diff --git a/map/map.mbt b/map/map.mbt index a57ac56c3..ffbb73aa6 100644 --- a/map/map.mbt +++ b/map/map.mbt @@ -167,6 +167,42 @@ pub fn iter[K : Compare, V](self : Map[K, V], f : (K, V) -> Unit) -> Unit { } } +/// Iterate over the key-value pairs with index. +pub fn iteri[K, V](self : Map[K, V], f : (Int, K, V) -> Unit) -> Unit { + fn do_iteri(m : Map[K, V], f, i) { + match m { + Empty => () + Tree(k, v, _, l, r) => { + do_iteri(l, f, i) + f(size(l) + i, k, v) + do_iteri(r, f, size(l) + i + 1) + } + } + } + + do_iteri(self, f, 0) +} + +/// Maps over the values in the map. +pub fn map[K , X, Y](self : Map[K, X], f : (X) -> Y) -> Map[K, Y] { + match self { + Empty => Empty + Tree(k, v, s, l, r) => Tree(k, f(v), s, map(l, f), map(r, f)) + } +} + +/// Maps over the key-value pairs in the map. +pub fn map_with_key[K , X, Y]( + self : Map[K, X], + f : (K, X) -> Y +) -> Map[K, Y] { + match self { + Empty => Empty + Tree(k, v, s, l, r) => + Tree(k, f(k, v), s, map_with_key(l, f), map_with_key(r, f)) + } +} + /// The ratio between the sizes of the left and right subtrees. let ratio = 5 @@ -339,6 +375,31 @@ test "iter" { @assertion.assert_eq(s, "(0,zero)(1,one)(2,two)(3,three)(8,eight)")? } +test "iteri" { + let m = Map::[(3, "three"), (8, "eight"), (1, "one"), (2, "two"), (0, "zero")] + let mut s = "" + m.iteri(fn(i, k, v) { s = s + "(\(i),\(k),\(v))" }) + @assertion.assert_eq(s, "(0,0,zero)(1,1,one)(2,2,two)(3,3,three)(4,8,eight)")? +} + +test "map" { + let m = Map::[(3, "three"), (8, "eight"), (1, "one"), (2, "two"), (0, "zero")] + let n = m.map(fn(v) { v + "X" }) + @assertion.assert_eq( + n.debug_tree(), + "(3,threeX,(1,oneX,(0,zeroX,_,_),(2,twoX,_,_)),(8,eightX,_,_))", + )? +} + +test "map_with_key" { + let m = Map::[(3, "three"), (8, "eight"), (1, "one"), (2, "two"), (0, "zero")] + let n = m.map_with_key(fn(k, v) { "\(k)-\(v)" }) + @assertion.assert_eq( + n.debug_tree(), + "(3,3-three,(1,1-one,(0,0-zero,_,_),(2,2-two,_,_)),(8,8-eight,_,_))", + )? +} + test "singleton" { let m = singleton(3, "three") @assertion.assert_eq(m.debug_tree(), "(3,three,_,_)")?