Skip to content

Commit

Permalink
fix: use fget and fset
Browse files Browse the repository at this point in the history
  • Loading branch information
fgdorais committed Dec 23, 2024
1 parent d3a45c1 commit 8510de1
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 45 deletions.
44 changes: 22 additions & 22 deletions Batteries/Data/DArray/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ convert array items to the appropriate type when necessary.
-- TODO: Use a structure once [lean4#2292](https://github.com/leanprover/lean4/pull/2292) is fixed.
inductive DArray (n) (α : Fin n → Type _) where
/-- Makes a new `DArray` with given item values. `O(n*g)` where `get i` is `O(g)`. -/
| mk (get : (i : Fin n) → α i)
| mk (fget : (i : Fin n) → α i)

namespace DArray

Expand All @@ -40,13 +40,13 @@ private unsafe abbrev data : DArray n α → Array NonScalar := unsafeCast
private unsafe def mkImpl (get : (i : Fin n) → α i) : DArray n α :=
unsafeCast <| Array.ofFn fun i => (unsafeCast (get i) : NonScalar)

private unsafe def getImpl (a : DArray n α) (i) : α i :=
private unsafe def fgetImpl (a : DArray n α) (i) : α i :=
unsafeCast <| a.data.get i.val

private unsafe def ugetImpl (a : DArray n α) (i : USize) (h : i.toNat < n) : α ⟨i.toNat, h⟩ :=
unsafeCast <| a.data.uget i lcProof

private unsafe def setImpl (a : DArray n α) (i) (v : α i) : DArray n α :=
private unsafe def fsetImpl (a : DArray n α) (i) (v : α i) : DArray n α :=
unsafeCast <| a.data.set i (unsafeCast v) lcProof

private unsafe def usetImpl (a : DArray n α) (i : USize) (h : i.toNat < n) (v : α ⟨i.toNat, h⟩) :
Expand All @@ -57,7 +57,7 @@ private unsafe def modifyFImpl [Functor f] (a : DArray n α) (i : Fin n)
let v := unsafeCast <| a.data.get i
-- Make sure `v` is unshared, if possible, by replacing its array entry by `box(0)`.
let a := unsafeCast <| a.data.set i (unsafeCast ()) lcProof
setImpl a i <$> t v
fsetImpl a i <$> t v

private unsafe def umodifyFImpl [Functor f] (a : DArray n α) (i : USize) (h : i.toNat < n)
(t : α ⟨i.toNat, h⟩ → f (α ⟨i.toNat, h⟩)) : f (DArray n α) :=
Expand All @@ -84,28 +84,28 @@ instance (α : Fin n → Type _) [(i : Fin n) → Inhabited (α i)] : Inhabited
default := mk fun _ => default

/-- Gets the `DArray` item at index `i`. `O(1)`. -/
@[implemented_by getImpl]
protected def get : DArray n α → (i : Fin n) → α i
| mk get => get
@[implemented_by fgetImpl]
protected def fget : DArray n α → (i : Fin n) → α i
| mk fget => fget

@[simp, inherit_doc DArray.get]
protected abbrev getN (a : DArray n α) (i) (h : i < n := by get_elem_tactic) : α ⟨i, h⟩ :=
a.get ⟨i, h⟩
@[inherit_doc DArray.fget, inline]
protected def get (a : DArray n α) (i) (h : i < n := by get_elem_tactic) : α ⟨i, h⟩ :=
a.fget ⟨i, h⟩

/-- Gets the `DArray` item at index `i : USize`. Slightly faster than `get`; `O(1)`. -/
@[implemented_by ugetImpl]
protected def uget (a : DArray n α) (i : USize) (h : i.toNat < n) : α ⟨i.toNat, h⟩ :=
a.get ⟨i.toNat, h⟩
a.fget ⟨i.toNat, h⟩

private def casesOnImpl.{u} {motive : DArray n α → Sort u} (a : DArray n α)
(h : (get : (i : Fin n) → α i) → motive (.mk get)) : motive a :=
h a.get
(h : (fget : (i : Fin n) → α i) → motive (.mk fget)) : motive a :=
h a.fget

attribute [implemented_by casesOnImpl] DArray.casesOn

/-- Sets the `DArray` item at index `i`. `O(1)` if exclusive else `O(n)`. -/
@[implemented_by setImpl]
protected def set (a : DArray n α) (i : Fin n) (v : α i) : DArray n α :=
@[implemented_by fsetImpl]
protected def fset (a : DArray n α) (i : Fin n) (v : α i) : DArray n α :=
mk fun j => if h : i = j then h ▸ v else a.get j

/--
Expand All @@ -114,16 +114,16 @@ Slightly faster than `set` and `O(1)` if exclusive else `O(n)`.
-/
@[implemented_by usetImpl]
protected def uset (a : DArray n α) (i : USize) (h : i.toNat < n) (v : α ⟨i.toNat, h⟩) :=
a.set ⟨i.toNat, h⟩ v
a.fset ⟨i.toNat, h⟩ v

@[simp, inherit_doc DArray.set]
protected abbrev setN (a : DArray n α) (i) (h : i < n := by get_elem_tactic) (v : α ⟨i, h⟩) :=
a.set ⟨i, h⟩ v
@[simp, inherit_doc DArray.fset]
protected abbrev set (a : DArray n α) (i) (h : i < n := by get_elem_tactic) (v : α ⟨i, h⟩) :=
a.fset ⟨i, h⟩ v

/-- Modifies the `DArray` item at index `i` using transform `t` and the functor `f`. -/
@[implemented_by modifyFImpl]
protected def modifyF [Functor f] (a : DArray n α) (i : Fin n)
(t : α i → f (α i)) : f (DArray n α) := a.set i <$> t (a.get i)
(t : α i → f (α i)) : f (DArray n α) := a.fset i <$> t (a.fget i)

/-- Modifies the `DArray` item at index `i` using transform `t`. -/
@[inline]
Expand All @@ -143,13 +143,13 @@ protected def umodify (a : DArray n α) (i : USize) (h : i.toNat < n)

/-- Copies the `DArray` to an exclusive `DArray`. `O(1)` if exclusive else `O(n)`. -/
@[implemented_by copyImpl]
protected def copy (a : DArray n α) : DArray n α := mk a.get
protected def copy (a : DArray n α) : DArray n α := mk a.fget

/-- Push an element onto the end of a `DArray`. `O(1)` if exclusive else `O(n)`. -/
@[implemented_by pushImpl]
protected def push (a : DArray n α) (v : β) :
DArray (n+1) fun i => if h : i.val < n then α ⟨i.val, h⟩ else β :=
mk fun i => if h : i.val < n then dif_pos h ▸ a.get ⟨i.val, h⟩ else dif_neg h ▸ v
mk fun i => if h : i.val < n then dif_pos h ▸ a.fget ⟨i.val, h⟩ else dif_neg h ▸ v

/-- Delete the last item of a `DArray`. `O(1)`. -/
@[implemented_by popImpl]
Expand Down
47 changes: 24 additions & 23 deletions Batteries/Data/DArray/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,57 +9,58 @@ import Batteries.Data.DArray.Basic
namespace Batteries.DArray

@[ext]
protected theorem ext : {a b : DArray n α} → (∀ i, a.get i = b.get i) → a = b
protected theorem ext : {a b : DArray n α} → (∀ i, a.fget i = b.fget i) → a = b
| mk _, mk _, h => congrArg _ <| funext fun i => h i

@[simp]
theorem get_mk (i : Fin n) : DArray.get (.mk init) i = init i := rfl
theorem fget_mk (i : Fin n) : DArray.fget (.mk init) i = init i := rfl

theorem set_mk {α : Fin n → Type _} {init : (i : Fin n) → α i} (i : Fin n) (v : α i) :
DArray.set (.mk init) i v = .mk fun j => if h : i = j then h ▸ v else init j := rfl
theorem fset_mk {α : Fin n → Type _} {init : (i : Fin n) → α i} (i : Fin n) (v : α i) :
DArray.fset (.mk init) i v = .mk fun j => if h : i = j then h ▸ v else init j := rfl

@[simp]
theorem get_set (a : DArray n α) (i : Fin n) (v : α i) : (a.set i v).get i = v := by
simp only [DArray.get, DArray.set, dif_pos]
theorem fget_fset (a : DArray n α) (i : Fin n) (v : α i) : (a.fset i v).fget i = v := by
simp only [DArray.fget, DArray.fset, dif_pos]

theorem get_set_ne (a : DArray n α) (v : α i) (h : i ≠ j) : (a.set i v).get j = a.get j := by
simp only [DArray.get, DArray.set, dif_neg h]
theorem fget_fset_ne (a : DArray n α) (v : α i) (h : i ≠ j) : (a.fset i v).fget j = a.fget j := by
simp only [DArray.fget, DArray.fset, dif_neg h]; rfl

@[simp]
theorem set_set (a : DArray n α) (i : Fin n) (v w : α i) : (a.set i v).set i w = a.set i w := by
theorem fset_fset (a : DArray n α) (i : Fin n) (v w : α i) :
(a.fset i v).fset i w = a.fset i w := by
ext j
if h : i = j then
rw [← h, get_set, get_set]
rw [← h, fget_fset, fget_fset]
else
rw [get_set_ne _ _ h, get_set_ne _ _ h, get_set_ne _ _ h]
rw [fget_fset_ne _ _ h, fget_fset_ne _ _ h, fget_fset_ne _ _ h]

theorem get_modifyF [Functor f] [LawfulFunctor f] (a : DArray n α) (i : Fin n) (t : α i → f (α i)) :
(DArray.get . i) <$> a.modifyF i t = t (a.get i) := by
theorem fget_modifyF [Functor f] [LawfulFunctor f] (a : DArray n α) (i : Fin n)
(t : α i → f (α i)) : (DArray.fget . i) <$> a.modifyF i t = t (a.fget i) := by
simp [DArray.modifyF]

@[simp]
theorem get_modify (a : DArray n α) (i : Fin n) (t : α i → α i) :
(a.modify i t).get i = t (a.get i) := get_modifyF (f:=Id) a i t
theorem fget_modify (a : DArray n α) (i : Fin n) (t : α i → α i) :
(a.modify i t).fget i = t (a.fget i) := fget_modifyF (f:=Id) a i t

theorem get_modify_ne (a : DArray n α) (t : α i → α i) (h : i ≠ j) :
(a.modify i t).get j = a.get j := get_set_ne _ _ h
theorem fget_modify_ne (a : DArray n α) (t : α i → α i) (h : i ≠ j) :
(a.modify i t).fget j = a.fget j := fget_fset_ne _ _ h

@[simp]
theorem set_modify (a : DArray n α) (i : Fin n) (t : α i → α i) (v : α i) :
(a.set i v).modify i t = a.set i (t v) := by
(a.fset i v).modify i t = a.fset i (t v) := by
ext j
if h : i = j then
cases h; simp
else
simp [h, get_modify_ne, get_set_ne]
simp [h, fget_modify_ne, fget_fset_ne]

@[simp]
theorem uget_eq_get (a : DArray n α) (i : USize) (h : i.toNat < n) :
a.uget i h = a.get ⟨i.toNat, h⟩ := rfl
theorem uget_eq_fget (a : DArray n α) (i : USize) (h : i.toNat < n) :
a.uget i h = a.fget ⟨i.toNat, h⟩ := rfl

@[simp]
theorem uset_eq_set (a : DArray n α) (i : USize) (h : i.toNat < n) (v : α ⟨i.toNat, h⟩) :
a.uset i h v = a.set ⟨i.toNat, h⟩ v := rfl
theorem uset_eq_fset (a : DArray n α) (i : USize) (h : i.toNat < n) (v : α ⟨i.toNat, h⟩) :
a.uset i h v = a.fset ⟨i.toNat, h⟩ v := rfl

@[simp]
theorem umodifyF_eq_modifyF [Functor f] (a : DArray n α) (i : USize) (h : i.toNat < n)
Expand Down

0 comments on commit 8510de1

Please sign in to comment.