Skip to content

Commit

Permalink
feat: add modify and umodify
Browse files Browse the repository at this point in the history
  • Loading branch information
fgdorais committed May 28, 2024
1 parent 3f815fe commit 10d0530
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions Batteries/Data/DArray.lean
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@ private unsafe def setImpl (a : DArray n α) (i) (v : α i) : DArray n α :=
private unsafe def usetImpl (a : DArray n α) (i : USize) (h : i.toNat < n) (v : α ⟨i.toNat, h⟩) :
DArray n α := unsafeCast <| a.data.uset i (unsafeCast v) lcProof

private unsafe def modifyFImpl [Functor f] (a : DArray n α) (i : Fin n)
(t : α i → f (α i)) : f (DArray n α) :=
let v := unsafeCast <| a.data.get ⟨i.val, lcProof⟩
-- Make sure `v` is unshared, if possible, by replacing its array entry by `box(0)`.
let a := unsafeCast <| a.data.set ⟨i.val, lcProof⟩ (unsafeCast ())
setImpl a i <$> t v

private unsafe def umodifyFImpl [Functor f] (a : DArray n α) (i : USize) (h : i.toNat < n)
(t : α ⟨i.toNat, h⟩ → f (α ⟨i.toNat, h⟩)) : f (DArray n α) :=
let v := unsafeCast <| a.data.uget i lcProof
-- Make sure `v` is unshared, if possible, by replacing its array entry by `box(0)`.
let a := unsafeCast <| a.data.uset i (unsafeCast ()) lcProof
usetImpl a i h <$> t v

private unsafe def copyImpl (a : DArray n α) : DArray n α :=
unsafeCast <| a.data.extract 0 n

Expand Down Expand Up @@ -99,6 +113,25 @@ protected def uset (a : DArray n α) (i : USize) (h : i.toNat < n) (v : α ⟨i.
protected abbrev setN (a : DArray n α) (i) (h : i < n := by get_elem_tactic) (v : α ⟨i, h⟩) :=
a.set ⟨i, h⟩ v

/-- Modifies the `DArray` item at index `i` using transform `t` and the functor `f`. -/
@[implemented_by modifyFImpl]
protected def modifyF [Functor f] (a : DArray n α) (i : Fin n)
(t : α i → f (α i)) : f (DArray n α) := a.set i <$> t (a.get i)

/-- Modifies the `DArray` item at index `i` using transform `t`. -/
protected abbrev modify (a : DArray n α) (i : Fin n) (t : α i → α i) : DArray n α :=
a.modifyF (f:=Id) i t

/-- Modifies the `DArray` item at index `i : USize` using transform `t` and the functor `f`. -/
@[implemented_by umodifyFImpl]
protected def umodifyF [Functor f] (a : DArray n α) (i : USize) (h : i.toNat < n)
(t : α ⟨i.toNat, h⟩ → f (α ⟨i.toNat, h⟩)) : f (DArray n α) := a.uset i h <$> t (a.uget i h)

/-- Modifies the `DArray` item at index `i : USize` using transform `t`. -/
protected abbrev umodify (a : DArray n α) (i : USize) (h : i.toNat < n)
(t : α ⟨i.toNat, h⟩ → α ⟨i.toNat, h⟩) : DArray n α :=
a.umodifyF (f:=Id) i h t

/-- Copies the `DArray` to an exclusive `DArray`. `O(1)` if exclusive else `O(n)`. -/
@[implemented_by copyImpl]
protected def copy (a : DArray n α) : DArray n α := mk a.get
Expand Down

0 comments on commit 10d0530

Please sign in to comment.