Skip to content

Commit

Permalink
Merge pull request #21 from opencompl/upstream-div-alt
Browse files Browse the repository at this point in the history
refactor: bundle `wn` and `wr` into DivModState
  • Loading branch information
alexkeizer authored Sep 24, 2024
2 parents 69aef1d + 2ba45c9 commit c9c274e
Showing 1 changed file with 131 additions and 87 deletions.
218 changes: 131 additions & 87 deletions src/Init/Data/BitVec/Bitblast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -514,11 +514,24 @@ theorem umod_eq_of_mul_add_toNat {d n q r : BitVec w} (hrd : r < d)

/-- `DivModState` is a structure that maintains the state of recursive `divrem` calls. -/
structure DivModState (w : Nat) : Type where
/-- The number of bits in the numerator that are not yet processed -/
wn : Nat
/-- The number of bits in the remainder (and quotient) -/
wr : Nat
/-- The current quotient. -/
q : BitVec w
/-- The current remainder. -/
r : BitVec w


/-- `DivModArgs` contains the arguments to a `divrem` call which remain constant throughout
execution -/
structure DivModArgs (w : Nat) where
/-- the numerator (aka, dividend) -/
n : BitVec w
/-- the denumerator (aka, divisor)-/
d : BitVec w

/-- A `DivModState` is lawful if the remainder width `wr` plus the numerator width `wn` equals `w`,
and the bitvectors `r` and `n` have values in the bounds given by bitwidths `wr`, resp. `wn`.
Expand All @@ -531,72 +544,79 @@ the values are within their respective bounds.
We start with `wn = w` and `wr = 0`, and then in each step, we decrement `wn` and increment `wr`.
In this way, we grow a legal remainder in each loop iteration.
-/
structure DivModState.Lawful (w wr wn : Nat) (n d : BitVec w)
(qr : DivModState w) : Prop where
structure DivModState.Lawful {w : Nat} (args : DivModArgs w) (qr : DivModState w) : Prop where
/-- The sum of widths of the dividend and remainder is `w`. -/
hwrn : wr + wn = w
hwrn : qr.wr + qr.wn = w
/-- The denominator is positive. -/
hdPos : 0 < d
hdPos : 0 < args.d
/-- The remainder is strictly less than the denominator. -/
hrLtDivisor : qr.r.toNat < d.toNat
hrLtDivisor : qr.r.toNat < args.d.toNat
/-- The remainder is morally a `Bitvec wr`, and so has value less than `2^wr`. -/
hrWidth : qr.r.toNat < 2^wr
hrWidth : qr.r.toNat < 2^qr.wr
/-- The quotient is morally a `Bitvec wr`, and so has value less than `2^wr`. -/
hqWidth : qr.q.toNat < 2^wr
hqWidth : qr.q.toNat < 2^qr.wr
/-- The low `(w - wn)` bits of `n` obey the invariant for division. -/
hdiv : n.toNat >>> wn = d.toNat * qr.q.toNat + qr.r.toNat
hdiv : args.n.toNat >>> qr.wn = args.d.toNat * qr.q.toNat + qr.r.toNat

/-- A lawful DivModState implies `w > 0`. -/
def DivModState.Lawful.hw {qr : DivModState w}
{h : DivModState.Lawful w wr wn n d qr} : 0 < w := by
def DivModState.Lawful.hw {args : DivModArgs w} {qr : DivModState w}
{h : DivModState.Lawful args qr} : 0 < w := by
have hd := h.hdPos
rcases w with rfl | w
· have hcontra : d = 0#0 := by apply Subsingleton.elim
· have hcontra : args.d = 0#0 := by apply Subsingleton.elim
rw [hcontra] at hd
simp at hd
· omega

/-- An initial value with both `q, r = 0`. -/
def DivModState.init (w : Nat) : DivModState w := {
wn := w
wr := 0
q := 0#w
r := 0#w
}

