diff --git a/Batteries.lean b/Batteries.lean index 1efcac66e7..93caa6bbe5 100644 --- a/Batteries.lean +++ b/Batteries.lean @@ -19,6 +19,7 @@ import Batteries.Data.BinomialHeap import Batteries.Data.ByteArray import Batteries.Data.ByteSubarray import Batteries.Data.Char +import Batteries.Data.DArray import Batteries.Data.DList import Batteries.Data.Fin import Batteries.Data.FloatArray diff --git a/Batteries/Data/DArray.lean b/Batteries/Data/DArray.lean new file mode 100644 index 0000000000..65bab05079 --- /dev/null +++ b/Batteries/Data/DArray.lean @@ -0,0 +1,2 @@ +import Batteries.Data.DArray.Basic +import Batteries.Data.DArray.Lemmas diff --git a/Batteries/Data/DArray/Basic.lean b/Batteries/Data/DArray/Basic.lean new file mode 100644 index 0000000000..087b2d7030 --- /dev/null +++ b/Batteries/Data/DArray/Basic.lean @@ -0,0 +1,157 @@ +/- +Copyright (c) 2024 François G. Dorais. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: François G. Dorais +-/ + +namespace Batteries + +/-! +# Dependent Arrays + +`DArray` is a heterogenous array where the type of each item depends on the index. The model +for this type is the dependent function type `(i : Fin n) → α i` where `α i` is the type assigned +to items at index `i`. + +The implementation of `DArray` is based on Lean's dynamic array type. This means that the array +values are stored in a contiguous memory region and can be accessed in constant time. Lean's arrays +also support destructive updates when the array is exclusive (RC=1). + +### Implementation Details + +Lean's array API does not directly support dependent arrays. Each `DArray n α` is internally stored +as an `Array NonScalar` with length `n`. This is sound since Lean's array implementation does not +record nor use the type of the items stored in the array. So it is safe to use `UnsafeCast` to +convert array items to the appropriate type when necessary. +-/ + +/-- `DArray` is a heterogenous array where the type of each item depends on the index. -/ +-- TODO: Use a structure once [lean4#2292](https://github.com/leanprover/lean4/pull/2292) is fixed. +inductive DArray (n) (α : Fin n → Type _) where + /-- Makes a new `DArray` with given item values. `O(n*g)` where `get i` is `O(g)`. -/ + | mk (fget : (i : Fin n) → α i) + +namespace DArray + +section unsafe_implementation + +private unsafe abbrev data : DArray n α → Array NonScalar := unsafeCast + +private unsafe def mkImpl (get : (i : Fin n) → α i) : DArray n α := + unsafeCast <| Array.ofFn fun i => (unsafeCast (get i) : NonScalar) + +private unsafe def fgetImpl (a : DArray n α) (i) : α i := + unsafeCast <| a.data.get i.val + +private unsafe def ugetImpl (a : DArray n α) (i : USize) (h : i.toNat < n) : α ⟨i.toNat, h⟩ := + unsafeCast <| a.data.uget i lcProof + +private unsafe def fsetImpl (a : DArray n α) (i) (v : α i) : DArray n α := + unsafeCast <| a.data.set i (unsafeCast v) lcProof + +private unsafe def usetImpl (a : DArray n α) (i : USize) (h : i.toNat < n) (v : α ⟨i.toNat, h⟩) : + DArray n α := unsafeCast <| a.data.uset i (unsafeCast v) lcProof + +private unsafe def modifyFImpl [Functor f] (a : DArray n α) (i : Fin n) + (t : α i → f (α i)) : f (DArray n α) := + let v := unsafeCast <| a.data.get i + -- Make sure `v` is unshared, if possible, by replacing its array entry by `box(0)`. + let a := unsafeCast <| a.data.set i (unsafeCast ()) lcProof + fsetImpl a i <$> t v + +private unsafe def umodifyFImpl [Functor f] (a : DArray n α) (i : USize) (h : i.toNat < n) + (t : α ⟨i.toNat, h⟩ → f (α ⟨i.toNat, h⟩)) : f (DArray n α) := + let v := unsafeCast <| a.data.uget i lcProof + -- Make sure `v` is unshared, if possible, by replacing its array entry by `box(0)`. + let a := unsafeCast <| a.data.uset i (unsafeCast ()) lcProof + usetImpl a i h <$> t v + +private unsafe def pushImpl (a : DArray n α) (v : β) : + DArray (n+1) fun i => if h : i.val < n then α ⟨i.val, h⟩ else β := + unsafeCast <| a.data.push <| unsafeCast v + +private unsafe def popImpl (a : DArray (n+1) α) : DArray n fun i => α i.castSucc := + unsafeCast <| a.data.pop + +private unsafe def copyImpl (a : DArray n α) : DArray n α := + unsafeCast <| a.data.extract 0 n + +end unsafe_implementation + +attribute [implemented_by mkImpl] DArray.mk + +instance (α : Fin n → Type _) [(i : Fin n) → Inhabited (α i)] : Inhabited (DArray n α) where + default := mk fun _ => default + +/-- Gets the `DArray` item at index `i`. `O(1)`. -/ +@[implemented_by fgetImpl] +protected def fget : DArray n α → (i : Fin n) → α i + | mk fget => fget + +@[inherit_doc DArray.fget, inline] +protected def get (a : DArray n α) (i) (h : i < n := by get_elem_tactic) : α ⟨i, h⟩ := + a.fget ⟨i, h⟩ + +/-- Gets the `DArray` item at index `i : USize`. Slightly faster than `get`; `O(1)`. -/ +@[implemented_by ugetImpl] +protected def uget (a : DArray n α) (i : USize) (h : i.toNat < n) : α ⟨i.toNat, h⟩ := + a.fget ⟨i.toNat, h⟩ + +private def casesOnImpl.{u} {motive : DArray n α → Sort u} (a : DArray n α) + (h : (fget : (i : Fin n) → α i) → motive (.mk fget)) : motive a := + h a.fget + +attribute [implemented_by casesOnImpl] DArray.casesOn + +/-- Sets the `DArray` item at index `i`. `O(1)` if exclusive else `O(n)`. -/ +@[implemented_by fsetImpl] +protected def fset (a : DArray n α) (i : Fin n) (v : α i) : DArray n α := + mk fun j => if h : i = j then h ▸ v else a.get j + +/-- +Sets the `DArray` item at index `i : USize`. +Slightly faster than `set` and `O(1)` if exclusive else `O(n)`. +-/ +@[implemented_by usetImpl] +protected def uset (a : DArray n α) (i : USize) (h : i.toNat < n) (v : α ⟨i.toNat, h⟩) := + a.fset ⟨i.toNat, h⟩ v + +@[simp, inherit_doc DArray.fset] +protected abbrev set (a : DArray n α) (i) (h : i < n := by get_elem_tactic) (v : α ⟨i, h⟩) := + a.fset ⟨i, h⟩ v + +/-- Modifies the `DArray` item at index `i` using transform `t` and the functor `f`. -/ +@[implemented_by modifyFImpl] +protected def modifyF [Functor f] (a : DArray n α) (i : Fin n) + (t : α i → f (α i)) : f (DArray n α) := a.fset i <$> t (a.fget i) + +/-- Modifies the `DArray` item at index `i` using transform `t`. -/ +@[inline] +protected def modify (a : DArray n α) (i : Fin n) (t : α i → α i) : DArray n α := + a.modifyF (f:=Id) i t + +/-- Modifies the `DArray` item at index `i : USize` using transform `t` and the functor `f`. -/ +@[implemented_by umodifyFImpl] +protected def umodifyF [Functor f] (a : DArray n α) (i : USize) (h : i.toNat < n) + (t : α ⟨i.toNat, h⟩ → f (α ⟨i.toNat, h⟩)) : f (DArray n α) := a.uset i h <$> t (a.uget i h) + +/-- Modifies the `DArray` item at index `i : USize` using transform `t`. -/ +@[inline] +protected def umodify (a : DArray n α) (i : USize) (h : i.toNat < n) + (t : α ⟨i.toNat, h⟩ → α ⟨i.toNat, h⟩) : DArray n α := + a.umodifyF (f:=Id) i h t + +/-- Copies the `DArray` to an exclusive `DArray`. `O(1)` if exclusive else `O(n)`. -/ +@[implemented_by copyImpl] +protected def copy (a : DArray n α) : DArray n α := mk a.fget + +/-- Push an element onto the end of a `DArray`. `O(1)` if exclusive else `O(n)`. -/ +@[implemented_by pushImpl] +protected def push (a : DArray n α) (v : β) : + DArray (n+1) fun i => if h : i.val < n then α ⟨i.val, h⟩ else β := + mk fun i => if h : i.val < n then dif_pos h ▸ a.fget ⟨i.val, h⟩ else dif_neg h ▸ v + +/-- Delete the last item of a `DArray`. `O(1)`. -/ +@[implemented_by popImpl] +protected def pop (a : DArray (n+1) α) : DArray n fun i => α i.castSucc := + mk fun i => a.get i.castSucc diff --git a/Batteries/Data/DArray/Lemmas.lean b/Batteries/Data/DArray/Lemmas.lean new file mode 100644 index 0000000000..5e907b1b2f --- /dev/null +++ b/Batteries/Data/DArray/Lemmas.lean @@ -0,0 +1,74 @@ +/- +Copyright (c) 2024 François G. Dorais. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: François G. Dorais +-/ + +import Batteries.Data.DArray.Basic + +namespace Batteries.DArray + +@[ext] +protected theorem ext : {a b : DArray n α} → (∀ i, a.fget i = b.fget i) → a = b + | mk _, mk _, h => congrArg _ <| funext fun i => h i + +@[simp] +theorem fget_mk (i : Fin n) : DArray.fget (.mk init) i = init i := rfl + +theorem fset_mk {α : Fin n → Type _} {init : (i : Fin n) → α i} (i : Fin n) (v : α i) : + DArray.fset (.mk init) i v = .mk fun j => if h : i = j then h ▸ v else init j := rfl + +@[simp] +theorem fget_fset (a : DArray n α) (i : Fin n) (v : α i) : (a.fset i v).fget i = v := by + simp only [DArray.fget, DArray.fset, dif_pos] + +theorem fget_fset_ne (a : DArray n α) (v : α i) (h : i ≠ j) : (a.fset i v).fget j = a.fget j := by + simp only [DArray.fget, DArray.fset, dif_neg h]; rfl + +@[simp] +theorem fset_fset (a : DArray n α) (i : Fin n) (v w : α i) : + (a.fset i v).fset i w = a.fset i w := by + ext j + if h : i = j then + rw [← h, fget_fset, fget_fset] + else + rw [fget_fset_ne _ _ h, fget_fset_ne _ _ h, fget_fset_ne _ _ h] + +theorem fget_modifyF [Functor f] [LawfulFunctor f] (a : DArray n α) (i : Fin n) + (t : α i → f (α i)) : (DArray.fget . i) <$> a.modifyF i t = t (a.fget i) := by + simp [DArray.modifyF] + +@[simp] +theorem fget_modify (a : DArray n α) (i : Fin n) (t : α i → α i) : + (a.modify i t).fget i = t (a.fget i) := fget_modifyF (f:=Id) a i t + +theorem fget_modify_ne (a : DArray n α) (t : α i → α i) (h : i ≠ j) : + (a.modify i t).fget j = a.fget j := fget_fset_ne _ _ h + +@[simp] +theorem set_modify (a : DArray n α) (i : Fin n) (t : α i → α i) (v : α i) : + (a.fset i v).modify i t = a.fset i (t v) := by + ext j + if h : i = j then + cases h; simp + else + simp [h, fget_modify_ne, fget_fset_ne] + +@[simp] +theorem uget_eq_fget (a : DArray n α) (i : USize) (h : i.toNat < n) : + a.uget i h = a.fget ⟨i.toNat, h⟩ := rfl + +@[simp] +theorem uset_eq_fset (a : DArray n α) (i : USize) (h : i.toNat < n) (v : α ⟨i.toNat, h⟩) : + a.uset i h v = a.fset ⟨i.toNat, h⟩ v := rfl + +@[simp] +theorem umodifyF_eq_modifyF [Functor f] (a : DArray n α) (i : USize) (h : i.toNat < n) + (t : α ⟨i.toNat, h⟩ → f (α ⟨i.toNat, h⟩)) : a.umodifyF i h t = a.modifyF ⟨i.toNat, h⟩ t := rfl + +@[simp] +theorem umodify_eq_modify (a : DArray n α) (i : USize) (h : i.toNat < n) + (t : α ⟨i.toNat, h⟩ → α ⟨i.toNat, h⟩) : a.umodify i h t = a.modify ⟨i.toNat, h⟩ t := rfl + +@[simp] +theorem copy_eq (a : DArray n α) : a.copy = a := rfl diff --git a/test/darray.lean b/test/darray.lean new file mode 100644 index 0000000000..0af1a42520 --- /dev/null +++ b/test/darray.lean @@ -0,0 +1,28 @@ +import Batteries.Data.DArray + +open Batteries + +def foo : DArray 3 fun | 0 => String | 1 => Nat | 2 => Array Nat := + .mk fun | 0 => "foo" | 1 => 42 | 2 => #[4, 2] + +def bar := foo.set 0 "bar" + +#guard foo.get 0 == "foo" +#guard foo.get 1 == 42 +#guard foo.get 2 == #[4, 2] + +#guard (foo.set 1 1).get 0 == "foo" +#guard (foo.set 1 1).get 1 == 1 +#guard (foo.set 1 1).get 2 == #[4, 2] + +#guard bar.get 0 == "bar" +#guard (bar.set 0 (foo.get 0)).get 0 == "foo" +#guard ((bar.set 0 "baz").set 1 1).get 0 == "baz" +#guard ((bar.set 0 "baz").set 0 "foo").get 0 == "foo" +#guard ((bar.set 0 "foo").set 0 "baz").get 0 == "baz" + +def Batteries.DArray.head : DArray (n+1) α → α 0 + | mk f => f 0 + +#guard foo.head == "foo" +#guard bar.head == "bar"