Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: bundle wn and wr into DivModState #21

Merged
merged 5 commits into from
Sep 24, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 140 additions & 86 deletions src/Init/Data/BitVec/Bitblast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -514,11 +514,24 @@ theorem umod_eq_of_mul_add_toNat {d n q r : BitVec w} (hrd : r < d)

/-- `DivModState` is a structure that maintains the state of recursive `divrem` calls. -/
structure DivModState (w : Nat) : Type where
/-- The number of bits in the numerator that are not yet processed -/
wn : Nat
/-- The number of bits in the remainder (and quotient) -/
wr : Nat
/-- The current quotient. -/
q : BitVec w
/-- The current remainder. -/
r : BitVec w


/-- `DivModArgs` contains the arguments to a `divrem` call which remain constant throughout
execution -/
structure DivModArgs (w : Nat) where
/-- the numerator (aka, dividend) -/
n : BitVec w
/-- the denumerator (aka, divisor)-/
d : BitVec w

/-- A `DivModState` is lawful if the remainder width `wr` plus the numerator width `wn` equals `w`,
and the bitvectors `r` and `n` have values in the bounds given by bitwidths `wr`, resp. `wn`.

Expand All @@ -531,129 +544,145 @@ the values are within their respective bounds.
We start with `wn = w` and `wr = 0`, and then in each step, we decrement `wn` and increment `wr`.
In this way, we grow a legal remainder in each loop iteration.
-/
structure DivModState.Lawful (w wr wn : Nat) (n d : BitVec w)
(qr : DivModState w) : Prop where
structure DivModState.Lawful {w : Nat} (args : DivModArgs w) (qr : DivModState w) : Prop where
/-- The sum of widths of the dividend and remainder is `w`. -/
hwrn : wr + wn = w
hwrn : qr.wr + qr.wn = w
/-- The denominator is positive. -/
hdPos : 0 < d
hdPos : 0 < args.d
/-- The remainder is strictly less than the denominator. -/
hrLtDivisor : qr.r.toNat < d.toNat
hrLtDivisor : qr.r.toNat < args.d.toNat
/-- The remainder is morally a `Bitvec wr`, and so has value less than `2^wr`. -/
hrWidth : qr.r.toNat < 2^wr
hrWidth : qr.r.toNat < 2^qr.wr
/-- The quotient is morally a `Bitvec wr`, and so has value less than `2^wr`. -/
hqWidth : qr.q.toNat < 2^wr
hqWidth : qr.q.toNat < 2^qr.wr
/-- The low `(w - wn)` bits of `n` obey the invariant for division. -/
hdiv : n.toNat >>> wn = d.toNat * qr.q.toNat + qr.r.toNat
hdiv : args.n.toNat >>> qr.wn = args.d.toNat * qr.q.toNat + qr.r.toNat

/-- A lawful DivModState implies `w > 0`. -/
def DivModState.Lawful.hw {qr : DivModState w}
{h : DivModState.Lawful w wr wn n d qr} : 0 < w := by
def DivModState.Lawful.hw {args : DivModArgs w} {qr : DivModState w}
{h : DivModState.Lawful args qr} : 0 < w := by
have hd := h.hdPos
rcases w with rfl | w
· have hcontra : d = 0#0 := by apply Subsingleton.elim
· have hcontra : args.d = 0#0 := by apply Subsingleton.elim
rw [hcontra] at hd
simp at hd
· omega

/-- An initial value with both `q, r = 0`. -/
def DivModState.init (w : Nat) : DivModState w := {
wn := w
wr := 0
q := 0#w
r := 0#w
}