/-- The initial state is lawful. -/
def DivModState.Lawful.init (w : Nat) (n d : BitVec w) (hd : 0#w < d) :
DivModState.Lawful w 0 w n d (DivModState.init w) := {
hwrn := by omega,
hdPos := by assumption
hrLtDivisor := by simp [BitVec.lt_def] at hd ⊢; assumption
hrWidth := by simp [DivModState.init],
hqWidth := by simp [DivModState.init],
hdiv := by
simp only [DivModState.init, toNat_ofNat, zero_mod, Nat.mul_zero, Nat.add_zero];
rw [Nat.shiftRight_eq_div_pow]
apply Nat.div_eq_of_lt n.isLt
}
def DivModState.lawful_init {w : Nat} (args : DivModArgs w) (hd : 0#w < args.d) :
DivModState.Lawful args (DivModState.init w) := by
simp only [BitVec.DivModState.init]
exact {
hwrn := by simp only; omega,
hdPos := by assumption
hrLtDivisor := by simp [BitVec.lt_def] at hd ⊢; assumption
hrWidth := by simp [DivModState.init],
hqWidth := by simp [DivModState.init],
hdiv := by
simp only [DivModState.init, toNat_ofNat, zero_mod, Nat.mul_zero, Nat.add_zero];
rw [Nat.shiftRight_eq_div_pow]
apply Nat.div_eq_of_lt args.n.isLt
}

/--
A lawful DivModState with a fully consumed dividend (`wn = 0`) witnesses that the
quotient has been correctly computed.
-/
theorem DivModState.udiv_eq_of_lawful_zero {qr : DivModState w}
(h : DivModState.Lawful w w 0 n d qr) :
theorem DivModState.udiv_eq_of_lawful {n d : BitVec w} {qr : DivModState w}
(h_lawful : DivModState.Lawful {n, d} qr)
(h_final : qr.wn = 0) :
n.udiv d = qr.q := by
apply udiv_eq_of_mul_add_toNat h.hdPos h.hrLtDivisor
have hdiv := h.hdiv
apply udiv_eq_of_mul_add_toNat h_lawful.hdPos h_lawful.hrLtDivisor
have hdiv := h_lawful.hdiv
simp only [h_final] at *
omega

/--
A lawful DivModState with a fully consumed dividend (`wn = 0`) witnesses that the
remainder has been correctly computed.
-/
theorem DivModState.umod_eq_of_lawful_zero {qr : DivModState w}
(h : DivModState.Lawful w w 0 n d qr) :
theorem DivModState.umod_eq_of_lawful {qr : DivModState w}
(h : DivModState.Lawful {n, d} qr)
(h_final : qr.wn = 0) :
n.umod d = qr.r := by
apply umod_eq_of_mul_add_toNat h.hrLtDivisor
have hdiv := h.hdiv
simp only [shiftRight_zero] at hdiv
simp only [h_final] at *
exact hdiv.symm

/-! ### DivModState.Poised -/
Expand All @@ -608,52 +628,51 @@ one numerator bit left to process `(0 < wn)`
The input to the shift subtractor is a legal input to `divrem`, and we also need to have an
input bit to perform shift subtraction on, and thus we need `0 < wn`.
-/
structure DivModState.Poised (w wr wn : Nat) (n d : BitVec w) (qr : DivModState w)
extends DivModState.Lawful w wr wn n d qr : Type where
structure DivModState.Poised {w : Nat} (args : DivModArgs w) (qr : DivModState w)
extends DivModState.Lawful args qr : Type where
/-- Only perform a round of shift-subtract if we have dividend bits. -/
hwn_lt : 0 < wn
hwn_lt : 0 < qr.wn

/--
In the shift subtract input, the dividend is at least one bit long (`wn > 0`), so
the remainder has bits to be computed (`wr < w`).
-/
def DivModState.wr_lt_w {qr : DivModState w} (h : qr.Poised wr wn n d) : wr < w := by
def DivModState.wr_lt_w {qr : DivModState w} (h : qr.Poised args) : qr.wr < w := by
have hwrn := h.hwrn
have hwn_lt := h.hwn_lt
omega

/-- If we have extra bits to spare in `n`,
then we know the div mod state is poised to run another round of the shift subtractor. -/
def DivModState.Lawful.toPoised {qr : DivModState w}
(h : qr.Lawful w wr (wn + 1) n d) : qr.Poised wr (wn + 1) n d :=
{ h with hwn_lt := by omega }

/-! ### Division shift subtractor -/

/--
One round of the division algorithm, that tries to perform a subtract shift.
Note that this should only be called when `r.msb = false`, so we will not overflow.
-/
def divSubtractShift (n : BitVec w) (d : BitVec w) (wn : Nat) (qr : DivModState w) :
DivModState w :=
let r' := shiftConcat qr.r (n.getLsbD (wn - 1))
def divSubtractShift (args : DivModArgs w) (qr : DivModState w) : DivModState w :=
let {n, d} := args
let wn := qr.wn - 1
let wr := qr.wr + 1
let r' := shiftConcat qr.r (n.getLsbD wn)
if r' < d
then {
q := qr.q.shiftConcat false, -- If `r' < d`, then we do not have a quotient bit.
r := r'
wn, wr
} else {
q := qr.q.shiftConcat true, -- Otherwise, `r' ≥ d`, and we have a quotient bit.
r := r' - d -- we subtract to maintain the invariant that `r < d`.
wn, wr
}

/-- The value of shifting right by `wn - 1` equals shifting by `wn` and grabbing the lsb at `(wn - 1)`. -/
theorem DivModState.toNat_shiftRight_sub_one_eq
(qr : DivModState w) (h : qr.Poised wr wn n d):
n.toNat >>> (wn - 1) = (n.toNat >>> wn) * 2 + (n.getLsbD (wn - 1)).toNat := by
show BitVec.toNat (n >>> (wn - 1)) = _
{args : DivModArgs w} {qr : DivModState w} (h : qr.Poised args) :
args.n.toNat >>> (qr.wn - 1)
= (args.n.toNat >>> qr.wn) * 2 + (args.n.getLsbD (qr.wn - 1)).toNat := by
show BitVec.toNat (args.n >>> (qr.wn - 1)) = _
have {..} := h -- break the structure down for `omega`
rw [shiftRight_sub_one_eq_shiftConcat n h.hwn_lt]
rw [toNat_shiftConcat_eq_of_lt (k := w - wn)]
rw [shiftRight_sub_one_eq_shiftConcat args.n h.hwn_lt]
rw [toNat_shiftConcat_eq_of_lt (k := w - qr.wn)]
· simp
· omega
· apply BitVec.toNat_ushiftRight_lt
Expand All @@ -671,13 +690,14 @@ private theorem two_mul_add_sub_lt_of_lt_of_lt_two (h : a < x) (hy : y < 2) :

/-- We show that the output of `divSubtractShift` is lawful, which tells us that it
obeys the division equation. -/
def divSubtractShiftProof (qr : DivModState w) (h : qr.Poised wr wn n d) :
DivModState.Lawful w (wr + 1) (wn - 1) n d (divSubtractShift n d wn qr) := by
theorem lawful_divSubtractShift (qr : DivModState w) (h : qr.Poised args) :
DivModState.Lawful args (divSubtractShift args qr) := by
rcases args with ⟨n, d⟩
simp only [divSubtractShift, decide_eq_true_eq]
-- We add these hypotheses for `omega` to find them later.
have ⟨⟨hrwn, hd, hrd, hr, hn, hrnd⟩, hwn_lt⟩ := h
have : d.toNat * (qr.q.toNat * 2) = d.toNat * qr.q.toNat * 2 := by rw [Nat.mul_assoc]
by_cases rltd : shiftConcat qr.r (n.getLsbD (wn - 1)) < d
by_cases rltd : shiftConcat qr.r (n.getLsbD (qr.wn - 1)) < d
· simp only [rltd, ↓reduceIte]
constructor <;> try bv_omega
case pos.hrWidth => apply toNat_shiftConcat_lt_of_lt <;> omega
Expand All @@ -697,9 +717,9 @@ def divSubtractShiftProof (qr : DivModState w) (h : qr.Poised wr wn n d) :
apply two_mul_add_sub_lt_of_lt_of_lt_two <;> bv_omega
case neg.hrWidth =>
simp only
have hdr' : d ≤ (qr.r.shiftConcat (n.getLsbD (wn - 1))) :=
have hdr' : d ≤ (qr.r.shiftConcat (n.getLsbD (qr.wn - 1))) :=
BitVec.le_iff_not_lt.mp rltd
have hr' : ((qr.r.shiftConcat (n.getLsbD (wn - 1)))).toNat < 2 ^ (wr + 1) := by
have hr' : ((qr.r.shiftConcat (n.getLsbD (qr.wn - 1)))).toNat < 2 ^ (qr.wr + 1) := by
apply toNat_shiftConcat_lt_of_lt <;> bv_omega
rw [BitVec.toNat_sub_of_le hdr']
omega
Expand All @@ -718,61 +738,85 @@ def divSubtractShiftProof (qr : DivModState w) (h : qr.Poised wr wn n d) :
/-! ### Core division algorithm circuit -/

/-- A recursive definition of division for bitblasting, in terms of a shift-subtraction circuit. -/
def divRec (w wr wn : Nat) (n d : BitVec w) (qr : DivModState w) :
def divRec {w : Nat} (m : Nat) (args : DivModArgs w) (qr : DivModState w) :
DivModState w :=
match wn with
match m with
| 0 => qr
| wn + 1 =>
divRec w (wr + 1) wn n d <| divSubtractShift n d (wn + 1) qr
| m + 1 => divRec m args <| divSubtractShift args qr

@[simp]
theorem divRec_zero (qr : DivModState w) :
divRec w w 0 n d qr = qr := rfl
divRec 0 args qr = qr := rfl

@[simp]
theorem divRec_succ (wn : Nat) (qr : DivModState w) :
divRec w wr (wn + 1) n d qr =
divRec w (wr + 1) wn n d
(divSubtractShift n d (wn + 1) qr) := rfl

theorem divRec_correct {n d : BitVec w} (qr : DivModState w)
(h : DivModState.Lawful w wr wn n d qr) :
DivModState.Lawful w w 0 n d (divRec w wr wn n d qr) := by
induction wn generalizing wr qr
theorem divRec_succ (m : Nat) (args : DivModArgs w) (qr : DivModState w) :
divRec (m + 1) args qr =
divRec m args (divSubtractShift args qr) := rfl

/-- The output of `divRec` is a lawful state -/
theorem lawful_divRec {args : DivModArgs w} {qr : DivModState w}
(h : DivModState.Lawful args qr) :
DivModState.Lawful args (divRec qr.wn args qr) := by
generalize hm : qr.wn = m
induction m generalizing qr
case zero =>
exact h
case succ wn' ih =>
simp only [divRec_succ]
apply ih
· apply lawful_divSubtractShift
constructor
· assumption
· omega
· simp only [divSubtractShift, hm]
split <;> rfl

/-- The output of `divRec` has no more bits left to process (i.e., `wn = 0`) -/
@[simp]
theorem wn_divRec (args : DivModArgs w) (qr : DivModState w) :
(divRec qr.wn args qr).wn = 0 := by
generalize hm : qr.wn = m
induction m generalizing qr
case zero =>
unfold divRec
simp [← h.hwrn, h]
assumption
case succ wn' ih =>
simp only [divRec]
apply ih
apply divSubtractShiftProof (w := w)
(wr := wr)
(wn := wn' + 1)
exact h.toPoised
simp only [divSubtractShift, hm]
split <;> rfl

/-- The result of `udiv` agrees with the result of the division recurrence. -/
theorem udiv_eq_divRec (hd : 0#w < d) :
let out := divRec w 0 w n d (DivModState.init w)
let out := divRec w {n, d} (DivModState.init w)
n.udiv d = out.q := by
have := DivModState.Lawful.init w n d hd
have := divRec_correct (DivModState.init w) this
apply DivModState.udiv_eq_of_lawful_zero this
have := DivModState.lawful_init {n, d} hd
have := lawful_divRec this
apply DivModState.udiv_eq_of_lawful this (wn_divRec ..)

/-- The result of `umod` agrees with the result of the division recurrence. -/
theorem umod_eq_divRec (hd : 0#w < d) :
let out := divRec w 0 w n d (DivModState.init w)
let out := divRec w {n, d} (DivModState.init w)
n.umod d = out.r := by
have := DivModState.Lawful.init w n d hd
have := divRec_correct (DivModState.init w) this
apply DivModState.umod_eq_of_lawful_zero this
have := DivModState.lawful_init {n, d} hd
have := lawful_divRec this
apply DivModState.umod_eq_of_lawful this (wn_divRec ..)

@[simp]
theorem divRec_succ' (wn : Nat) (qr : DivModState w) :
divRec w wr (wn + 1) n d qr =
let r' := shiftConcat qr.r (n.getLsbD wn)
let input : DivModState w :=
if r' < d then ⟨qr.q.shiftConcat false, r'⟩ else ⟨qr.q.shiftConcat true, r' - d⟩
divRec w (wr + 1) wn n d input := by
theorem divRec_succ' (m : Nat) (args : DivModArgs w) (qr : DivModState w) :
divRec (m+1) args qr =
let wn := qr.wn - 1
let wr := qr.wr + 1
let r' := shiftConcat qr.r (args.n.getLsbD wn)
let input : DivModState _ :=
if r' < args.d then {
q := qr.q.shiftConcat false,
r := r'
wn, wr
} else {
q := qr.q.shiftConcat true,
r := r' - args.d
wn, wr
}
divRec m args input := by
simp [divRec_succ, divSubtractShift]

/- ### Arithmetic shift right (sshiftRight) recurrence -/
Expand Down

0 comments on commit c9c274e

Please sign in to comment.