diff --git a/Batteries/Data/Fin/Basic.lean b/Batteries/Data/Fin/Basic.lean index 7323f70eee..ef5a1f57f0 100644 --- a/Batteries/Data/Fin/Basic.lean +++ b/Batteries/Data/Fin/Basic.lean @@ -18,6 +18,34 @@ alias enum := Array.finRange @[deprecated (since := "2024-11-15")] alias list := List.finRange +/-- Dependent version of `Fin.foldr`. -/ +@[inline] def dfoldr (n : Nat) (α : Fin (n + 1) → Sort _) + (f : ∀ (i : Fin n), α i.succ → α i.castSucc) (init : α (last n)) : α 0 := + loop n (Nat.lt_succ_self n) init where + /-- Inner loop for `Fin.dfoldr`. + `Fin.dfoldr.loop n α f i h x = f 0 (f 1 (... (f i x)))` -/ + @[specialize] loop (i : Nat) (h : i < n + 1) (x : α ⟨i, h⟩) : α 0 := + match i with + | i + 1 => loop i (Nat.lt_of_succ_lt h) (f ⟨i, Nat.lt_of_succ_lt_succ h⟩ x) + | 0 => x + +/-- Dependent version of `Fin.foldrM`. -/ +@[inline] def dfoldrM [Monad m] (n : Nat) (α : Fin (n + 1) → Sort _) + (f : ∀ (i : Fin n), α i.succ → m (α i.castSucc)) (init : α (last n)) : m (α 0) := + dfoldr n (fun i => m (α i)) (fun i x => x >>= f i) (pure init) + +/-- Dependent version of `Fin.foldl`. -/ +@[inline] def dfoldl (n : Nat) (α : Fin (n + 1) → Sort _) + (f : ∀ (i : Fin n), α i.castSucc → α i.succ) (init : α 0) : α (last n) := + loop 0 (Nat.zero_lt_succ n) init where + /-- Inner loop for `Fin.dfoldl`. `Fin.dfoldl.loop n α f i h x = f n (f (n-1) (... (f i x)))` -/ + @[semireducible, specialize] loop (i : Nat) (h : i < n + 1) (x : α ⟨i, h⟩) : α (last n) := + if h' : i < n then + loop (i + 1) (Nat.succ_lt_succ h') (f ⟨i, h'⟩ x) + else + haveI : ⟨i, h⟩ = last n := by ext; simp; omega + _root_.cast (congrArg α this) x + /-- Dependent version of `Fin.foldlM`. -/ @[inline] def dfoldlM [Monad m] (n : Nat) (α : Fin (n + 1) → Sort _) (f : ∀ (i : Fin n), α i.castSucc → m (α i.succ)) (init : α 0) : m (α (last n)) := @@ -31,54 +59,9 @@ alias list := List.finRange pure xₙ ``` -/ - @[semireducible] loop (i : Nat) (h : i < n + 1) (x : α ⟨i, h⟩) : m (α (last n)) := + @[semireducible, specialize] loop (i : Nat) (h : i < n + 1) (x : α ⟨i, h⟩) : m (α (last n)) := if h' : i < n then (f ⟨i, h'⟩ x) >>= loop (i + 1) (Nat.succ_lt_succ h') else haveI : ⟨i, h⟩ = last n := by ext; simp; omega _root_.cast (congrArg (fun i => m (α i)) this) (pure x) - -/-- Dependent version of `Fin.foldrM`. -/ -@[inline] def dfoldrM [Monad m] (n : Nat) (α : Fin (n + 1) → Sort _) - (f : ∀ (i : Fin n), α i.succ → m (α i.castSucc)) (init : α (last n)) : m (α 0) := - loop n (Nat.lt_succ_self n) init where - /-- Inner loop for `Fin.foldRevM`. - ``` - Fin.foldRevM.loop n α f i h xᵢ = do - let xᵢ₋₁ ← f (i+1) xᵢ - ... - let x₁ ← f 1 x₂ - let x₀ ← f 0 x₁ - pure x₀ - ``` - -/ - @[semireducible] loop (i : Nat) (h : i < n + 1) (x : α ⟨i, h⟩) : m (α 0) := - if h' : i > 0 then - (f ⟨i - 1, by omega⟩ (by simpa [Nat.sub_one_add_one_eq_of_pos h'] using x)) - >>= loop (i - 1) (by omega) - else - haveI : ⟨i, h⟩ = 0 := by ext; simp; omega - _root_.cast (congrArg (fun i => m (α i)) this) (pure x) - -/-- Dependent version of `Fin.foldl`. -/ -@[inline] def dfoldl (n : Nat) (α : Fin (n + 1) → Sort _) - (f : ∀ (i : Fin n), α i.castSucc → α i.succ) (init : α 0) : α (last n) := - loop 0 (Nat.zero_lt_succ n) init where - /-- Inner loop for `Fin.dfoldl`. `Fin.dfoldl.loop n α f i h x = f n (f (n-1) (... (f i x)))` -/ - @[semireducible, specialize] loop (i : Nat) (h : i < n + 1) (x : α ⟨i, h⟩) : α (last n) := - if h' : i < n then - loop (i + 1) (Nat.succ_lt_succ h') (f ⟨i, h'⟩ x) - else - haveI : ⟨i, h⟩ = last n := by ext; simp; omega - _root_.cast (congrArg α this) x - -/-- Dependent version of `Fin.foldr`. -/ -@[inline] def dfoldr (n : Nat) (α : Fin (n + 1) → Sort _) - (f : ∀ (i : Fin n), α i.succ → α i.castSucc) (init : α (last n)) : α 0 := - loop n (Nat.lt_succ_self n) init where - /-- Inner loop for `Fin.dfoldr`. - `Fin.dfoldr.loop n α f i h x = f 0 (f 1 (... (f i x)))` -/ - @[specialize] loop (i : Nat) (h : i < n + 1) (x : α ⟨i, h⟩) : α 0 := - match i with - | i + 1 => loop i (Nat.lt_of_succ_lt h) (f ⟨i, Nat.lt_of_succ_lt_succ h⟩ x) - | 0 => x diff --git a/Batteries/Data/Fin/Fold.lean b/Batteries/Data/Fin/Fold.lean index 6c2842e224..bbbe373535 100644 --- a/Batteries/Data/Fin/Fold.lean +++ b/Batteries/Data/Fin/Fold.lean @@ -8,74 +8,49 @@ import Batteries.Data.Fin.Basic namespace Fin -/-! ### dfoldlM -/ - -theorem dfoldlM_loop_lt [Monad m] (f : ∀ (i : Fin n), α i.castSucc → m (α i.succ)) (h : i < n) (x) : - dfoldlM.loop n α f i (Nat.lt_add_right 1 h) x = - (f ⟨i, h⟩ x) >>= (dfoldlM.loop n α f (i+1) (Nat.add_lt_add_right h 1)) := by - rw [dfoldlM.loop, dif_pos h] +/-! ### dfoldr -/ -theorem dfoldlM_loop_eq [Monad m] (f : ∀ (i : Fin n), α i.castSucc → m (α i.succ)) (x) : - dfoldlM.loop n α f n (Nat.le_refl _) x = pure x := by - rw [dfoldlM.loop, dif_neg (Nat.lt_irrefl _), cast_eq] +theorem dfoldr_loop_zero (f : (i : Fin n) → α i.succ → α i.castSucc) (x) : + dfoldr.loop n α f 0 (Nat.zero_lt_succ n) x = x := rfl -@[simp] theorem dfoldlM_zero [Monad m] (f : (i : Fin 0) → α i.castSucc → m (α i.succ)) (x) : - dfoldlM 0 α f x = pure x := dfoldlM_loop_eq .. +theorem dfoldr_loop_succ (f : (i : Fin n) → α i.succ → α i.castSucc) (h : i < n) (x) : + dfoldr.loop n α f (i+1) (Nat.add_lt_add_right h 1) x = + dfoldr.loop n α f i (Nat.lt_add_right 1 h) (f ⟨i, h⟩ x) := rfl -theorem dfoldlM_loop [Monad m] (f : (i : Fin (n+1)) → α i.castSucc → m (α i.succ)) (h : i < n+1) - (x) : dfoldlM.loop (n+1) α f i (Nat.lt_add_right 1 h) x = - f ⟨i, h⟩ x >>= (dfoldlM.loop n (α ∘ succ) (f ·.succ ·) i h .) := by - if h' : i < n then - rw [dfoldlM_loop_lt _ h _] - congr; funext - rw [dfoldlM_loop_lt _ h' _, dfoldlM_loop]; rfl - else - cases Nat.le_antisymm (Nat.le_of_lt_succ h) (Nat.not_lt.1 h') - rw [dfoldlM_loop_lt] - congr; funext - rw [dfoldlM_loop_eq, dfoldlM_loop_eq] +theorem dfoldr_loop (f : (i : Fin (n+1)) → α i.succ → α i.castSucc) (h : i+1 ≤ n+1) (x) : + dfoldr.loop (n+1) α f (i+1) (Nat.add_lt_add_right h 1) x = + f 0 (dfoldr.loop n (α ∘ succ) (f ·.succ) i h x) := by + induction i with + | zero => rfl + | succ i ih => exact ih .. -theorem dfoldlM_succ [Monad m] (f : (i : Fin (n+1)) → α i.castSucc → m (α i.succ)) (x) : - dfoldlM (n+1) α f x = f 0 x >>= (dfoldlM n (α ∘ succ) (f ·.succ ·) .) := - dfoldlM_loop .. +@[simp] theorem dfoldr_zero (f : (i : Fin 0) → α i.succ → α i.castSucc) (x) : + dfoldr 0 α f x = x := rfl -theorem dfoldlM_eq_foldlM [Monad m] (f : (i : Fin n) → α → m α) (x : α) : - dfoldlM n (fun _ => α) f x = foldlM n (fun x i => f i x) x := by - induction n generalizing x with - | zero => simp only [dfoldlM_zero, foldlM_zero] - | succ n ih => - simp only [dfoldlM_succ, foldlM_succ, Function.comp_apply, Function.comp_def] - congr; ext; simp only [ih] +theorem dfoldr_succ (f : (i : Fin (n+1)) → α i.succ → α i.castSucc) (x) : + dfoldr (n+1) α f x = f 0 (dfoldr n (α ∘ succ) (f ·.succ) x) := dfoldr_loop .. -/-! ### dfoldrM -/ +theorem dfoldr_succ_last (f : (i : Fin (n+1)) → α i.succ → α i.castSucc) (x) : + dfoldr (n+1) α f x = dfoldr n (α ∘ castSucc) (f ·.castSucc) (f (last n) x) := by + induction n with + | zero => simp only [dfoldr_succ, dfoldr_zero, last, zero_eta] + | succ n ih => rw [dfoldr_succ, ih (α := α ∘ succ) (f ·.succ), dfoldr_succ]; congr -theorem dfoldrM_loop_zero [Monad m] (f : (i : Fin n) → α i.succ → m (α i.castSucc)) (x) : - dfoldrM.loop n α f 0 (Nat.zero_lt_succ n) x = pure x := by - rw [dfoldrM.loop, dif_neg (Nat.not_lt_zero _), cast_eq] +theorem dfoldr_eq_dfoldrM (f : (i : Fin n) → α i.succ → α i.castSucc) (x) : + dfoldr n α f x = dfoldrM (m:=Id) n α f x := rfl -theorem dfoldrM_loop_succ [Monad m] (f : (i : Fin n) → α i.succ → m (α i.castSucc)) (h : i < n) - (x) : dfoldrM.loop n α f (i+1) (Nat.add_lt_add_right h 1) x = - f ⟨i, h⟩ x >>= dfoldrM.loop n α f i (Nat.lt_add_right 1 h) := by - rw [dfoldrM.loop, dif_pos (Nat.zero_lt_succ i)] - simp only [Nat.add_one_sub_one, castSucc_mk, succ_mk, eq_mpr_eq_cast, cast_eq] +theorem dfoldr_eq_foldr (f : Fin n → α → α) (x : α) : dfoldr n (fun _ => α) f x = foldr n f x := by + induction n with + | zero => simp only [dfoldr_zero, foldr_zero] + | succ n ih => simp only [dfoldr_succ, foldr_succ, Function.comp_apply, Function.comp_def, ih] -theorem dfoldrM_loop [Monad m] [LawfulMonad m] (f : (i : Fin (n+1)) → α i.succ → m (α i.castSucc)) - (h : i+1 ≤ n+1) (x) : dfoldrM.loop (n+1) α f (i+1) (Nat.add_lt_add_right h 1) x = - dfoldrM.loop n (α ∘ succ) (f ·.succ) i h x >>= f 0 := by - induction i with - | zero => - rw [dfoldrM_loop_zero, dfoldrM_loop_succ, pure_bind] - conv => rhs; rw [←bind_pure (f 0 x)] - congr - | succ i ih => - rw [dfoldrM_loop_succ _ h, dfoldrM_loop_succ _ (Nat.succ_lt_succ_iff.mp h), bind_assoc] - congr; funext; exact ih .. +/-! ### dfoldrM -/ @[simp] theorem dfoldrM_zero [Monad m] (f : (i : Fin 0) → α i.succ → m (α i.castSucc)) (x) : - dfoldrM 0 α f x = pure x := dfoldrM_loop_zero .. + dfoldrM 0 α f x = pure x := rfl theorem dfoldrM_succ [Monad m] [LawfulMonad m] (f : (i : Fin (n+1)) → α i.succ → m (α i.castSucc)) - (x) : dfoldrM (n+1) α f x = dfoldrM n (α ∘ succ) (f ·.succ) x >>= f 0 := dfoldrM_loop .. + (x) : dfoldrM (n+1) α f x = dfoldrM n (α ∘ succ) (f ·.succ) x >>= f 0 := dfoldr_succ .. theorem dfoldrM_eq_foldrM [Monad m] [LawfulMonad m] (f : (i : Fin n) → α → m α) (x : α) : dfoldrM n (fun _ => α) f x = foldrM n f x := by @@ -126,47 +101,44 @@ theorem dfoldl_eq_foldl (f : Fin n → α → α) (x : α) : simp only [dfoldl_succ, foldl_succ, Function.comp_apply, Function.comp_def] congr; simp only [ih] -/-! ### dfoldr -/ - -theorem dfoldr_loop_zero (f : (i : Fin n) → α i.succ → α i.castSucc) (x) : - dfoldr.loop n α f 0 (Nat.zero_lt_succ n) x = x := by - rw [dfoldr.loop] - -theorem dfoldr_loop_succ (f : (i : Fin n) → α i.succ → α i.castSucc) (h : i < n) (x) : - dfoldr.loop n α f (i+1) (Nat.add_lt_add_right h 1) x = - dfoldr.loop n α f i (Nat.lt_add_right 1 h) (f ⟨i, h⟩ x) := by - rw [dfoldr.loop] - -theorem dfoldr_loop (f : (i : Fin (n+1)) → α i.succ → α i.castSucc) (h : i+1 ≤ n+1) (x) : - dfoldr.loop (n+1) α f (i+1) (Nat.add_lt_add_right h 1) x = - f 0 (dfoldr.loop n (α ∘ succ) (f ·.succ) i h x) := by - induction i with - | zero => simp [dfoldr_loop_succ, dfoldr_loop_zero] - | succ i ih => rw [dfoldr_loop_succ _ h, dfoldr_loop_succ _ (Nat.succ_lt_succ_iff.mp h), - ih (Nat.le_of_succ_le h)]; rfl +/-! ### dfoldlM -/ -@[simp] theorem dfoldr_zero (f : (i : Fin 0) → α i.succ → α i.castSucc) (x) : - dfoldr 0 α f x = x := dfoldr_loop_zero .. +theorem dfoldlM_loop_lt [Monad m] (f : ∀ (i : Fin n), α i.castSucc → m (α i.succ)) (h : i < n) (x) : + dfoldlM.loop n α f i (Nat.lt_add_right 1 h) x = + (f ⟨i, h⟩ x) >>= (dfoldlM.loop n α f (i+1) (Nat.add_lt_add_right h 1)) := by + rw [dfoldlM.loop, dif_pos h] -theorem dfoldr_succ (f : (i : Fin (n+1)) → α i.succ → α i.castSucc) (x) : - dfoldr (n+1) α f x = f 0 (dfoldr n (α ∘ succ) (f ·.succ) x) := dfoldr_loop .. +theorem dfoldlM_loop_eq [Monad m] (f : ∀ (i : Fin n), α i.castSucc → m (α i.succ)) (x) : + dfoldlM.loop n α f n (Nat.le_refl _) x = pure x := by + rw [dfoldlM.loop, dif_neg (Nat.lt_irrefl _), cast_eq] -theorem dfoldr_succ_last (f : (i : Fin (n+1)) → α i.succ → α i.castSucc) (x) : - dfoldr (n+1) α f x = dfoldr n (α ∘ castSucc) (f ·.castSucc) (f (last n) x) := by - induction n with - | zero => simp only [dfoldr_succ, dfoldr_zero, last, zero_eta] - | succ n ih => rw [dfoldr_succ, ih (α := α ∘ succ) (f ·.succ), dfoldr_succ]; congr +@[simp] theorem dfoldlM_zero [Monad m] (f : (i : Fin 0) → α i.castSucc → m (α i.succ)) (x) : + dfoldlM 0 α f x = pure x := dfoldlM_loop_eq .. -theorem dfoldr_eq_dfoldrM (f : (i : Fin n) → α i.succ → α i.castSucc) (x) : - dfoldr n α f x = dfoldrM (m:=Id) n α f x := by - induction n <;> simp [dfoldr_succ, dfoldrM_succ, *] +theorem dfoldlM_loop [Monad m] (f : (i : Fin (n+1)) → α i.castSucc → m (α i.succ)) (h : i < n+1) + (x) : dfoldlM.loop (n+1) α f i (Nat.lt_add_right 1 h) x = + f ⟨i, h⟩ x >>= (dfoldlM.loop n (α ∘ succ) (f ·.succ ·) i h .) := by + if h' : i < n then + rw [dfoldlM_loop_lt _ h _] + congr; funext + rw [dfoldlM_loop_lt _ h' _, dfoldlM_loop]; rfl + else + cases Nat.le_antisymm (Nat.le_of_lt_succ h) (Nat.not_lt.1 h') + rw [dfoldlM_loop_lt] + congr; funext + rw [dfoldlM_loop_eq, dfoldlM_loop_eq] -theorem dfoldr_eq_foldr (f : Fin n → α → α) (x : α) : dfoldr n (fun _ => α) f x = foldr n f x := by - induction n with - | zero => simp only [dfoldr_zero, foldr_zero] - | succ n ih => simp only [dfoldr_succ, foldr_succ, Function.comp_apply, Function.comp_def, ih] +theorem dfoldlM_succ [Monad m] (f : (i : Fin (n+1)) → α i.castSucc → m (α i.succ)) (x) : + dfoldlM (n+1) α f x = f 0 x >>= (dfoldlM n (α ∘ succ) (f ·.succ ·) .) := + dfoldlM_loop .. --- TODO: add `dfoldl_rev` and `dfoldr_rev` +theorem dfoldlM_eq_foldlM [Monad m] (f : (i : Fin n) → α → m α) (x : α) : + dfoldlM n (fun _ => α) f x = foldlM n (fun x i => f i x) x := by + induction n generalizing x with + | zero => simp only [dfoldlM_zero, foldlM_zero] + | succ n ih => + simp only [dfoldlM_succ, foldlM_succ, Function.comp_apply, Function.comp_def] + congr; ext; simp only [ih] /-! ### `Fin.fold{l/r}{M}` equals `List.fold{l/r}{M}` -/