/-- The initial state is lawful. -/
def DivModState.Lawful.init (w : Nat) (n d : BitVec w) (hd : 0#w < d) :
DivModState.Lawful w 0 w n d (DivModState.init w) := {
hwrn := by omega,
hdPos := by assumption
hrLtDivisor := by simp [BitVec.lt_def] at hd ⊢; assumption
hrWidth := by simp [DivModState.init],
hqWidth := by simp [DivModState.init],
hdiv := by
simp only [DivModState.init, toNat_ofNat, zero_mod, Nat.mul_zero, Nat.add_zero];
rw [Nat.shiftRight_eq_div_pow]
apply Nat.div_eq_of_lt n.isLt
}
def DivModState.Lawful.init {w : Nat} (args : DivModArgs w) (hd : 0#w < args.d) :
Copy link
Author

@alexkeizer alexkeizer Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def DivModState.Lawful.init {w : Nat} (args : DivModArgs w) (hd : 0#w < args.d) :
def DivModState.lawful_init {w : Nat} (args : DivModArgs w) (hd : 0#w < args.d) :

DivModState.Lawful args (DivModState.init w) := by
simp only [BitVec.DivModState.init]
exact {
hwrn := by simp only; omega,
hdPos := by assumption
hrLtDivisor := by simp [BitVec.lt_def] at hd ⊢; assumption
hrWidth := by simp [DivModState.init],
hqWidth := by simp [DivModState.init],
hdiv := by
simp only [DivModState.init, toNat_ofNat, zero_mod, Nat.mul_zero, Nat.add_zero];
rw [Nat.shiftRight_eq_div_pow]
apply Nat.div_eq_of_lt args.n.isLt
}

/--
A lawful DivModState with a fully consumed dividend (`wn = 0`) witnesses that the
quotient has been correctly computed.
-/
theorem DivModState.udiv_eq_of_lawful_zero {qr : DivModState w}
(h : DivModState.Lawful w w 0 n d qr) :
theorem DivModState.udiv_eq_of_lawful {n d : BitVec w} {qr : DivModState w}
(h_lawful : DivModState.Lawful {n, d} qr)
(h_final : qr.wn = 0) :
n.udiv d = qr.q := by
apply udiv_eq_of_mul_add_toNat h.hdPos h.hrLtDivisor
have hdiv := h.hdiv
apply udiv_eq_of_mul_add_toNat h_lawful.hdPos h_lawful.hrLtDivisor
have hdiv := h_lawful.hdiv
simp only [h_final] at *
omega

/--
A lawful DivModState with a fully consumed dividend (`wn = 0`) witnesses that the
remainder has been correctly computed.
-/
theorem DivModState.umod_eq_of_lawful_zero {qr : DivModState w}
(h : DivModState.Lawful w w 0 n d qr) :
theorem DivModState.umod_eq_of_lawful {qr : DivModState w}
(h : DivModState.Lawful {n, d} qr)
(h_final : qr.wn = 0) :
n.umod d = qr.r := by
apply umod_eq_of_mul_add_toNat h.hrLtDivisor
have hdiv := h.hdiv
simp only [shiftRight_zero] at hdiv
simp only [h_final] at *
exact hdiv.symm

/-! ### DivModState.Poised -/

/-
TODO: Can we redefine `qr.Poised` as simply `¬qr.IsFinal`?
That would mean we have extra `Lawful` assumptions in some spots, but in others we'll have less,
as we won't have to carry around the `args` -/
/--
A `Poised` DivModState is a state which is `Lawful` and furthermore, has at least
one numerator bit left to process `(0 < wn)`

The input to the shift subtractor is a legal input to `divrem`, and we also need to have an
input bit to perform shift subtraction on, and thus we need `0 < wn`.
-/
structure DivModState.Poised (w wr wn : Nat) (n d : BitVec w) (qr : DivModState w)
extends DivModState.Lawful w wr wn n d qr : Type where
structure DivModState.Poised {w : Nat} (args : DivModArgs w) (qr : DivModState w)
extends DivModState.Lawful args qr : Type where
/-- Only perform a round of shift-subtract if we have dividend bits. -/
hwn_lt : 0 < wn
hwn_lt : 0 < qr.wn

/--
In the shift subtract input, the dividend is at least one bit long (`wn > 0`), so
the remainder has bits to be computed (`wr < w`).
-/
def DivModState.wr_lt_w {qr : DivModState w} (h : qr.Poised wr wn n d) : wr < w := by
def DivModState.wr_lt_w {qr : DivModState w} (h : qr.Poised args) : qr.wr < w := by
have hwrn := h.hwrn
have hwn_lt := h.hwn_lt
omega

/-- If we have extra bits to spare in `n`,
then we know the div mod state is poised to run another round of the shift subtractor. -/
def DivModState.Lawful.toPoised {qr : DivModState w}
(h : qr.Lawful w wr (wn + 1) n d) : qr.Poised wr (wn + 1) n d :=
{ h with hwn_lt := by omega }
-- /-- If we have extra bits to spare in `n`,
-- then we know the div mod state is poised to run another round of the shift subtractor. -/
-- def DivModState.Lawful.toPoised {qr : DivModState w}
-- (h : qr.Lawful w wr (wn + 1) n d) : qr.Poised wr (wn + 1) n d :=
-- { h with hwn_lt := by omega }

/-! ### Division shift subtractor -/

/--
One round of the division algorithm, that tries to perform a subtract shift.
Note that this should only be called when `r.msb = false`, so we will not overflow.
-/
def divSubtractShift (n : BitVec w) (d : BitVec w) (wn : Nat) (qr : DivModState w) :
DivModState w :=
let r' := shiftConcat qr.r (n.getLsbD (wn - 1))
def divSubtractShift (args : DivModArgs w) (qr : DivModState w) : DivModState w :=
let {n, d} := args
let wn := qr.wn - 1
let wr := qr.wr + 1
let r' := shiftConcat qr.r (n.getLsbD wn)
if r' < d
then {
q := qr.q.shiftConcat false, -- If `r' < d`, then we do not have a quotient bit.
r := r'
wn, wr
} else {
q := qr.q.shiftConcat true, -- Otherwise, `r' ≥ d`, and we have a quotient bit.
r := r' - d -- we subtract to maintain the invariant that `r < d`.
wn, wr
}

/-- The value of shifting right by `wn - 1` equals shifting by `wn` and grabbing the lsb at `(wn - 1)`. -/
theorem DivModState.toNat_shiftRight_sub_one_eq
(qr : DivModState w) (h : qr.Poised wr wn n d):
n.toNat >>> (wn - 1) = (n.toNat >>> wn) * 2 + (n.getLsbD (wn - 1)).toNat := by
show BitVec.toNat (n >>> (wn - 1)) = _
{args : DivModArgs w} {qr : DivModState w} (h : qr.Poised args) :
args.n.toNat >>> (qr.wn - 1)
= (args.n.toNat >>> qr.wn) * 2 + (args.n.getLsbD (qr.wn - 1)).toNat := by
show BitVec.toNat (args.n >>> (qr.wn - 1)) = _
have {..} := h -- break the structure down for `omega`
rw [shiftRight_sub_one_eq_shiftConcat n h.hwn_lt]
rw [toNat_shiftConcat_eq_of_lt (k := w - wn)]
rw [shiftRight_sub_one_eq_shiftConcat args.n h.hwn_lt]
rw [toNat_shiftConcat_eq_of_lt (k := w - qr.wn)]
· simp
· omega
· apply BitVec.toNat_ushiftRight_lt
Expand All @@ -671,13 +700,14 @@ private theorem two_mul_add_sub_lt_of_lt_of_lt_two (h : a < x) (hy : y < 2) :

/-- We show that the output of `divSubtractShift` is lawful, which tells us that it
obeys the division equation. -/
def divSubtractShiftProof (qr : DivModState w) (h : qr.Poised wr wn n d) :
DivModState.Lawful w (wr + 1) (wn - 1) n d (divSubtractShift n d wn qr) := by
def lawful_divSubtractShift (qr : DivModState w) (h : qr.Poised args) :
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a fan of divSubtractShiftProof as a name, so I renamed it to lawful_...

alexkeizer marked this conversation as resolved.
Show resolved Hide resolved
DivModState.Lawful args (divSubtractShift args qr) := by
rcases args with ⟨n, d⟩
simp only [divSubtractShift, decide_eq_true_eq]
-- We add these hypotheses for `omega` to find them later.
have ⟨⟨hrwn, hd, hrd, hr, hn, hrnd⟩, hwn_lt⟩ := h
have : d.toNat * (qr.q.toNat * 2) = d.toNat * qr.q.toNat * 2 := by rw [Nat.mul_assoc]
by_cases rltd : shiftConcat qr.r (n.getLsbD (wn - 1)) < d
by_cases rltd : shiftConcat qr.r (n.getLsbD (qr.wn - 1)) < d
· simp only [rltd, ↓reduceIte]
constructor <;> try bv_omega
case pos.hrWidth => apply toNat_shiftConcat_lt_of_lt <;> omega
Expand All @@ -697,9 +727,9 @@ def divSubtractShiftProof (qr : DivModState w) (h : qr.Poised wr wn n d) :
apply two_mul_add_sub_lt_of_lt_of_lt_two <;> bv_omega
case neg.hrWidth =>
simp only
have hdr' : d ≤ (qr.r.shiftConcat (n.getLsbD (wn - 1))) :=
have hdr' : d ≤ (qr.r.shiftConcat (n.getLsbD (qr.wn - 1))) :=
BitVec.le_iff_not_lt.mp rltd
have hr' : ((qr.r.shiftConcat (n.getLsbD (wn - 1)))).toNat < 2 ^ (wr + 1) := by
have hr' : ((qr.r.shiftConcat (n.getLsbD (qr.wn - 1)))).toNat < 2 ^ (qr.wr + 1) := by
apply toNat_shiftConcat_lt_of_lt <;> bv_omega
rw [BitVec.toNat_sub_of_le hdr']
omega
Expand All @@ -718,61 +748,85 @@ def divSubtractShiftProof (qr : DivModState w) (h : qr.Poised wr wn n d) :
/-! ### Core division algorithm circuit -/

/-- A recursive definition of division for bitblasting, in terms of a shift-subtraction circuit. -/
def divRec (w wr wn : Nat) (n d : BitVec w) (qr : DivModState w) :
def divRec {w : Nat} (m : Nat) (args : DivModArgs w) (qr : DivModState w) :
DivModState w :=
match wn with
match m with
| 0 => qr
| wn + 1 =>
divRec w (wr + 1) wn n d <| divSubtractShift n d (wn + 1) qr
| m + 1 => divRec m args <| divSubtractShift args qr

@[simp]
theorem divRec_zero (qr : DivModState w) :
divRec w w 0 n d qr = qr := rfl
divRec 0 args qr = qr := rfl

@[simp]
theorem divRec_succ (wn : Nat) (qr : DivModState w) :
divRec w wr (wn + 1) n d qr =
divRec w (wr + 1) wn n d
(divSubtractShift n d (wn + 1) qr) := rfl

theorem divRec_correct {n d : BitVec w} (qr : DivModState w)
(h : DivModState.Lawful w wr wn n d qr) :
DivModState.Lawful w w 0 n d (divRec w wr wn n d qr) := by
induction wn generalizing wr qr
theorem divRec_succ (m : Nat) (args : DivModArgs w) (qr : DivModState w) :
divRec (m + 1) args qr =
divRec m args (divSubtractShift args qr) := rfl

/-- The output of `divRec` is a lawful state -/
theorem lawful_divRec {args : DivModArgs w} {qr : DivModState w}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to rename this from the somewhat ambiguous "correct" to "lawful"

(h : DivModState.Lawful args qr) :
DivModState.Lawful args (divRec qr.wn args qr) := by
generalize hm : qr.wn = m
induction m generalizing qr
case zero =>
exact h
case succ wn' ih =>
simp only [divRec_succ]
apply ih
· apply lawful_divSubtractShift
constructor
· assumption
· omega
· simp only [divSubtractShift, hm]
split <;> rfl

/-- The output of `divRec` has no more bits left to process (i.e., `wn = 0`) -/
@[simp]
theorem wn_divRec (args : DivModArgs w) (qr : DivModState w) :
(divRec qr.wn args qr).wn = 0 := by
generalize hm : qr.wn = m
induction m generalizing qr
case zero =>
unfold divRec
simp [← h.hwrn, h]
assumption
case succ wn' ih =>
simp only [divRec]
apply ih
apply divSubtractShiftProof (w := w)
(wr := wr)
(wn := wn' + 1)
exact h.toPoised
simp only [divSubtractShift, hm]
split <;> rfl

/-- The result of `udiv` agrees with the result of the division recurrence. -/
theorem udiv_eq_divRec (hd : 0#w < d) :
let out := divRec w 0 w n d (DivModState.init w)
let out := divRec w {n, d} (DivModState.init w)
n.udiv d = out.q := by
have := DivModState.Lawful.init w n d hd
have := divRec_correct (DivModState.init w) this
apply DivModState.udiv_eq_of_lawful_zero this
have := DivModState.Lawful.init {n, d} hd
have := lawful_divRec this
apply DivModState.udiv_eq_of_lawful this (wn_divRec ..)

/-- The result of `umod` agrees with the result of the division recurrence. -/
theorem umod_eq_divRec (hd : 0#w < d) :
let out := divRec w 0 w n d (DivModState.init w)
let out := divRec w {n, d} (DivModState.init w)
n.umod d = out.r := by
have := DivModState.Lawful.init w n d hd
have := divRec_correct (DivModState.init w) this
apply DivModState.umod_eq_of_lawful_zero this
have := DivModState.Lawful.init {n, d} hd
have := lawful_divRec this
apply DivModState.umod_eq_of_lawful this (wn_divRec ..)

@[simp]
theorem divRec_succ' (wn : Nat) (qr : DivModState w) :
divRec w wr (wn + 1) n d qr =
let r' := shiftConcat qr.r (n.getLsbD wn)
let input : DivModState w :=
if r' < d then ⟨qr.q.shiftConcat false, r'⟩ else ⟨qr.q.shiftConcat true, r' - d⟩
divRec w (wr + 1) wn n d input := by
theorem divRec_succ' (m : Nat) (args : DivModArgs w) (qr : DivModState w) :
divRec (m+1) args qr =
let wn := qr.wn - 1
let wr := qr.wr + 1
let r' := shiftConcat qr.r (args.n.getLsbD wn)
let input : DivModState _ :=
if r' < args.d then {
q := qr.q.shiftConcat false,
r := r'
wn, wr
} else {
q := qr.q.shiftConcat true,
r := r' - args.d
wn, wr
}
divRec m args input := by
simp [divRec_succ, divSubtractShift]

/- ### Arithmetic shift right (sshiftRight) recurrence -/
Expand Down
Loading