forked from leanprover/lean4
-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor: bundle wn
and wr
into DivModState
#21
Merged
Merged
Changes from 1 commit
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
decf9df
refactor: bundle `wn` and `wr` into DivModState
alexkeizer df8887e
remove dead code
alexkeizer 2839738
Update src/Init/Data/BitVec/Bitblast.lean
alexkeizer 84dfd6a
rename lawful_init
alexkeizer 2ba45c9
Merge branch 'upstream-div-alt' of https://github.com/opencompl/lean4…
alexkeizer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -514,11 +514,24 @@ theorem umod_eq_of_mul_add_toNat {d n q r : BitVec w} (hrd : r < d) | |
|
||
/-- `DivModState` is a structure that maintains the state of recursive `divrem` calls. -/ | ||
structure DivModState (w : Nat) : Type where | ||
/-- The number of bits in the numerator that are not yet processed -/ | ||
wn : Nat | ||
/-- The number of bits in the remainder (and quotient) -/ | ||
wr : Nat | ||
/-- The current quotient. -/ | ||
q : BitVec w | ||
/-- The current remainder. -/ | ||
r : BitVec w | ||
|
||
|
||
/-- `DivModArgs` contains the arguments to a `divrem` call which remain constant throughout | ||
execution -/ | ||
structure DivModArgs (w : Nat) where | ||
/-- the numerator (aka, dividend) -/ | ||
n : BitVec w | ||
/-- the denumerator (aka, divisor)-/ | ||
d : BitVec w | ||
|
||
/-- A `DivModState` is lawful if the remainder width `wr` plus the numerator width `wn` equals `w`, | ||
and the bitvectors `r` and `n` have values in the bounds given by bitwidths `wr`, resp. `wn`. | ||
|
||
|
@@ -531,129 +544,145 @@ the values are within their respective bounds. | |
We start with `wn = w` and `wr = 0`, and then in each step, we decrement `wn` and increment `wr`. | ||
In this way, we grow a legal remainder in each loop iteration. | ||
-/ | ||
structure DivModState.Lawful (w wr wn : Nat) (n d : BitVec w) | ||
(qr : DivModState w) : Prop where | ||
structure DivModState.Lawful {w : Nat} (args : DivModArgs w) (qr : DivModState w) : Prop where | ||
/-- The sum of widths of the dividend and remainder is `w`. -/ | ||
hwrn : wr + wn = w | ||
hwrn : qr.wr + qr.wn = w | ||
/-- The denominator is positive. -/ | ||
hdPos : 0 < d | ||
hdPos : 0 < args.d | ||
/-- The remainder is strictly less than the denominator. -/ | ||
hrLtDivisor : qr.r.toNat < d.toNat | ||
hrLtDivisor : qr.r.toNat < args.d.toNat | ||
/-- The remainder is morally a `Bitvec wr`, and so has value less than `2^wr`. -/ | ||
hrWidth : qr.r.toNat < 2^wr | ||
hrWidth : qr.r.toNat < 2^qr.wr | ||
/-- The quotient is morally a `Bitvec wr`, and so has value less than `2^wr`. -/ | ||
hqWidth : qr.q.toNat < 2^wr | ||
hqWidth : qr.q.toNat < 2^qr.wr | ||
/-- The low `(w - wn)` bits of `n` obey the invariant for division. -/ | ||
hdiv : n.toNat >>> wn = d.toNat * qr.q.toNat + qr.r.toNat | ||
hdiv : args.n.toNat >>> qr.wn = args.d.toNat * qr.q.toNat + qr.r.toNat | ||
|
||
/-- A lawful DivModState implies `w > 0`. -/ | ||
def DivModState.Lawful.hw {qr : DivModState w} | ||
{h : DivModState.Lawful w wr wn n d qr} : 0 < w := by | ||
def DivModState.Lawful.hw {args : DivModArgs w} {qr : DivModState w} | ||
{h : DivModState.Lawful args qr} : 0 < w := by | ||
have hd := h.hdPos | ||
rcases w with rfl | w | ||
· have hcontra : d = 0#0 := by apply Subsingleton.elim | ||
· have hcontra : args.d = 0#0 := by apply Subsingleton.elim | ||
rw [hcontra] at hd | ||
simp at hd | ||
· omega | ||
|
||
/-- An initial value with both `q, r = 0`. -/ | ||
def DivModState.init (w : Nat) : DivModState w := { | ||
wn := w | ||
wr := 0 | ||
q := 0#w | ||
r := 0#w | ||
} | ||
|
||
/-- The initial state is lawful. -/ | ||
def DivModState.Lawful.init (w : Nat) (n d : BitVec w) (hd : 0#w < d) : | ||
DivModState.Lawful w 0 w n d (DivModState.init w) := { | ||
hwrn := by omega, | ||
hdPos := by assumption | ||
hrLtDivisor := by simp [BitVec.lt_def] at hd ⊢; assumption | ||
hrWidth := by simp [DivModState.init], | ||
hqWidth := by simp [DivModState.init], | ||
hdiv := by | ||
simp only [DivModState.init, toNat_ofNat, zero_mod, Nat.mul_zero, Nat.add_zero]; | ||
rw [Nat.shiftRight_eq_div_pow] | ||
apply Nat.div_eq_of_lt n.isLt | ||
} | ||
def DivModState.Lawful.init {w : Nat} (args : DivModArgs w) (hd : 0#w < args.d) : | ||
DivModState.Lawful args (DivModState.init w) := by | ||
simp only [BitVec.DivModState.init] | ||
exact { | ||
hwrn := by simp only; omega, | ||
hdPos := by assumption | ||
hrLtDivisor := by simp [BitVec.lt_def] at hd ⊢; assumption | ||
hrWidth := by simp [DivModState.init], | ||
hqWidth := by simp [DivModState.init], | ||
hdiv := by | ||
simp only [DivModState.init, toNat_ofNat, zero_mod, Nat.mul_zero, Nat.add_zero]; | ||
rw [Nat.shiftRight_eq_div_pow] | ||
apply Nat.div_eq_of_lt args.n.isLt | ||
} | ||
|
||
/-- | ||
A lawful DivModState with a fully consumed dividend (`wn = 0`) witnesses that the | ||
quotient has been correctly computed. | ||
-/ | ||
theorem DivModState.udiv_eq_of_lawful_zero {qr : DivModState w} | ||
(h : DivModState.Lawful w w 0 n d qr) : | ||
theorem DivModState.udiv_eq_of_lawful {n d : BitVec w} {qr : DivModState w} | ||
(h_lawful : DivModState.Lawful {n, d} qr) | ||
(h_final : qr.wn = 0) : | ||
n.udiv d = qr.q := by | ||
apply udiv_eq_of_mul_add_toNat h.hdPos h.hrLtDivisor | ||
have hdiv := h.hdiv | ||
apply udiv_eq_of_mul_add_toNat h_lawful.hdPos h_lawful.hrLtDivisor | ||
have hdiv := h_lawful.hdiv | ||
simp only [h_final] at * | ||
omega | ||
|
||
/-- | ||
A lawful DivModState with a fully consumed dividend (`wn = 0`) witnesses that the | ||
remainder has been correctly computed. | ||
-/ | ||
theorem DivModState.umod_eq_of_lawful_zero {qr : DivModState w} | ||
(h : DivModState.Lawful w w 0 n d qr) : | ||
theorem DivModState.umod_eq_of_lawful {qr : DivModState w} | ||
(h : DivModState.Lawful {n, d} qr) | ||
(h_final : qr.wn = 0) : | ||
n.umod d = qr.r := by | ||
apply umod_eq_of_mul_add_toNat h.hrLtDivisor | ||
have hdiv := h.hdiv | ||
simp only [shiftRight_zero] at hdiv | ||
simp only [h_final] at * | ||
exact hdiv.symm | ||
|
||
/-! ### DivModState.Poised -/ | ||
|
||
/- | ||
TODO: Can we redefine `qr.Poised` as simply `¬qr.IsFinal`? | ||
That would mean we have extra `Lawful` assumptions in some spots, but in others we'll have less, | ||
as we won't have to carry around the `args` -/ | ||
/-- | ||
A `Poised` DivModState is a state which is `Lawful` and furthermore, has at least | ||
one numerator bit left to process `(0 < wn)` | ||
|
||
The input to the shift subtractor is a legal input to `divrem`, and we also need to have an | ||
input bit to perform shift subtraction on, and thus we need `0 < wn`. | ||
-/ | ||
structure DivModState.Poised (w wr wn : Nat) (n d : BitVec w) (qr : DivModState w) | ||
extends DivModState.Lawful w wr wn n d qr : Type where | ||
structure DivModState.Poised {w : Nat} (args : DivModArgs w) (qr : DivModState w) | ||
extends DivModState.Lawful args qr : Type where | ||
/-- Only perform a round of shift-subtract if we have dividend bits. -/ | ||
hwn_lt : 0 < wn | ||
hwn_lt : 0 < qr.wn | ||
|
||
/-- | ||
In the shift subtract input, the dividend is at least one bit long (`wn > 0`), so | ||
the remainder has bits to be computed (`wr < w`). | ||
-/ | ||
def DivModState.wr_lt_w {qr : DivModState w} (h : qr.Poised wr wn n d) : wr < w := by | ||
def DivModState.wr_lt_w {qr : DivModState w} (h : qr.Poised args) : qr.wr < w := by | ||
have hwrn := h.hwrn | ||
have hwn_lt := h.hwn_lt | ||
omega | ||
|
||
/-- If we have extra bits to spare in `n`, | ||
then we know the div mod state is poised to run another round of the shift subtractor. -/ | ||
def DivModState.Lawful.toPoised {qr : DivModState w} | ||
(h : qr.Lawful w wr (wn + 1) n d) : qr.Poised wr (wn + 1) n d := | ||
{ h with hwn_lt := by omega } | ||
-- /-- If we have extra bits to spare in `n`, | ||
-- then we know the div mod state is poised to run another round of the shift subtractor. -/ | ||
-- def DivModState.Lawful.toPoised {qr : DivModState w} | ||
-- (h : qr.Lawful w wr (wn + 1) n d) : qr.Poised wr (wn + 1) n d := | ||
-- { h with hwn_lt := by omega } | ||
|
||
/-! ### Division shift subtractor -/ | ||
|
||
/-- | ||
One round of the division algorithm, that tries to perform a subtract shift. | ||
Note that this should only be called when `r.msb = false`, so we will not overflow. | ||
-/ | ||
def divSubtractShift (n : BitVec w) (d : BitVec w) (wn : Nat) (qr : DivModState w) : | ||
DivModState w := | ||
let r' := shiftConcat qr.r (n.getLsbD (wn - 1)) | ||
def divSubtractShift (args : DivModArgs w) (qr : DivModState w) : DivModState w := | ||
let {n, d} := args | ||
let wn := qr.wn - 1 | ||
let wr := qr.wr + 1 | ||
let r' := shiftConcat qr.r (n.getLsbD wn) | ||
if r' < d | ||
then { | ||
q := qr.q.shiftConcat false, -- If `r' < d`, then we do not have a quotient bit. | ||
r := r' | ||
wn, wr | ||
} else { | ||
q := qr.q.shiftConcat true, -- Otherwise, `r' ≥ d`, and we have a quotient bit. | ||
r := r' - d -- we subtract to maintain the invariant that `r < d`. | ||
wn, wr | ||
} | ||
|
||
/-- The value of shifting right by `wn - 1` equals shifting by `wn` and grabbing the lsb at `(wn - 1)`. -/ | ||
theorem DivModState.toNat_shiftRight_sub_one_eq | ||
(qr : DivModState w) (h : qr.Poised wr wn n d): | ||
n.toNat >>> (wn - 1) = (n.toNat >>> wn) * 2 + (n.getLsbD (wn - 1)).toNat := by | ||
show BitVec.toNat (n >>> (wn - 1)) = _ | ||
{args : DivModArgs w} {qr : DivModState w} (h : qr.Poised args) : | ||
args.n.toNat >>> (qr.wn - 1) | ||
= (args.n.toNat >>> qr.wn) * 2 + (args.n.getLsbD (qr.wn - 1)).toNat := by | ||
show BitVec.toNat (args.n >>> (qr.wn - 1)) = _ | ||
have {..} := h -- break the structure down for `omega` | ||
rw [shiftRight_sub_one_eq_shiftConcat n h.hwn_lt] | ||
rw [toNat_shiftConcat_eq_of_lt (k := w - wn)] | ||
rw [shiftRight_sub_one_eq_shiftConcat args.n h.hwn_lt] | ||
rw [toNat_shiftConcat_eq_of_lt (k := w - qr.wn)] | ||
· simp | ||
· omega | ||
· apply BitVec.toNat_ushiftRight_lt | ||
|
@@ -671,13 +700,14 @@ private theorem two_mul_add_sub_lt_of_lt_of_lt_two (h : a < x) (hy : y < 2) : | |
|
||
/-- We show that the output of `divSubtractShift` is lawful, which tells us that it | ||
obeys the division equation. -/ | ||
def divSubtractShiftProof (qr : DivModState w) (h : qr.Poised wr wn n d) : | ||
DivModState.Lawful w (wr + 1) (wn - 1) n d (divSubtractShift n d wn qr) := by | ||
def lawful_divSubtractShift (qr : DivModState w) (h : qr.Poised args) : | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a fan of
alexkeizer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
DivModState.Lawful args (divSubtractShift args qr) := by | ||
rcases args with ⟨n, d⟩ | ||
simp only [divSubtractShift, decide_eq_true_eq] | ||
-- We add these hypotheses for `omega` to find them later. | ||
have ⟨⟨hrwn, hd, hrd, hr, hn, hrnd⟩, hwn_lt⟩ := h | ||
have : d.toNat * (qr.q.toNat * 2) = d.toNat * qr.q.toNat * 2 := by rw [Nat.mul_assoc] | ||
by_cases rltd : shiftConcat qr.r (n.getLsbD (wn - 1)) < d | ||
by_cases rltd : shiftConcat qr.r (n.getLsbD (qr.wn - 1)) < d | ||
· simp only [rltd, ↓reduceIte] | ||
constructor <;> try bv_omega | ||
case pos.hrWidth => apply toNat_shiftConcat_lt_of_lt <;> omega | ||
|
@@ -697,9 +727,9 @@ def divSubtractShiftProof (qr : DivModState w) (h : qr.Poised wr wn n d) : | |
apply two_mul_add_sub_lt_of_lt_of_lt_two <;> bv_omega | ||
case neg.hrWidth => | ||
simp only | ||
have hdr' : d ≤ (qr.r.shiftConcat (n.getLsbD (wn - 1))) := | ||
have hdr' : d ≤ (qr.r.shiftConcat (n.getLsbD (qr.wn - 1))) := | ||
BitVec.le_iff_not_lt.mp rltd | ||
have hr' : ((qr.r.shiftConcat (n.getLsbD (wn - 1)))).toNat < 2 ^ (wr + 1) := by | ||
have hr' : ((qr.r.shiftConcat (n.getLsbD (qr.wn - 1)))).toNat < 2 ^ (qr.wr + 1) := by | ||
apply toNat_shiftConcat_lt_of_lt <;> bv_omega | ||
rw [BitVec.toNat_sub_of_le hdr'] | ||
omega | ||
|
@@ -718,61 +748,85 @@ def divSubtractShiftProof (qr : DivModState w) (h : qr.Poised wr wn n d) : | |
/-! ### Core division algorithm circuit -/ | ||
|
||
/-- A recursive definition of division for bitblasting, in terms of a shift-subtraction circuit. -/ | ||
def divRec (w wr wn : Nat) (n d : BitVec w) (qr : DivModState w) : | ||
def divRec {w : Nat} (m : Nat) (args : DivModArgs w) (qr : DivModState w) : | ||
DivModState w := | ||
match wn with | ||
match m with | ||
| 0 => qr | ||
| wn + 1 => | ||
divRec w (wr + 1) wn n d <| divSubtractShift n d (wn + 1) qr | ||
| m + 1 => divRec m args <| divSubtractShift args qr | ||
|
||
@[simp] | ||
theorem divRec_zero (qr : DivModState w) : | ||
divRec w w 0 n d qr = qr := rfl | ||
divRec 0 args qr = qr := rfl | ||
|
||
@[simp] | ||
theorem divRec_succ (wn : Nat) (qr : DivModState w) : | ||
divRec w wr (wn + 1) n d qr = | ||
divRec w (wr + 1) wn n d | ||
(divSubtractShift n d (wn + 1) qr) := rfl | ||
|
||
theorem divRec_correct {n d : BitVec w} (qr : DivModState w) | ||
(h : DivModState.Lawful w wr wn n d qr) : | ||
DivModState.Lawful w w 0 n d (divRec w wr wn n d qr) := by | ||
induction wn generalizing wr qr | ||
theorem divRec_succ (m : Nat) (args : DivModArgs w) (qr : DivModState w) : | ||
divRec (m + 1) args qr = | ||
divRec m args (divSubtractShift args qr) := rfl | ||
|
||
/-- The output of `divRec` is a lawful state -/ | ||
theorem lawful_divRec {args : DivModArgs w} {qr : DivModState w} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I decided to rename this from the somewhat ambiguous "correct" to "lawful" |
||
(h : DivModState.Lawful args qr) : | ||
DivModState.Lawful args (divRec qr.wn args qr) := by | ||
generalize hm : qr.wn = m | ||
induction m generalizing qr | ||
case zero => | ||
exact h | ||
case succ wn' ih => | ||
simp only [divRec_succ] | ||
apply ih | ||
· apply lawful_divSubtractShift | ||
constructor | ||
· assumption | ||
· omega | ||
· simp only [divSubtractShift, hm] | ||
split <;> rfl | ||
|
||
/-- The output of `divRec` has no more bits left to process (i.e., `wn = 0`) -/ | ||
@[simp] | ||
theorem wn_divRec (args : DivModArgs w) (qr : DivModState w) : | ||
(divRec qr.wn args qr).wn = 0 := by | ||
generalize hm : qr.wn = m | ||
induction m generalizing qr | ||
case zero => | ||
unfold divRec | ||
simp [← h.hwrn, h] | ||
assumption | ||
case succ wn' ih => | ||
simp only [divRec] | ||
apply ih | ||
apply divSubtractShiftProof (w := w) | ||
(wr := wr) | ||
(wn := wn' + 1) | ||
exact h.toPoised | ||
simp only [divSubtractShift, hm] | ||
split <;> rfl | ||
|
||
/-- The result of `udiv` agrees with the result of the division recurrence. -/ | ||
theorem udiv_eq_divRec (hd : 0#w < d) : | ||
let out := divRec w 0 w n d (DivModState.init w) | ||
let out := divRec w {n, d} (DivModState.init w) | ||
n.udiv d = out.q := by | ||
have := DivModState.Lawful.init w n d hd | ||
have := divRec_correct (DivModState.init w) this | ||
apply DivModState.udiv_eq_of_lawful_zero this | ||
have := DivModState.Lawful.init {n, d} hd | ||
have := lawful_divRec this | ||
apply DivModState.udiv_eq_of_lawful this (wn_divRec ..) | ||
|
||
/-- The result of `umod` agrees with the result of the division recurrence. -/ | ||
theorem umod_eq_divRec (hd : 0#w < d) : | ||
let out := divRec w 0 w n d (DivModState.init w) | ||
let out := divRec w {n, d} (DivModState.init w) | ||
n.umod d = out.r := by | ||
have := DivModState.Lawful.init w n d hd | ||
have := divRec_correct (DivModState.init w) this | ||
apply DivModState.umod_eq_of_lawful_zero this | ||
have := DivModState.Lawful.init {n, d} hd | ||
have := lawful_divRec this | ||
apply DivModState.umod_eq_of_lawful this (wn_divRec ..) | ||
|
||
@[simp] | ||
theorem divRec_succ' (wn : Nat) (qr : DivModState w) : | ||
divRec w wr (wn + 1) n d qr = | ||
let r' := shiftConcat qr.r (n.getLsbD wn) | ||
let input : DivModState w := | ||
if r' < d then ⟨qr.q.shiftConcat false, r'⟩ else ⟨qr.q.shiftConcat true, r' - d⟩ | ||
divRec w (wr + 1) wn n d input := by | ||
theorem divRec_succ' (m : Nat) (args : DivModArgs w) (qr : DivModState w) : | ||
divRec (m+1) args qr = | ||
let wn := qr.wn - 1 | ||
let wr := qr.wr + 1 | ||
let r' := shiftConcat qr.r (args.n.getLsbD wn) | ||
let input : DivModState _ := | ||
if r' < args.d then { | ||
q := qr.q.shiftConcat false, | ||
r := r' | ||
wn, wr | ||
} else { | ||
q := qr.q.shiftConcat true, | ||
r := r' - args.d | ||
wn, wr | ||
} | ||
divRec m args input := by | ||
simp [divRec_succ, divSubtractShift] | ||
|
||
/- ### Arithmetic shift right (sshiftRight) recurrence -/ | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.