Skip to content

Commit

Permalink
feat: RBNode.reverse (#737)
Browse files Browse the repository at this point in the history
* chore: RBMap.min -> min?

* feat: RBNode.reverse
  • Loading branch information
digama0 authored Apr 17, 2024
1 parent b6bc371 commit 6361c24
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 49 deletions.
6 changes: 6 additions & 0 deletions Std/Classes/Order.lean
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ theorem cmp_congr_right [TransCmp cmp] (yz : cmp y z = .eq) : cmp x y = cmp x z

end TransCmp

instance [inst : OrientedCmp cmp] : OrientedCmp (flip cmp) where
symm _ _ := inst.symm ..

instance [inst : TransCmp cmp] : TransCmp (flip cmp) where
le_trans h1 h2 := inst.le_trans h2 h1

end Std

namespace Ordering
Expand Down
5 changes: 5 additions & 0 deletions Std/Data/List/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,11 @@ theorem getLastD_mem_cons : ∀ (l : List α) (a : α), getLastD l a ∈ a::l
| [], _ => .head ..
| _::_, _ => .tail _ <| getLast_mem _

@[simp] theorem getLast?_reverse (l : List α) : l.reverse.getLast? = l.head? := by cases l <;> simp

@[simp] theorem head?_reverse (l : List α) : l.reverse.head? = l.getLast? := by
rw [← getLast?_reverse, reverse_reverse]

/-! ### dropLast -/

/-! NB: `dropLast` is the specification for `Array.pop`, so theorems about `List.dropLast`
Expand Down
13 changes: 9 additions & 4 deletions Std/Data/RBMap/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ def isOrdered (cmp : α → α → Ordering)

/-- The second half of Okasaki's `balance`, concerning red-red sequences in the right child. -/
@[inline] def balance2 : RBNode α → α → RBNode α → RBNode α
| a, x, node red (node red b y c) z d
| a, x, node red b y (node red c z d) => node red (node black a x b) y (node black c z d)
| a, x, node red b y (node red c z d)
| a, x, node red (node red b y c) z d => node red (node black a x b) y (node black c z d)
| a, x, b => node black a x b

/-- Returns `red` if the node is red, otherwise `black`. (Nil nodes are treated as `black`.) -/
Expand All @@ -284,11 +284,16 @@ Returns `black` if the node is black, otherwise `red`.
| node c .. => c
| _ => red

/-- Change the color of the root to `black`. -/
/-- Changes the color of the root to `black`. -/
def setBlack : RBNode α → RBNode α
| nil => nil
| node _ l v r => node black l v r

/-- `O(n)`. Reverses the ordering of the tree without any rebalancing. -/
@[simp] def reverse : RBNode α → RBNode α
| nil => nil
| node c l v r => node c r.reverse v l.reverse

section Insert

/--
Expand Down Expand Up @@ -897,7 +902,7 @@ variable {α : Type u} {β : Type v} {σ : Type w} {cmp : α → α → Ordering

/-- `O(n)`. Run monadic function `f` on each element of the tree (in increasing order). -/
@[inline] def forM [Monad m] (f : α → β → m PUnit) (t : RBMap α β cmp) : m PUnit :=
t.foldlM (fun _ k v => f k v) ⟨⟩
t.1.forM (fun (a, b) => f a b)

instance : ForIn m (RBMap α β cmp) (α × β) := inferInstanceAs (ForIn _ (RBSet ..) _)

Expand Down
40 changes: 40 additions & 0 deletions Std/Data/RBMap/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ theorem WF.depth_bound {t : RBNode α} (h : t.WF cmp) : t.depth ≤ 2 * (t.size

end depth

@[simp] theorem min?_reverse (t : RBNode α) : t.reverse.min? = t.max? := by
unfold RBNode.max?; split <;> simp [RBNode.min?]
unfold RBNode.min?; rw [min?.match_1.eq_3]
· apply min?_reverse
· simpa [reverse_eq_iff]

@[simp] theorem max?_reverse (t : RBNode α) : t.reverse.max? = t.min? := by
rw [← min?_reverse, reverse_reverse]

@[simp] theorem mem_nil {x} : ¬x ∈ (.nil : RBNode α) := by simp [(·∈·), EMem]
@[simp] theorem mem_node {y c a x b} :
y ∈ (.node c a x b : RBNode α) ↔ y = x ∨ y ∈ a ∨ y ∈ b := by simp [(·∈·), EMem]
Expand Down Expand Up @@ -367,15 +376,39 @@ theorem foldr_cons (t : RBNode α) (l) : t.foldr (·::·) l = t.toList ++ l := b
@[simp] theorem toList_node : (.node c a x b : RBNode α).toList = a.toList ++ x :: b.toList := by
rw [toList, foldr, foldr_cons]; rfl

@[simp] theorem toList_reverse (t : RBNode α) : t.reverse.toList = t.toList.reverse := by
induction t <;> simp [*]

@[simp] theorem mem_toList {t : RBNode α} : x ∈ t.toList ↔ x ∈ t := by
induction t <;> simp [*, or_left_comm]

@[simp] theorem mem_reverse {t : RBNode α} : a ∈ t.reverse ↔ a ∈ t := by rw [← mem_toList]; simp

theorem min?_eq_toList_head? {t : RBNode α} : t.min? = t.toList.head? := by
induction t with
| nil => rfl
| node _ l _ _ ih =>
cases l <;> simp [RBNode.min?, ih]
next ll _ _ => cases toList ll <;> rfl

theorem max?_eq_toList_getLast? {t : RBNode α} : t.max? = t.toList.getLast? := by
rw [← min?_reverse, min?_eq_toList_head?]; simp

theorem foldr_eq_foldr_toList {t : RBNode α} : t.foldr f init = t.toList.foldr f init := by
induction t generalizing init <;> simp [*]

theorem foldl_eq_foldl_toList {t : RBNode α} : t.foldl f init = t.toList.foldl f init := by
induction t generalizing init <;> simp [*]

theorem foldl_reverse {α β : Type _} {t : RBNode α} {f : β → α → β} {init : β} :
t.reverse.foldl f init = t.foldr (flip f) init := by
simp (config := {unfoldPartialApp := true})
[foldr_eq_foldr_toList, foldl_eq_foldl_toList, flip]

theorem foldr_reverse {α β : Type _} {t : RBNode α} {f : α → β → β} {init : β} :
t.reverse.foldr f init = t.foldl (flip f) init :=
foldl_reverse.symm.trans (by simp; rfl)

theorem forM_eq_forM_toList [Monad m] [LawfulMonad m] {t : RBNode α} :
t.forM (m := m) f = t.toList.forM f := by induction t <;> simp [*]

Expand Down Expand Up @@ -467,6 +500,13 @@ theorem Ordered.toList_sorted {t : RBNode α} : t.Ordered cmp → t.toList.Pairw
theorem size_eq {t : RBNode α} : t.size = t.toList.length := by
induction t <;> simp [*, size]; rfl

@[simp] theorem reverse_size (t : RBNode α) : t.reverse.size = t.size := by simp [size_eq]

@[simp] theorem find?_reverse (t : RBNode α) (cut : α → Ordering) :
t.reverse.find? cut = t.find? (cut · |>.swap) := by
induction t <;> simp [*, find?]
cases cut _ <;> simp [Ordering.swap]

namespace Path

attribute [simp] RootOrdered Ordered
Expand Down
131 changes: 86 additions & 45 deletions Std/Data/RBMap/WF.lean
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ theorem All.trivial (H : ∀ {x : α}, p x) : ∀ {t : RBNode α}, t.All p
theorem All_and {t : RBNode α} : t.All (fun a => p a ∧ q a) ↔ t.All p ∧ t.All q := by
induction t <;> simp [*, and_assoc, and_left_comm]

protected theorem cmpLT.flip (h₁ : cmpLT cmp x y) : cmpLT (flip cmp) y x :=
have : TransCmp cmp := inferInstanceAs (TransCmp (flip (flip cmp))); h₁.1

theorem cmpLT.trans (h₁ : cmpLT cmp x y) (h₂ : cmpLT cmp y z) : cmpLT cmp x z :=
⟨TransCmp.lt_trans h₁.1 h₂.1

Expand All @@ -42,6 +45,36 @@ theorem cmpEq.lt_congr_left (H : cmpEq cmp x y) : cmpLT cmp x z ↔ cmpLT cmp y
theorem cmpEq.lt_congr_right (H : cmpEq cmp y z) : cmpLT cmp x y ↔ cmpLT cmp x z :=
fun ⟨h⟩ => ⟨TransCmp.cmp_congr_right H.1 ▸ h⟩, fun ⟨h⟩ => ⟨TransCmp.cmp_congr_right H.1 ▸ h⟩⟩

@[simp] theorem reverse_reverse (t : RBNode α) : t.reverse.reverse = t := by
induction t <;> simp [*]

theorem reverse_eq_iff {t t' : RBNode α} : t.reverse = t' ↔ t = t'.reverse := by
constructor <;> rintro rfl <;> simp

@[simp] theorem reverse_balance1 (l : RBNode α) (v : α) (r : RBNode α) :
(balance1 l v r).reverse = balance2 r.reverse v l.reverse := by
unfold balance1 balance2; split <;> simp
· rw [balance2.match_1.eq_2]; simp [reverse_eq_iff]; intros; solve_by_elim
· rw [balance2.match_1.eq_3] <;> (simp [reverse_eq_iff]; intros; solve_by_elim)

@[simp] theorem reverse_balance2 (l : RBNode α) (v : α) (r : RBNode α) :
(balance2 l v r).reverse = balance1 r.reverse v l.reverse := by
refine Eq.trans ?_ (reverse_reverse _); rw [reverse_balance1]; simp

@[simp] theorem All.reverse {t : RBNode α} : t.reverse.All p ↔ t.All p := by
induction t <;> simp [*, and_comm]

/-- The `reverse` function reverses the ordering invariants. -/
protected theorem Ordered.reverse : ∀ {t : RBNode α}, t.Ordered cmp → t.reverse.Ordered (flip cmp)
| .nil, _ => ⟨⟩
| .node .., ⟨lv, vr, hl, hr⟩ =>
⟨(All.reverse.2 vr).imp cmpLT.flip, (All.reverse.2 lv).imp cmpLT.flip, hr.reverse, hl.reverse⟩

protected theorem Balanced.reverse {t : RBNode α} : t.Balanced c n → t.reverse.Balanced c n
| .nil => .nil
| .black hl hr => .black hr.reverse hl.reverse
| .red hl hr => .red hr.reverse hl.reverse

/-- The `balance1` function preserves the ordering invariants. -/
protected theorem Ordered.balance1 {l : RBNode α} {v : α} {r : RBNode α}
(lv : l.All (cmpLT cmp · v)) (vr : r.All (cmpLT cmp v ·))
Expand All @@ -63,19 +96,17 @@ protected theorem Ordered.balance1 {l : RBNode α} {v : α} {r : RBNode α}
protected theorem Ordered.balance2 {l : RBNode α} {v : α} {r : RBNode α}
(lv : l.All (cmpLT cmp · v)) (vr : r.All (cmpLT cmp v ·))
(hl : l.Ordered cmp) (hr : r.Ordered cmp) : (balance2 l v r).Ordered cmp := by
unfold balance2; split
· next b y c z d =>
have ⟨_, ⟨vy, vb, _⟩, _⟩ := vr; have ⟨⟨yz, _, cz⟩, zd, ⟨by_, yc, hy, hz⟩, hd⟩ := hr
exact ⟨⟨vy, vy.trans_r lv, by_⟩, ⟨yz, yc, yz.trans_l zd⟩, ⟨lv, vb, hl, hy⟩, cz, zd, hz, hd⟩
· next a x b y c _ =>
have ⟨vx, va, _⟩ := vr; have ⟨ax, xy, ha, hy⟩ := hr
exact ⟨⟨vx, vx.trans_r lv, ax⟩, xy, ⟨lv, va, hl, ha⟩, hy⟩
· exact ⟨lv, vr, hl, hr⟩
rw [← reverse_reverse (balance2 ..), reverse_balance2]
exact .reverse <| hr.reverse.balance1
((All.reverse.2 vr).imp cmpLT.flip) ((All.reverse.2 lv).imp cmpLT.flip) hl.reverse

@[simp] theorem balance2_All {l : RBNode α} {v : α} {r : RBNode α} :
(balance2 l v r).All p ↔ p v ∧ l.All p ∧ r.All p := by
unfold balance2; split <;> simp [and_assoc, and_left_comm]

@[simp] theorem reverse_setBlack {t : RBNode α} : (setBlack t).reverse = setBlack t.reverse := by
unfold setBlack; split <;> simp

protected theorem Ordered.setBlack {t : RBNode α} : (setBlack t).Ordered cmp ↔ t.Ordered cmp := by
unfold setBlack; split <;> simp [Ordered]

Expand All @@ -85,9 +116,10 @@ protected theorem Balanced.setBlack : t.Balanced c n → ∃ n', (setBlack t).Ba

theorem setBlack_idem {t : RBNode α} : t.setBlack.setBlack = t.setBlack := by cases t <;> rfl

theorem insert_setBlack {t : RBNode α} :
(t.insert cmp v).setBlack = (t.ins cmp v).setBlack := by
unfold insert; split <;> simp [setBlack_idem]
@[simp] theorem reverse_ins [inst : @OrientedCmp α cmp] {t : RBNode α} :
(ins cmp x t).reverse = ins (flip cmp) x t.reverse := by
induction t <;> [skip; (rename_i c a y b iha ihb; cases c)] <;> simp [ins, flip]
<;> rw [← inst.symm x y] <;> split <;> simp [*, Ordering.swap, iha, ihb]

protected theorem All.ins {x : α} {t : RBNode α}
(h₁ : p x) (h₂ : t.All p) : (ins cmp x t).All p := by
Expand All @@ -112,6 +144,17 @@ protected theorem Ordered.ins : ∀ {t : RBNode α}, t.Ordered cmp → (ins cmp
ay.imp fun ⟨h'⟩ => ⟨(TransCmp.cmp_congr_right h).trans h'⟩,
yb.imp fun ⟨h'⟩ => ⟨(TransCmp.cmp_congr_left h).trans h'⟩, ha, hb⟩)

@[simp] theorem isRed_reverse {t : RBNode α} : t.reverse.isRed = t.isRed := by
cases t <;> simp [isRed]

@[simp] theorem reverse_insert [inst : @OrientedCmp α cmp] {t : RBNode α} :
(insert cmp t x).reverse = insert (flip cmp) t.reverse x := by
simp [insert] <;> split <;> simp

theorem insert_setBlack {t : RBNode α} :
(t.insert cmp v).setBlack = (t.ins cmp v).setBlack := by
unfold insert; split <;> simp [setBlack_idem]

/-- The `insert` function preserves the ordering invariants. -/
protected theorem Ordered.insert (h : t.Ordered cmp) : (insert cmp t v).Ordered cmp := by
unfold RBNode.insert; split <;> simp [Ordered.setBlack, h.ins (x := v)]
Expand Down Expand Up @@ -145,6 +188,10 @@ protected theorem RedRed.imp (h : p → q) : RedRed p t n → RedRed q t n
| .balanced h => .balanced h
| .redred hp ha hb => .redred (h hp) ha hb

protected theorem RedRed.reverse : RedRed p t n → RedRed p t.reverse n
| .balanced h => .balanced h.reverse
| .redred hp ha hb => .redred hp hb.reverse ha.reverse

/-- If `t` has the red-red invariant, then setting the root to black yields a balanced tree. -/
protected theorem RedRed.setBlack : t.RedRed p n → ∃ n', (setBlack t).Balanced black n'
| .balanced h => h.setBlack
Expand All @@ -164,15 +211,8 @@ protected theorem RedRed.balance1 {l : RBNode α} {v : α} {r : RBNode α}

/-- The `balance2` function repairs the balance invariant when the second argument is red-red. -/
protected theorem RedRed.balance2 {l : RBNode α} {v : α} {r : RBNode α}
(hl : l.Balanced c n) (hr : r.RedRed p n) : ∃ c, (balance2 l v r).Balanced c (n + 1) := by
unfold balance2; split
· have .redred _ (.red ha hb) hc := hr; exact ⟨_, .red (.black hl ha) (.black hb hc)⟩
· have .redred _ ha (.red hb hc) := hr; exact ⟨_, .red (.black hl ha) (.black hb hc)⟩
· next H1 H2 => match hr with
| .balanced hr => exact ⟨_, .black hl hr⟩
| .redred _ (c₁ := black) (c₂ := black) ha hb => exact ⟨_, .black hl (.red ha hb)⟩
| .redred _ (c₁ := red) (.red ..) _ => cases H1 _ _ _ _ _ rfl
| .redred _ (c₂ := red) _ (.red ..) => cases H2 _ _ _ _ _ rfl
(hl : l.Balanced c n) (hr : r.RedRed p n) : ∃ c, (balance2 l v r).Balanced c (n + 1) :=
(hr.reverse.balance1 hl.reverse (v := v)).imp fun _ h => by simpa using h.reverse

/-- The `balance1` function does nothing if the first argument is already balanced. -/
theorem balance1_eq {l : RBNode α} {v : α} {r : RBNode α}
Expand All @@ -181,8 +221,8 @@ theorem balance1_eq {l : RBNode α} {v : α} {r : RBNode α}

/-- The `balance2` function does nothing if the second argument is already balanced. -/
theorem balance2_eq {l : RBNode α} {v : α} {r : RBNode α}
(hr : r.Balanced c n) : balance2 l v r = node black l v r := by
unfold balance2; split <;> first | rfl | nomatch hr
(hr : r.Balanced c n) : balance2 l v r = node black l v r :=
(reverse_reverse _).symm.trans <| by simp [balance1_eq hr.reverse]

/-! ## insert -/

Expand Down Expand Up @@ -225,13 +265,28 @@ theorem Balanced.insert {t : RBNode α} (h : t.Balanced c n) :
| _, .balanced h => split <;> [exact ⟨_, h.setBlack⟩; exact ⟨_, _, h⟩]
| _, .redred _ ha hb => have .node red .. := t; exact ⟨_, _, .black ha hb⟩

@[simp] theorem reverse_setRed {t : RBNode α} : (setRed t).reverse = setRed t.reverse := by
unfold setRed; split <;> simp

protected theorem All.setRed {t : RBNode α} (h : t.All p) : (setRed t).All p := by
unfold setRed; split <;> simp_all

/-- The `setRed` function preserves the ordering invariants. -/
protected theorem Ordered.setRed {t : RBNode α} : (setRed t).Ordered cmp ↔ t.Ordered cmp := by
unfold setRed; split <;> simp [Ordered]

@[simp] theorem reverse_balLeft (l : RBNode α) (v : α) (r : RBNode α) :
(balLeft l v r).reverse = balRight r.reverse v l.reverse := by
unfold balLeft balRight; split
· simp
· rw [balLeft.match_2.eq_2 _ _ _ _ (by simp [reverse_eq_iff]; intros; solve_by_elim)]
split <;> simp
rw [balRight.match_1.eq_3] <;> (simp [reverse_eq_iff]; intros; solve_by_elim)

@[simp] theorem reverse_balRight (l : RBNode α) (v : α) (r : RBNode α) :
(balRight l v r).reverse = balLeft r.reverse v l.reverse := by
rw [← reverse_reverse (balLeft ..)]; simp

protected theorem All.balLeft
(hl : l.All p) (hv : p v) (hr : r.All p) : (balLeft l v r).All p := by
unfold balLeft; split <;> (try simp_all); split <;> simp_all [All.setRed]
Expand Down Expand Up @@ -267,38 +322,24 @@ protected theorem Balanced.balLeft (hl : l.RedRed True n) (hr : r.Balanced cr (n
let ⟨c, h⟩ := RedRed.balance2 hb (.redred trivial hc hd); .redred rfl (.black hl ha) h

protected theorem All.balRight
(hl : l.All p) (hv : p v) (hr : r.All p) : (balRight l v r).All p := by
unfold balRight; split <;> (try simp_all); split <;> simp_all [All.setRed]
(hl : l.All p) (hv : p v) (hr : r.All p) : (balRight l v r).All p :=
All.reverse.1 <| reverse_balRight .. ▸ (All.reverse.2 hr).balLeft hv (All.reverse.2 hl)

/-- The `balRight` function preserves the ordering invariants. -/
protected theorem Ordered.balRight {l : RBNode α} {v : α} {r : RBNode α}
(lv : l.All (cmpLT cmp · v)) (vr : r.All (cmpLT cmp v ·))
(hl : l.Ordered cmp) (hr : r.Ordered cmp) : (balRight l v r).Ordered cmp := by
unfold balRight; split
· exact ⟨lv, vr, hl, hr⟩
split
· exact hl.balance1 lv vr hr
· have ⟨yv, _, cv⟩ := lv.2.2; have ⟨ax, ⟨xy, xb, _⟩, ha, by_, yc, hb, hc⟩ := hl
exact ⟨balance1_All.2 ⟨xy, (xy.trans_r ax).setRed, by_⟩, ⟨yv, yc, yv.trans_l vr⟩,
(Ordered.setRed.2 ha).balance1 ax.setRed xb hb, cv, vr, hc, hr⟩
· exact ⟨lv, vr, hl, hr⟩
rw [← reverse_reverse (balRight ..), reverse_balRight]
exact .reverse <| hr.reverse.balLeft
((All.reverse.2 vr).imp cmpLT.flip) ((All.reverse.2 lv).imp cmpLT.flip) hl.reverse

/-- The balancing properties of the `balRight` function. -/
protected theorem Balanced.balRight (hl : l.Balanced cl (n + 1)) (hr : r.RedRed True n) :
(balRight l v r).RedRed (cl = red) (n + 1) := by
unfold balRight; split
· next b y c => exact
let ⟨cb, cc, hb, hc⟩ := hr.of_red
match cl with
| red => .redred rfl hl (.black hb hc)
| black => .balanced (.red hl (.black hb hc))
· next H => exact match hr with
| .redred .. => nomatch H _ _ _ rfl
| .balanced hr => match hl with
| .black hb hc =>
let ⟨c, h⟩ := RedRed.balance1 (.redred trivial hb hc) hr; .balanced h
| .red (.black ha hb) (.black hc hd) =>
let ⟨c, h⟩ := RedRed.balance1 (.redred trivial ha hb) hc; .redred rfl h (.black hd hr)
rw [← reverse_reverse (balRight ..), reverse_balRight]
exact .reverse <| hl.reverse.balLeft hr.reverse

-- note: reverse_append is false!

protected theorem All.append (hl : l.All p) (hr : r.All p) : (append l r).All p := by
unfold append; split <;> try simp [*]
Expand Down

0 comments on commit 6361c24

Please sign in to comment.