diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 868be948d4bc..948b3f0b03e4 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -514,11 +514,24 @@ theorem umod_eq_of_mul_add_toNat {d n q r : BitVec w} (hrd : r < d) /-- `DivModState` is a structure that maintains the state of recursive `divrem` calls. -/ structure DivModState (w : Nat) : Type where + /-- The number of bits in the numerator that are not yet processed -/ + wn : Nat + /-- The number of bits in the remainder (and quotient) -/ + wr : Nat /-- The current quotient. -/ q : BitVec w /-- The current remainder. -/ r : BitVec w + +/-- `DivModArgs` contains the arguments to a `divrem` call which remain constant throughout +execution -/ +structure DivModArgs (w : Nat) where + /-- the numerator (aka, dividend) -/ + n : BitVec w + /-- the denumerator (aka, divisor)-/ + d : BitVec w + /-- A `DivModState` is lawful if the remainder width `wr` plus the numerator width `wn` equals `w`, and the bitvectors `r` and `n` have values in the bounds given by bitwidths `wr`, resp. `wn`. @@ -531,72 +544,79 @@ the values are within their respective bounds. We start with `wn = w` and `wr = 0`, and then in each step, we decrement `wn` and increment `wr`. In this way, we grow a legal remainder in each loop iteration. -/ -structure DivModState.Lawful (w wr wn : Nat) (n d : BitVec w) - (qr : DivModState w) : Prop where +structure DivModState.Lawful {w : Nat} (args : DivModArgs w) (qr : DivModState w) : Prop where /-- The sum of widths of the dividend and remainder is `w`. -/ - hwrn : wr + wn = w + hwrn : qr.wr + qr.wn = w /-- The denominator is positive. -/ - hdPos : 0 < d + hdPos : 0 < args.d /-- The remainder is strictly less than the denominator. -/ - hrLtDivisor : qr.r.toNat < d.toNat + hrLtDivisor : qr.r.toNat < args.d.toNat /-- The remainder is morally a `Bitvec wr`, and so has value less than `2^wr`. -/ - hrWidth : qr.r.toNat < 2^wr + hrWidth : qr.r.toNat < 2^qr.wr /-- The quotient is morally a `Bitvec wr`, and so has value less than `2^wr`. -/ - hqWidth : qr.q.toNat < 2^wr + hqWidth : qr.q.toNat < 2^qr.wr /-- The low `(w - wn)` bits of `n` obey the invariant for division. -/ - hdiv : n.toNat >>> wn = d.toNat * qr.q.toNat + qr.r.toNat + hdiv : args.n.toNat >>> qr.wn = args.d.toNat * qr.q.toNat + qr.r.toNat /-- A lawful DivModState implies `w > 0`. -/ -def DivModState.Lawful.hw {qr : DivModState w} - {h : DivModState.Lawful w wr wn n d qr} : 0 < w := by +def DivModState.Lawful.hw {args : DivModArgs w} {qr : DivModState w} + {h : DivModState.Lawful args qr} : 0 < w := by have hd := h.hdPos rcases w with rfl | w - · have hcontra : d = 0#0 := by apply Subsingleton.elim + · have hcontra : args.d = 0#0 := by apply Subsingleton.elim rw [hcontra] at hd simp at hd · omega /-- An initial value with both `q, r = 0`. -/ def DivModState.init (w : Nat) : DivModState w := { + wn := w + wr := 0 q := 0#w r := 0#w } /-- The initial state is lawful. -/ -def DivModState.Lawful.init (w : Nat) (n d : BitVec w) (hd : 0#w < d) : - DivModState.Lawful w 0 w n d (DivModState.init w) := { - hwrn := by omega, - hdPos := by assumption - hrLtDivisor := by simp [BitVec.lt_def] at hd ⊢; assumption - hrWidth := by simp [DivModState.init], - hqWidth := by simp [DivModState.init], - hdiv := by - simp only [DivModState.init, toNat_ofNat, zero_mod, Nat.mul_zero, Nat.add_zero]; - rw [Nat.shiftRight_eq_div_pow] - apply Nat.div_eq_of_lt n.isLt -} +def DivModState.lawful_init {w : Nat} (args : DivModArgs w) (hd : 0#w < args.d) : + DivModState.Lawful args (DivModState.init w) := by + simp only [BitVec.DivModState.init] + exact { + hwrn := by simp only; omega, + hdPos := by assumption + hrLtDivisor := by simp [BitVec.lt_def] at hd ⊢; assumption + hrWidth := by simp [DivModState.init], + hqWidth := by simp [DivModState.init], + hdiv := by + simp only [DivModState.init, toNat_ofNat, zero_mod, Nat.mul_zero, Nat.add_zero]; + rw [Nat.shiftRight_eq_div_pow] + apply Nat.div_eq_of_lt args.n.isLt + } /-- A lawful DivModState with a fully consumed dividend (`wn = 0`) witnesses that the quotient has been correctly computed. -/ -theorem DivModState.udiv_eq_of_lawful_zero {qr : DivModState w} - (h : DivModState.Lawful w w 0 n d qr) : +theorem DivModState.udiv_eq_of_lawful {n d : BitVec w} {qr : DivModState w} + (h_lawful : DivModState.Lawful {n, d} qr) + (h_final : qr.wn = 0) : n.udiv d = qr.q := by - apply udiv_eq_of_mul_add_toNat h.hdPos h.hrLtDivisor - have hdiv := h.hdiv + apply udiv_eq_of_mul_add_toNat h_lawful.hdPos h_lawful.hrLtDivisor + have hdiv := h_lawful.hdiv + simp only [h_final] at * omega /-- A lawful DivModState with a fully consumed dividend (`wn = 0`) witnesses that the remainder has been correctly computed. -/ -theorem DivModState.umod_eq_of_lawful_zero {qr : DivModState w} - (h : DivModState.Lawful w w 0 n d qr) : +theorem DivModState.umod_eq_of_lawful {qr : DivModState w} + (h : DivModState.Lawful {n, d} qr) + (h_final : qr.wn = 0) : n.umod d = qr.r := by apply umod_eq_of_mul_add_toNat h.hrLtDivisor have hdiv := h.hdiv simp only [shiftRight_zero] at hdiv + simp only [h_final] at * exact hdiv.symm /-! ### DivModState.Poised -/ @@ -608,52 +628,51 @@ one numerator bit left to process `(0 < wn)` The input to the shift subtractor is a legal input to `divrem`, and we also need to have an input bit to perform shift subtraction on, and thus we need `0 < wn`. -/ -structure DivModState.Poised (w wr wn : Nat) (n d : BitVec w) (qr : DivModState w) - extends DivModState.Lawful w wr wn n d qr : Type where +structure DivModState.Poised {w : Nat} (args : DivModArgs w) (qr : DivModState w) + extends DivModState.Lawful args qr : Type where /-- Only perform a round of shift-subtract if we have dividend bits. -/ - hwn_lt : 0 < wn + hwn_lt : 0 < qr.wn /-- In the shift subtract input, the dividend is at least one bit long (`wn > 0`), so the remainder has bits to be computed (`wr < w`). -/ -def DivModState.wr_lt_w {qr : DivModState w} (h : qr.Poised wr wn n d) : wr < w := by +def DivModState.wr_lt_w {qr : DivModState w} (h : qr.Poised args) : qr.wr < w := by have hwrn := h.hwrn have hwn_lt := h.hwn_lt omega -/-- If we have extra bits to spare in `n`, -then we know the div mod state is poised to run another round of the shift subtractor. -/ -def DivModState.Lawful.toPoised {qr : DivModState w} - (h : qr.Lawful w wr (wn + 1) n d) : qr.Poised wr (wn + 1) n d := - { h with hwn_lt := by omega } - /-! ### Division shift subtractor -/ /-- One round of the division algorithm, that tries to perform a subtract shift. Note that this should only be called when `r.msb = false`, so we will not overflow. -/ -def divSubtractShift (n : BitVec w) (d : BitVec w) (wn : Nat) (qr : DivModState w) : - DivModState w := - let r' := shiftConcat qr.r (n.getLsbD (wn - 1)) +def divSubtractShift (args : DivModArgs w) (qr : DivModState w) : DivModState w := + let {n, d} := args + let wn := qr.wn - 1 + let wr := qr.wr + 1 + let r' := shiftConcat qr.r (n.getLsbD wn) if r' < d then { q := qr.q.shiftConcat false, -- If `r' < d`, then we do not have a quotient bit. r := r' + wn, wr } else { q := qr.q.shiftConcat true, -- Otherwise, `r' ≥ d`, and we have a quotient bit. r := r' - d -- we subtract to maintain the invariant that `r < d`. + wn, wr } /-- The value of shifting right by `wn - 1` equals shifting by `wn` and grabbing the lsb at `(wn - 1)`. -/ theorem DivModState.toNat_shiftRight_sub_one_eq - (qr : DivModState w) (h : qr.Poised wr wn n d): - n.toNat >>> (wn - 1) = (n.toNat >>> wn) * 2 + (n.getLsbD (wn - 1)).toNat := by - show BitVec.toNat (n >>> (wn - 1)) = _ + {args : DivModArgs w} {qr : DivModState w} (h : qr.Poised args) : + args.n.toNat >>> (qr.wn - 1) + = (args.n.toNat >>> qr.wn) * 2 + (args.n.getLsbD (qr.wn - 1)).toNat := by + show BitVec.toNat (args.n >>> (qr.wn - 1)) = _ have {..} := h -- break the structure down for `omega` - rw [shiftRight_sub_one_eq_shiftConcat n h.hwn_lt] - rw [toNat_shiftConcat_eq_of_lt (k := w - wn)] + rw [shiftRight_sub_one_eq_shiftConcat args.n h.hwn_lt] + rw [toNat_shiftConcat_eq_of_lt (k := w - qr.wn)] · simp · omega · apply BitVec.toNat_ushiftRight_lt @@ -671,13 +690,14 @@ private theorem two_mul_add_sub_lt_of_lt_of_lt_two (h : a < x) (hy : y < 2) : /-- We show that the output of `divSubtractShift` is lawful, which tells us that it obeys the division equation. -/ -def divSubtractShiftProof (qr : DivModState w) (h : qr.Poised wr wn n d) : - DivModState.Lawful w (wr + 1) (wn - 1) n d (divSubtractShift n d wn qr) := by +theorem lawful_divSubtractShift (qr : DivModState w) (h : qr.Poised args) : + DivModState.Lawful args (divSubtractShift args qr) := by + rcases args with ⟨n, d⟩ simp only [divSubtractShift, decide_eq_true_eq] -- We add these hypotheses for `omega` to find them later. have ⟨⟨hrwn, hd, hrd, hr, hn, hrnd⟩, hwn_lt⟩ := h have : d.toNat * (qr.q.toNat * 2) = d.toNat * qr.q.toNat * 2 := by rw [Nat.mul_assoc] - by_cases rltd : shiftConcat qr.r (n.getLsbD (wn - 1)) < d + by_cases rltd : shiftConcat qr.r (n.getLsbD (qr.wn - 1)) < d · simp only [rltd, ↓reduceIte] constructor <;> try bv_omega case pos.hrWidth => apply toNat_shiftConcat_lt_of_lt <;> omega @@ -697,9 +717,9 @@ def divSubtractShiftProof (qr : DivModState w) (h : qr.Poised wr wn n d) : apply two_mul_add_sub_lt_of_lt_of_lt_two <;> bv_omega case neg.hrWidth => simp only - have hdr' : d ≤ (qr.r.shiftConcat (n.getLsbD (wn - 1))) := + have hdr' : d ≤ (qr.r.shiftConcat (n.getLsbD (qr.wn - 1))) := BitVec.le_iff_not_lt.mp rltd - have hr' : ((qr.r.shiftConcat (n.getLsbD (wn - 1)))).toNat < 2 ^ (wr + 1) := by + have hr' : ((qr.r.shiftConcat (n.getLsbD (qr.wn - 1)))).toNat < 2 ^ (qr.wr + 1) := by apply toNat_shiftConcat_lt_of_lt <;> bv_omega rw [BitVec.toNat_sub_of_le hdr'] omega @@ -718,61 +738,85 @@ def divSubtractShiftProof (qr : DivModState w) (h : qr.Poised wr wn n d) : /-! ### Core division algorithm circuit -/ /-- A recursive definition of division for bitblasting, in terms of a shift-subtraction circuit. -/ -def divRec (w wr wn : Nat) (n d : BitVec w) (qr : DivModState w) : +def divRec {w : Nat} (m : Nat) (args : DivModArgs w) (qr : DivModState w) : DivModState w := - match wn with + match m with | 0 => qr - | wn + 1 => - divRec w (wr + 1) wn n d <| divSubtractShift n d (wn + 1) qr + | m + 1 => divRec m args <| divSubtractShift args qr @[simp] theorem divRec_zero (qr : DivModState w) : - divRec w w 0 n d qr = qr := rfl + divRec 0 args qr = qr := rfl @[simp] -theorem divRec_succ (wn : Nat) (qr : DivModState w) : - divRec w wr (wn + 1) n d qr = - divRec w (wr + 1) wn n d - (divSubtractShift n d (wn + 1) qr) := rfl - -theorem divRec_correct {n d : BitVec w} (qr : DivModState w) - (h : DivModState.Lawful w wr wn n d qr) : - DivModState.Lawful w w 0 n d (divRec w wr wn n d qr) := by - induction wn generalizing wr qr +theorem divRec_succ (m : Nat) (args : DivModArgs w) (qr : DivModState w) : + divRec (m + 1) args qr = + divRec m args (divSubtractShift args qr) := rfl + +/-- The output of `divRec` is a lawful state -/ +theorem lawful_divRec {args : DivModArgs w} {qr : DivModState w} + (h : DivModState.Lawful args qr) : + DivModState.Lawful args (divRec qr.wn args qr) := by + generalize hm : qr.wn = m + induction m generalizing qr + case zero => + exact h + case succ wn' ih => + simp only [divRec_succ] + apply ih + · apply lawful_divSubtractShift + constructor + · assumption + · omega + · simp only [divSubtractShift, hm] + split <;> rfl + +/-- The output of `divRec` has no more bits left to process (i.e., `wn = 0`) -/ +@[simp] +theorem wn_divRec (args : DivModArgs w) (qr : DivModState w) : + (divRec qr.wn args qr).wn = 0 := by + generalize hm : qr.wn = m + induction m generalizing qr case zero => - unfold divRec - simp [← h.hwrn, h] + assumption case succ wn' ih => - simp only [divRec] apply ih - apply divSubtractShiftProof (w := w) - (wr := wr) - (wn := wn' + 1) - exact h.toPoised + simp only [divSubtractShift, hm] + split <;> rfl /-- The result of `udiv` agrees with the result of the division recurrence. -/ theorem udiv_eq_divRec (hd : 0#w < d) : - let out := divRec w 0 w n d (DivModState.init w) + let out := divRec w {n, d} (DivModState.init w) n.udiv d = out.q := by - have := DivModState.Lawful.init w n d hd - have := divRec_correct (DivModState.init w) this - apply DivModState.udiv_eq_of_lawful_zero this + have := DivModState.lawful_init {n, d} hd + have := lawful_divRec this + apply DivModState.udiv_eq_of_lawful this (wn_divRec ..) /-- The result of `umod` agrees with the result of the division recurrence. -/ theorem umod_eq_divRec (hd : 0#w < d) : - let out := divRec w 0 w n d (DivModState.init w) + let out := divRec w {n, d} (DivModState.init w) n.umod d = out.r := by - have := DivModState.Lawful.init w n d hd - have := divRec_correct (DivModState.init w) this - apply DivModState.umod_eq_of_lawful_zero this + have := DivModState.lawful_init {n, d} hd + have := lawful_divRec this + apply DivModState.umod_eq_of_lawful this (wn_divRec ..) @[simp] -theorem divRec_succ' (wn : Nat) (qr : DivModState w) : - divRec w wr (wn + 1) n d qr = - let r' := shiftConcat qr.r (n.getLsbD wn) - let input : DivModState w := - if r' < d then ⟨qr.q.shiftConcat false, r'⟩ else ⟨qr.q.shiftConcat true, r' - d⟩ - divRec w (wr + 1) wn n d input := by +theorem divRec_succ' (m : Nat) (args : DivModArgs w) (qr : DivModState w) : + divRec (m+1) args qr = + let wn := qr.wn - 1 + let wr := qr.wr + 1 + let r' := shiftConcat qr.r (args.n.getLsbD wn) + let input : DivModState _ := + if r' < args.d then { + q := qr.q.shiftConcat false, + r := r' + wn, wr + } else { + q := qr.q.shiftConcat true, + r := r' - args.d + wn, wr + } + divRec m args input := by simp [divRec_succ, divSubtractShift] /- ### Arithmetic shift right (sshiftRight) recurrence -/