Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Define dependent version of Fin.foldl #1071

Merged
merged 15 commits into from
Dec 4, 2024
Merged
65 changes: 65 additions & 0 deletions Batteries/Data/Fin/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,68 @@ alias enum := Array.finRange

@[deprecated (since := "2024-11-15")]
alias list := List.finRange

/-- Dependent version of `Fin.foldlM`. -/
@[inline] def dfoldlM [Monad m] (n : Nat) (α : Fin (n + 1) → Sort _)
(f : ∀ (i : Fin n), α i.castSucc → m (α i.succ)) (init : α 0) : m (α (last n)) :=
loop 0 (Nat.zero_lt_succ n) init where
/-- Inner loop for `Fin.dfoldlM`.
```
Fin.foldM.loop n α f i h xᵢ = do
let xᵢ₊₁ ← f i xᵢ
...
let xₙ ← f (n-1) xₙ₋₁
pure xₙ
```
-/
@[semireducible] loop (i : Nat) (h : i < n + 1) (x : α ⟨i, h⟩) : m (α (last n)) :=
if h' : i < n then
(f ⟨i, h'⟩ x) >>= loop (i + 1) (Nat.succ_lt_succ h')
else
haveI : ⟨i, h⟩ = last n := by ext; simp; omega
_root_.cast (congrArg (fun i => m (α i)) this) (pure x)

