Skip to content

Commit

Permalink
feat: add uset
Browse files Browse the repository at this point in the history
  • Loading branch information
fgdorais committed May 28, 2024
1 parent b581611 commit 3f815fe
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions Batteries/Data/DArray.lean
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ private unsafe def ugetImpl (a : DArray n α) (i : USize) (h : i.toNat < n) : α
private unsafe def setImpl (a : DArray n α) (i) (v : α i) : DArray n α :=
unsafeCast <| a.data.set ⟨i.val, lcProof⟩ <| unsafeCast v

private unsafe def usetImpl (a : DArray n α) (i : USize) (h : i.toNat < n) (v : α ⟨i.toNat, h⟩) :
DArray n α := unsafeCast <| a.data.uset i (unsafeCast v) lcProof

private unsafe def copyImpl (a : DArray n α) : DArray n α :=
unsafeCast <| a.data.extract 0 n

Expand Down Expand Up @@ -84,6 +87,14 @@ attribute [implemented_by casesOnImpl] DArray.casesOn
protected def set (a : DArray n α) (i : Fin n) (v : α i) : DArray n α :=
mk fun j => if h : i = j then h ▸ v else a.get j

/--
Sets the `DArray` item at index `i : USize`.
Slightly faster than `set` and `O(1)` if exclusive else `O(n)`.
-/
@[implemented_by usetImpl]
protected def uset (a : DArray n α) (i : USize) (h : i.toNat < n) (v : α ⟨i.toNat, h⟩) :=
a.set ⟨i.toNat, h⟩ v

@[simp, inherit_doc DArray.set]
protected abbrev setN (a : DArray n α) (i) (h : i < n := by get_elem_tactic) (v : α ⟨i, h⟩) :=
a.set ⟨i, h⟩ v
Expand All @@ -102,10 +113,6 @@ theorem get_mk (i : Fin n) : DArray.get (.mk init) i = init i := rfl
theorem set_mk {α : Fin n → Type _} {init : (i : Fin n) → α i} (i : Fin n) (v : α i) :
DArray.set (.mk init) i v = .mk fun j => if h : i = j then h ▸ v else init j := rfl

@[simp]
theorem uget_eq_get (a : DArray n α) (i : USize) (h : i.toNat < n) :
a.uget i h = a.get ⟨i.toNat, h⟩ := rfl

@[simp]
theorem get_set (a : DArray n α) (i : Fin n) (v : α i) : (a.set i v).get i = v := by
simp only [DArray.get, DArray.set, dif_pos]
Expand All @@ -121,5 +128,13 @@ theorem set_set (a : DArray n α) (i : Fin n) (v w : α i) : (a.set i v).set i w
else
rw [get_set_ne _ _ h, get_set_ne _ _ h, get_set_ne _ _ h]

@[simp]
theorem uget_eq_get (a : DArray n α) (i : USize) (h : i.toNat < n) :
a.uget i h = a.get ⟨i.toNat, h⟩ := rfl

@[simp]
theorem uset_eq_set (a : DArray n α) (i : USize) (h : i.toNat < n) (v : α ⟨i.toNat, h⟩) :
a.uset i h v = a.set ⟨i.toNat, h⟩ v := rfl

@[simp]
theorem copy_eq (a : DArray n α) : a.copy = a := rfl

0 comments on commit 3f815fe

Please sign in to comment.