/-- Dependent version of `Fin.foldrM`. -/
@[inline] def dfoldrM [Monad m] (n : Nat) (α : Fin (n + 1) → Sort _)
(f : ∀ (i : Fin n), α i.succ → m (α i.castSucc)) (init : α (last n)) : m (α 0) :=
loop n (Nat.lt_succ_self n) init where
/-- Inner loop for `Fin.foldRevM`.
```
Fin.foldRevM.loop n α f i h xᵢ = do
let xᵢ₋₁ ← f (i+1) xᵢ
...
let x₁ ← f 1 x₂
let x₀ ← f 0 x₁
pure x₀
```
-/
@[semireducible] loop (i : Nat) (h : i < n + 1) (x : α ⟨i, h⟩) : m (α 0) :=
if h' : i > 0 then
(f ⟨i - 1, by omega⟩ (by simpa [Nat.sub_one_add_one_eq_of_pos h'] using x))
>>= loop (i - 1) (by omega)
else
haveI : ⟨i, h⟩ = 0 := by ext; simp; omega
_root_.cast (congrArg (fun i => m (α i)) this) (pure x)

/-- Dependent version of `Fin.foldl`. -/
@[inline] def dfoldl (n : Nat) (α : Fin (n + 1) → Sort _)
(f : ∀ (i : Fin n), α i.castSucc → α i.succ) (init : α 0) : α (last n) :=
loop 0 (Nat.zero_lt_succ n) init where
/-- Inner loop for `Fin.dfoldl`. `Fin.dfoldl.loop n α f i h x = f n (f (n-1) (... (f i x)))` -/
@[semireducible, specialize] loop (i : Nat) (h : i < n + 1) (x : α ⟨i, h⟩) : α (last n) :=
if h' : i < n then
loop (i + 1) (Nat.succ_lt_succ h') (f ⟨i, h'⟩ x)
else
haveI : ⟨i, h⟩ = last n := by ext; simp; omega
_root_.cast (congrArg α this) x

/-- Dependent version of `Fin.foldr`. -/
@[inline] def dfoldr (n : Nat) (α : Fin (n + 1) → Sort _)
(f : ∀ (i : Fin n), α i.succ → α i.castSucc) (init : α (last n)) : α 0 :=
loop n (Nat.lt_succ_self n) init where
/-- Inner loop for `Fin.dfoldr`.
`Fin.dfoldr.loop n α f i h x = f 0 (f 1 (... (f i x)))` -/
@[specialize] loop (i : Nat) (h : i < n + 1) (x : α ⟨i, h⟩) : α 0 :=
match i with
| i + 1 => loop i (Nat.lt_of_succ_lt h) (f ⟨i, Nat.lt_of_succ_lt_succ h⟩ x)
| 0 => x
165 changes: 164 additions & 1 deletion Batteries/Data/Fin/Fold.lean
Original file line number Diff line number Diff line change
@@ -1,12 +1,175 @@
/-
Copyright (c) 2024 François G. Dorais. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: François G. Dorais
Authors: François G. Dorais, Quang Dao
-/
import Batteries.Data.List.FinRange
import Batteries.Data.Fin.Basic

namespace Fin

/-! ### dfoldlM -/

theorem dfoldlM_loop_lt [Monad m] (f : ∀ (i : Fin n), α i.castSucc → m (α i.succ)) (h : i < n) (x) :
dfoldlM.loop n α f i (Nat.lt_add_right 1 h) x =
(f ⟨i, h⟩ x) >>= (dfoldlM.loop n α f (i+1) (Nat.add_lt_add_right h 1)) := by
rw [dfoldlM.loop, dif_pos h]

theorem dfoldlM_loop_eq [Monad m] (f : ∀ (i : Fin n), α i.castSucc → m (α i.succ)) (x) :
dfoldlM.loop n α f n (Nat.le_refl _) x = pure x := by
rw [dfoldlM.loop, dif_neg (Nat.lt_irrefl _), cast_eq]

@[simp] theorem dfoldlM_zero [Monad m] (f : (i : Fin 0) → α i.castSucc → m (α i.succ)) (x) :
dfoldlM 0 α f x = pure x := dfoldlM_loop_eq ..

theorem dfoldlM_loop [Monad m] (f : (i : Fin (n+1)) → α i.castSucc → m (α i.succ)) (h : i < n+1)
(x) : dfoldlM.loop (n+1) α f i (Nat.lt_add_right 1 h) x =
f ⟨i, h⟩ x >>= (dfoldlM.loop n (α ∘ succ) (f ·.succ ·) i h .) := by
if h' : i < n then
rw [dfoldlM_loop_lt _ h _]
congr; funext
rw [dfoldlM_loop_lt _ h' _, dfoldlM_loop]; rfl
else
cases Nat.le_antisymm (Nat.le_of_lt_succ h) (Nat.not_lt.1 h')
rw [dfoldlM_loop_lt]
congr; funext
rw [dfoldlM_loop_eq, dfoldlM_loop_eq]

theorem dfoldlM_succ [Monad m] (f : (i : Fin (n+1)) → α i.castSucc → m (α i.succ)) (x) :
dfoldlM (n+1) α f x = f 0 x >>= (dfoldlM n (α ∘ succ) (f ·.succ ·) .) :=
dfoldlM_loop ..

theorem dfoldlM_eq_foldlM [Monad m] (f : (i : Fin n) → α → m α) (x : α) :
dfoldlM n (fun _ => α) f x = foldlM n (fun x i => f i x) x := by
induction n generalizing x with
| zero => simp only [dfoldlM_zero, foldlM_zero]
| succ n ih =>
simp only [dfoldlM_succ, foldlM_succ, Function.comp_apply, Function.comp_def]
congr; ext; simp only [ih]

/-! ### dfoldrM -/

theorem dfoldrM_loop_zero [Monad m] (f : (i : Fin n) → α i.succ → m (α i.castSucc)) (x) :
dfoldrM.loop n α f 0 (Nat.zero_lt_succ n) x = pure x := by
rw [dfoldrM.loop, dif_neg (Nat.not_lt_zero _), cast_eq]

theorem dfoldrM_loop_succ [Monad m] (f : (i : Fin n) → α i.succ → m (α i.castSucc)) (h : i < n)
(x) : dfoldrM.loop n α f (i+1) (Nat.add_lt_add_right h 1) x =
f ⟨i, h⟩ x >>= dfoldrM.loop n α f i (Nat.lt_add_right 1 h) := by
rw [dfoldrM.loop, dif_pos (Nat.zero_lt_succ i)]
simp only [Nat.add_one_sub_one, castSucc_mk, succ_mk, eq_mpr_eq_cast, cast_eq]

theorem dfoldrM_loop [Monad m] [LawfulMonad m] (f : (i : Fin (n+1)) → α i.succ → m (α i.castSucc))
(h : i+1 ≤ n+1) (x) : dfoldrM.loop (n+1) α f (i+1) (Nat.add_lt_add_right h 1) x =
dfoldrM.loop n (α ∘ succ) (f ·.succ) i h x >>= f 0 := by
induction i with
| zero =>
rw [dfoldrM_loop_zero, dfoldrM_loop_succ, pure_bind]
conv => rhs; rw [←bind_pure (f 0 x)]
congr
| succ i ih =>
rw [dfoldrM_loop_succ _ h, dfoldrM_loop_succ _ (Nat.succ_lt_succ_iff.mp h), bind_assoc]
congr; funext; exact ih ..

@[simp] theorem dfoldrM_zero [Monad m] (f : (i : Fin 0) → α i.succ → m (α i.castSucc)) (x) :
dfoldrM 0 α f x = pure x := dfoldrM_loop_zero ..

theorem dfoldrM_succ [Monad m] [LawfulMonad m] (f : (i : Fin (n+1)) → α i.succ → m (α i.castSucc))
(x) : dfoldrM (n+1) α f x = dfoldrM n (α ∘ succ) (f ·.succ) x >>= f 0 := dfoldrM_loop ..

theorem dfoldrM_eq_foldrM [Monad m] [LawfulMonad m] (f : (i : Fin n) → α → m α) (x : α) :
dfoldrM n (fun _ => α) f x = foldrM n f x := by
induction n generalizing x with
| zero => simp only [dfoldrM_zero, foldrM_zero]
| succ n ih => simp only [dfoldrM_succ, foldrM_succ, Function.comp_def, ih]

/-! ### dfoldl -/

theorem dfoldl_loop_lt (f : ∀ (i : Fin n), α i.castSucc → α i.succ) (h : i < n) (x) :
dfoldl.loop n α f i (Nat.lt_add_right 1 h) x =
dfoldl.loop n α f (i+1) (Nat.add_lt_add_right h 1) (f ⟨i, h⟩ x) := by
rw [dfoldl.loop, dif_pos h]

theorem dfoldl_loop_eq (f : ∀ (i : Fin n), α i.castSucc → α i.succ) (x) :
dfoldl.loop n α f n (Nat.le_refl _) x = x := by
rw [dfoldl.loop, dif_neg (Nat.lt_irrefl _), cast_eq]

@[simp] theorem dfoldl_zero (f : (i : Fin 0) → α i.castSucc → α i.succ) (x) :
dfoldl 0 α f x = x := dfoldl_loop_eq ..

theorem dfoldl_loop (f : (i : Fin (n+1)) → α i.castSucc → α i.succ) (h : i < n+1) (x) :
dfoldl.loop (n+1) α f i (Nat.lt_add_right 1 h) x =
dfoldl.loop n (α ∘ succ) (f ·.succ ·) i h (f ⟨i, h⟩ x) := by
if h' : i < n then
rw [dfoldl_loop_lt _ h _]
rw [dfoldl_loop_lt _ h' _, dfoldl_loop]; rfl
else
cases Nat.le_antisymm (Nat.le_of_lt_succ h) (Nat.not_lt.1 h')
rw [dfoldl_loop_lt]
rw [dfoldl_loop_eq, dfoldl_loop_eq]

theorem dfoldl_succ (f : (i : Fin (n+1)) → α i.castSucc → α i.succ) (x) :
dfoldl (n+1) α f x = dfoldl n (α ∘ succ) (f ·.succ ·) (f 0 x) := dfoldl_loop ..

theorem dfoldl_succ_last (f : (i : Fin (n+1)) → α i.castSucc → α i.succ) (x) :
dfoldl (n+1) α f x = f (last n) (dfoldl n (α ∘ castSucc) (f ·.castSucc ·) x) := by
rw [dfoldl_succ]
induction n with
| zero => simp [dfoldl_succ, last]
| succ n ih => rw [dfoldl_succ, @ih (α ∘ succ) (f ·.succ ·), dfoldl_succ]; congr

theorem dfoldl_eq_foldl (f : Fin n → α → α) (x : α) :
dfoldl n (fun _ => α) f x = foldl n (fun x i => f i x) x := by
induction n generalizing x with
| zero => simp only [dfoldl_zero, foldl_zero]
| succ n ih =>
simp only [dfoldl_succ, foldl_succ, Function.comp_apply, Function.comp_def]
congr; simp only [ih]

/-! ### dfoldr -/

theorem dfoldr_loop_zero (f : (i : Fin n) → α i.succ → α i.castSucc) (x) :
dfoldr.loop n α f 0 (Nat.zero_lt_succ n) x = x := by
rw [dfoldr.loop]
fgdorais marked this conversation as resolved.
Show resolved Hide resolved

theorem dfoldr_loop_succ (f : (i : Fin n) → α i.succ → α i.castSucc) (h : i < n) (x) :
dfoldr.loop n α f (i+1) (Nat.add_lt_add_right h 1) x =
dfoldr.loop n α f i (Nat.lt_add_right 1 h) (f ⟨i, h⟩ x) := by
rw [dfoldr.loop]
fgdorais marked this conversation as resolved.
Show resolved Hide resolved

theorem dfoldr_loop (f : (i : Fin (n+1)) → α i.succ → α i.castSucc) (h : i+1 ≤ n+1) (x) :
dfoldr.loop (n+1) α f (i+1) (Nat.add_lt_add_right h 1) x =
f 0 (dfoldr.loop n (α ∘ succ) (f ·.succ) i h x) := by
induction i with
| zero => simp [dfoldr_loop_succ, dfoldr_loop_zero]
| succ i ih => rw [dfoldr_loop_succ _ h, dfoldr_loop_succ _ (Nat.succ_lt_succ_iff.mp h),
ih (Nat.le_of_succ_le h)]; rfl
fgdorais marked this conversation as resolved.
Show resolved Hide resolved

@[simp] theorem dfoldr_zero (f : (i : Fin 0) → α i.succ → α i.castSucc) (x) :
dfoldr 0 α f x = x := dfoldr_loop_zero ..
fgdorais marked this conversation as resolved.
Show resolved Hide resolved

theorem dfoldr_succ (f : (i : Fin (n+1)) → α i.succ → α i.castSucc) (x) :
dfoldr (n+1) α f x = f 0 (dfoldr n (α ∘ succ) (f ·.succ) x) := dfoldr_loop ..

theorem dfoldr_succ_last (f : (i : Fin (n+1)) → α i.succ → α i.castSucc) (x) :
dfoldr (n+1) α f x = dfoldr n (α ∘ castSucc) (f ·.castSucc) (f (last n) x) := by
induction n with
| zero => simp only [dfoldr_succ, dfoldr_zero, last, zero_eta]
| succ n ih => rw [dfoldr_succ, ih (α := α ∘ succ) (f ·.succ), dfoldr_succ]; congr

theorem dfoldr_eq_dfoldrM (f : (i : Fin n) → α i.succ → α i.castSucc) (x) :
dfoldr n α f x = dfoldrM (m:=Id) n α f x := by
induction n <;> simp [dfoldr_succ, dfoldrM_succ, *]

theorem dfoldr_eq_foldr (f : Fin n → α → α) (x : α) : dfoldr n (fun _ => α) f x = foldr n f x := by
induction n with
| zero => simp only [dfoldr_zero, foldr_zero]
| succ n ih => simp only [dfoldr_succ, foldr_succ, Function.comp_apply, Function.comp_def, ih]

-- TODO: add `dfoldl_rev` and `dfoldr_rev`

/-! ### `Fin.fold{l/r}{M}` equals `List.fold{l/r}{M}` -/

theorem foldlM_eq_foldlM_finRange [Monad m] (f : α → Fin n → m α) (x) :
foldlM n f x = (List.finRange n).foldlM f x := by
induction n generalizing x with
Expand Down
Loading