From 750c6945c21fc917db1ac79b330656dbd41201a9 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Mon, 9 Sep 2024 02:43:57 -0500 Subject: [PATCH] feat: udiv/umod bitblasting for LeanSAT --- src/Init/Data/BitVec/Basic.lean | 7 + src/Init/Data/BitVec/Bitblast.lean | 388 +++++++++++++++++++++++++++++ src/Init/Data/BitVec/Lemmas.lean | 51 ++++ src/Init/Data/Bool.lean | 5 +- src/Init/Data/Nat/Lemmas.lean | 5 + 5 files changed, 454 insertions(+), 2 deletions(-) diff --git a/src/Init/Data/BitVec/Basic.lean b/src/Init/Data/BitVec/Basic.lean index 3662c0cf4872..dad2e46b0192 100644 --- a/src/Init/Data/BitVec/Basic.lean +++ b/src/Init/Data/BitVec/Basic.lean @@ -658,6 +658,13 @@ result of appending a single bit to the front in the naive implementation). That is, the new bit is the least significant bit. -/ def concat {n} (msbs : BitVec n) (lsb : Bool) : BitVec (n+1) := msbs ++ (ofBool lsb) +/-- +`x.shiftConcat b` shifts all bits of `x` to the left by `1` and sets the least significant bit to `b`. +It is a non-dependent version of `concat` that does not change the total bitwidth. +-/ +def shiftConcat (x : BitVec n) (b : Bool) : BitVec n := + x <<< 1 ||| (ofBool b).zeroExtend n + /-- Prepend a single bit to the front of a bitvector, using big endian order (see `append`). That is, the new bit is the most significant bit. -/ def cons {n} (msb : Bool) (lsbs : BitVec n) : BitVec (n+1) := diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index 8e0f3a81f170..acd6c9a369f2 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -430,6 +430,394 @@ theorem shiftLeft_eq_shiftLeftRec (x : BitVec w₁) (y : BitVec w₂) : · simp [of_length_zero] · simp [shiftLeftRec_eq] +/-! # udiv/urem recurrence for bitblasting + +In order to prove the correctness of the division algorithm on the integers, +one shows that `n.div d = q` and `n.mod d = r` iff `n = d * q + r` and `0 ≤ r < d`. +Mnemonic: `n` is the numerator, `d` is the denominator, `q` is the quotient, and `r` the remainder. + +This uniqueness property is not true for bitvectors. +Let us study an instructive counterexample: + +- Let `bitwidth = 3` +- Let `n = 0, d = 3` +- If we choose `q = 2, r = 2`, then `d * q + r = 6 + 2 = 8 ≃ 0 (mod 8)` and (`r < d`). +- But see that `q = 0, r = 0` also satisfies the constraints, as `0 * 3 + 0 = 0`. +- So for (`n = 0, d = 3`), both (a) `q = 2, r = 2` as well as (b) `q = 0, r = 0` are solutions! + +Such examples can be created by choosing `(q, r)` for a fixed `(d, n)` +such that `(d * q + r)` overflows and wraps around to equal `n`. + +This tells us that the division algorithm must have more restrictions that just the ones +we have for natural numbers. These restrictions are captured in `DivModState.Lawful`, +which captures the relationship necessary between `n, d, q, r`. The key idea is to state +the relationship in terms of the `{n, d, q, r}.toNat` values, which implies that the +relationship also holds at the bitvector level. + +References: +- Fast 32-bit Division on the DSP56800E: Minimized nonrestoring division algorithm by David Baca +- Bitwuzla sources for bitblasting.h +-/ + +private theorem Nat.div_add_eq_left_of_lt {x y z : Nat} (hx : z ∣ x) (hy : y < z) (hz : 0 < z) : + (x + y) / z = x / z := by + refine Nat.div_eq_of_lt_le ?lo ?hi + · apply Nat.le_trans + · exact div_mul_le_self x z + · omega + · simp only [succ_eq_add_one, Nat.add_mul, Nat.one_mul] + apply Nat.add_lt_add_of_le_of_lt + · apply Nat.le_of_eq + exact (Nat.div_eq_iff_eq_mul_left hz hx).mp rfl + · exact hy + +/-- If the division equation `d.toNat * q.toNat + r.toNat = n.toNat` holds, +then `n.udiv d = q`. -/ +theorem udiv_eq_of_mul_add_toNat {d n q r : BitVec w} (hd : 0 < d) + (hrd : r < d) + (hdqnr : d.toNat * q.toNat + r.toNat = n.toNat) : + n.udiv d = q := by + apply BitVec.eq_of_toNat_eq + rw [toNat_udiv] + replace hdqnr : (d.toNat * q.toNat + r.toNat) / d.toNat = n.toNat / d.toNat := by + simp [hdqnr] + rw [Nat.div_add_eq_left_of_lt] at hdqnr + · rw [← hdqnr] + exact mul_div_right q.toNat hd + · exact Nat.dvd_mul_right d.toNat q.toNat + · exact hrd + · exact hd + +/-- If the division equation `d.toNat * q.toNat + r.toNat = n.toNat` holds, +then `n.umod d = r` -/ +theorem umod_eq_of_mul_add_toNat {d n q r : BitVec w} (hrd : r < d) + (hdqnr : d.toNat * q.toNat + r.toNat = n.toNat) : + n.umod d = r := by + apply BitVec.eq_of_toNat_eq + rw [toNat_umod] + replace hdqnr : (d.toNat * q.toNat + r.toNat) % d.toNat = n.toNat % d.toNat := by + simp [hdqnr] + rw [Nat.add_mod, Nat.mul_mod_right] at hdqnr + simp only [Nat.zero_add, mod_mod] at hdqnr + replace hrd : r.toNat < d.toNat := by + simpa [BitVec.lt_def] using hrd + rw [Nat.mod_eq_of_lt hrd] at hdqnr + simp [hdqnr] + +/-! ### DivModState -/ + +/-- Structure that maintains the state of recursive `divrem` calls. -/ +structure DivModState (w : Nat) : Type where + /-- The current quotient. -/ + q : BitVec w + /-- The current remainder. -/ + r : BitVec w + +/-- A `DivModState` is lawful if the remainder width `wr` plus the dividend width `wn` equals `w`, +and the bitvectors `r` and `n` have values in the bounds given by bitwidths `wr`, resp. `wn`. + +This is a proof engineering choice: An alternative world could have +`r : BitVec wr` and `n : BitVec wn`, but this required much more dependent typing coercions. + +Instead, we choose to declare all involved bitvectors as length `w`, and then prove that +the values are within their respective bounds. +-/ +structure DivModState.Lawful (w wr wn : Nat) (n d : BitVec w) + (qr : DivModState w) : Prop where + /-- The sum of widths of the dividend and remainder is `w`. -/ + hwrn : wr + wn = w + /-- The divisor is positive. -/ + hdPos : 0 < d + /-- The remainder is strictly less than the divisor. -/ + hrLtDivisor : qr.r.toNat < d.toNat + /-- The remainder is morally a `Bitvec wr`, and so has value less than `2^wr`. -/ + hrWidth : qr.r.toNat < 2^wr + /-- The quotient is morally a `Bitvec wr`, and so has value less than `2^wr`. -/ + hqWidth : qr.q.toNat < 2^wr + /-- The low n bits of `n` obey the fundamental division equation. -/ + hdiv : n.toNat >>> wn = d.toNat * qr.q.toNat + qr.r.toNat + +/-- A lawful DivModState implies `w > 0`. -/ +def DivModState.Lawful.hw {qr : DivModState w} + {h : DivModState.Lawful w wr wn n d qr} : 0 < w := by + have hd := h.hdPos + rcases w with rfl | w + · have hcontra : d = 0#0 := by apply Subsingleton.elim + rw [hcontra] at hd + simp at hd + · omega + +/-- An initial value with both `q, r = 0`. -/ +def DivModState.init (w : Nat) : DivModState w := { + q := 0#w + r := 0#w +} + +/-- The initial state. -/ +def DivModState.Lawful.init (w : Nat) (n d : BitVec w) (hd : 0#w < d) : + DivModState.Lawful w 0 w n d (DivModState.init w) := { + hwrn := by omega, + hdPos := by assumption + hrLtDivisor := by simp [BitVec.lt_def] at hd ⊢; assumption + hrWidth := by simp [DivModState.init], + hqWidth := by simp [DivModState.init], + hdiv := by + simp only [DivModState.init, toNat_ofNat, zero_mod, Nat.mul_zero, Nat.add_zero]; + rw [Nat.shiftRight_eq_div_pow] + apply Nat.div_eq_of_lt n.isLt +} + +/-- +A lawful DivModState with a fully consumed dividend (`wn = 0`) witneses that the +quotient has been correctly computed. +-/ +theorem DivModState.udiv_eq_of_lawful_zero {qr : DivModState w} + (h : DivModState.Lawful w w 0 n d qr) : + n.udiv d = qr.q := by + apply udiv_eq_of_mul_add_toNat h.hdPos h.hrLtDivisor + have hdiv := h.hdiv + omega + +/-- +A lawful DivModState with a fully consumed dividend (`wn = 0`) witneses that the +remainder has been correctly computed. +-/ +theorem DivModState.umod_eq_of_lawful_zero {qr : DivModState w} + (h : DivModState.Lawful w w 0 n d qr) : + n.umod d = qr.r := by + apply umod_eq_of_mul_add_toNat h.hrLtDivisor + have hdiv := h.hdiv + simp only [shiftRight_zero] at hdiv + exact hdiv.symm + +/-! ### LawfulShiftSubtract -/ + +/-- +A `LawfulShiftSubtract` is a `Lawful` DivModState that is also a legal input to the shift subtractor. +So in particular, we must have at least one dividend bit left over `(0 < wn)` +to perform a round of shift subtraction. + +The input to the shift subtractor is a legal input to `divrem`, and we also need to have an +input bit to perform shift subtraction on, and thus we need `0 < wn`. +-/ +structure DivModState.LawfulShiftSubtract (w wr wn : Nat) (n d : BitVec w) (qr : DivModState w) + extends DivModState.Lawful w wr wn n d qr : Type where + /-- Only perform a round of shift-subtract if we have dividend bits. -/ + hwn_lt : 0 < wn + +/-- +In the shift subtract input, the dividend is at least one bit long (`wn > 0`), so +the remainder has bits to be computed (`wr < w`). +-/ +def DivModState.wr_lt_w {qr : DivModState w} (h : qr.LawfulShiftSubtract wr wn n d) : wr < w := by + have hwrn := h.hwrn + have hwn_lt := h.hwn_lt + omega + +/-- If we have extra bits to spare in `n`, +then the div rem input can be converted into a shift subtract input +to run a round of the shift subtracter. -/ +def DivModState.Lawful.toLawfulShiftSubtract {qr : DivModState w} + (h : qr.Lawful w wr (wn + 1) n d) : qr.LawfulShiftSubtract wr (wn + 1) n d where + hwrn := by have := h.hwrn; omega + hdPos := h.hdPos + hrLtDivisor := h.hrLtDivisor + hrWidth := h.hrWidth + hqWidth := h.hqWidth + hdiv := h.hdiv + hwn_lt := by omega + +/-! ### shiftConcat -/ + +/-- +If `n : Bitvec w` has only the low `k < w` bits set, +then `(n <<< 1 | b)` does not overflow. +-/ +private theorem mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two {n b k w : Nat} + (hn : n < 2 ^ k) (hb : b < 2) (hk : k < w) : + n * 2 + b < 2 ^ w := by + have : 2^(k + 1) ≤ 2 ^w := Nat.pow_le_pow_of_le_right (by decide) (by assumption) + omega + +@[simp, bv_toNat] +theorem toNat_shiftConcat {x : BitVec w} {b : Bool} : (x.shiftConcat b).toNat = + (x.toNat <<< 1 + b.toNat) % 2 ^ w := by + simp only [shiftConcat] + rw [← add_eq_or_of_and_eq_zero] -- Due to `add_eq_or_of_and_eq_zero`, this must live in `Bitblast`. + · simp + · ext i + simp + omega + +/-- `x.shiftConcat b` does not overflow if `x < 2^k` for `k < w`, and so +`x.shiftConcat b |>.toNat = x.toNat * 2 + b.toNat`. -/ +theorem toNat_shiftConcat_eq_of_lt_of_lt_two_pow {x : BitVec w} {b : Bool} {k : Nat} + (hk : k < w) (hx : x.toNat < 2 ^ k) : + (x.shiftConcat b).toNat = x.toNat * 2 + b.toNat := by + simp [bv_toNat, Nat.shiftLeft_eq] + have h := @Nat.pow_lt_pow_of_lt 2 k w (by omega) (by omega) + have := (@Nat.pow_lt_pow_eq_pow_mul_le_pow 2 k w (by omega)).mp h + rw [Nat.mod_eq_of_lt (by cases b <;> simp [bv_toNat] <;> omega)] + +theorem toNat_shiftConcat_lt_of_lt_of_lt_two_pow {x : BitVec w} {b : Bool} {k : Nat} + (hk : k < w) (hx : x.toNat < 2 ^ k) : + (x.shiftConcat b).toNat < 2 ^ (k + 1) := by + rw [toNat_shiftConcat_eq_of_lt_of_lt_two_pow hk hx] + apply mul_two_add_lt_two_pow_of_lt_two_pow_of_lt_two hx + · cases b <;> decide + · omega + +/-! ### Division shift subtractor -/ + +/-- +One round of the division algorithm, that tries to perform a subtract shift. +Note that this is only called when `r.msb = false`, so we will not overflow. +-/ +def divSubtractShift (n : BitVec w) (d : BitVec w) (wn : Nat) (qr : DivModState w) : + DivModState w := + let r' := shiftConcat qr.r (n.getLsbD (wn - 1)) -- if r ≥ d, then we have a quotient bit. + if r' < d + then { + q := qr.q.shiftConcat false, -- If `r' < d`, then we do not have a quotient bit. + r := r' + } else { + q := qr.q.shiftConcat true, -- If `r' ≥ d`, then we have a quotient bit. + r := r' - d -- we subtract to maintain the invariant that `r < d`. + } + +/-- The value of shifting by `wn - 1` equals shifting by `wn` and grabbing the lsb at `(wn - 1)`. -/ +theorem DivModState.toNat_shiftRight_sub_one_eq + (qr : DivModState w) (h : qr.LawfulShiftSubtract wr wn n d): + n.toNat >>> (wn - 1) = (n.toNat >>> wn) * 2 + (n.getLsbD (wn - 1)).toNat := by + have hn := shiftRight_sub_one_eq_shiftConcat_getLsb_of_lt (n := n) (wn := wn) h.hwn_lt + obtain hn : (n >>> (wn - 1)).toNat = ((n >>> wn).shiftConcat (n.getLsbD (wn - 1))).toNat := by + simp [hn] + simp only [toNat_ushiftRight] at hn + rw [toNat_shiftConcat_eq_of_lt_of_lt_two_pow (k := w - wn)] at hn + · rw [hn] + rw [toNat_ushiftRight] + · have := h.hwn_lt + have := h.hw + omega + · apply BitVec.toNat_ushiftRight_lt + have := h.hwrn + omega + +/-- +This is used when proving the correctness of the divison algorithm, +where we know that `r < d`. +We then want to show that `((r.shiftConcat b) - d) < d` as the loop invariant. +In arithmethic, this is the same as showing that +`r * 2 + 1 - d < d`, which this theorem establishes. +-/ +private theorem two_mul_add_sub_lt_of_lt_of_lt_two (h : a < x) (hy : y < 2) : + 2 * a + y - x < x := by omega + +/-- We show that the output of `divSubtractShift` is lawful, which tells us that it +obeys the division equation. -/ +def divSubtractShiftProof (qr : DivModState w) (h : qr.LawfulShiftSubtract wr wn n d) : + DivModState.Lawful w (wr + 1) (wn - 1) n d (divSubtractShift n d wn qr) := by + simp only [divSubtractShift, decide_eq_true_eq] + -- We add these hypotheses for `omega` to find them later. + have ⟨⟨hrwn, hd, hrd, hr, hn, hrnd⟩, hwn_lt⟩ := h + have : d.toNat * (qr.q.toNat * 2) = d.toNat * qr.q.toNat * 2 := by rw [Nat.mul_assoc] + by_cases rltd : shiftConcat qr.r (n.getLsbD (wn - 1)) < d + · simp only [rltd, ↓reduceIte] + constructor <;> try bv_omega + case pos.hrWidth => apply toNat_shiftConcat_lt_of_lt_of_lt_two_pow <;> omega + case pos.hqWidth => apply toNat_shiftConcat_lt_of_lt_of_lt_two_pow <;> omega + case pos.hdiv => + simp [qr.toNat_shiftRight_sub_one_eq h, h.hdiv, this, + toNat_shiftConcat_eq_of_lt_of_lt_two_pow (qr.wr_lt_w h) h.hrWidth, + toNat_shiftConcat_eq_of_lt_of_lt_two_pow (qr.wr_lt_w h) h.hqWidth] + omega + · simp only [rltd, ↓reduceIte] + constructor <;> try bv_omega + case neg.hrLtDivisor => + simp only [lt_def, Nat.not_lt] at rltd + rw [BitVec.toNat_sub_of_le rltd, + toNat_shiftConcat_eq_of_lt_of_lt_two_pow (hk := qr.wr_lt_w h) (hx := h.hrWidth), + Nat.mul_comm] + apply two_mul_add_sub_lt_of_lt_of_lt_two <;> bv_omega + case neg.hrWidth => + simp only + have hdr' : d ≤ (qr.r.shiftConcat (n.getLsbD (wn - 1))) := + BitVec.le_iff_not_lt.mp rltd + have hr' : ((qr.r.shiftConcat (n.getLsbD (wn - 1)))).toNat < 2 ^ (wr + 1) := by + apply toNat_shiftConcat_lt_of_lt_of_lt_two_pow <;> bv_omega + rw [BitVec.toNat_sub_of_le hdr'] + omega + case neg.hqWidth => + apply toNat_shiftConcat_lt_of_lt_of_lt_two_pow <;> omega + case neg.hdiv => + have rltd' := (BitVec.le_iff_not_lt.mp rltd) + simp only [qr.toNat_shiftRight_sub_one_eq h, + BitVec.toNat_sub_of_le rltd', + toNat_shiftConcat_eq_of_lt_of_lt_two_pow (qr.wr_lt_w h) h.hrWidth] + simp only [BitVec.le_def, + toNat_shiftConcat_eq_of_lt_of_lt_two_pow (qr.wr_lt_w h) h.hrWidth] at rltd' + simp only [toNat_shiftConcat_eq_of_lt_of_lt_two_pow (qr.wr_lt_w h) h.hqWidth, h.hdiv, Nat.mul_add] + bv_omega + +/-! ### Core division algorithm circuit -/ + +/-- A recursive definition of division for bitblasting, in terms of a shift-subtraction circuit. -/ +def divRec (w wr wn : Nat) (n d : BitVec w) (qr : DivModState w) : + DivModState w := + match wn with + | 0 => qr + | wn + 1 => + divRec w (wr + 1) wn n d <| divSubtractShift n d (wn + 1) qr + +@[simp] +theorem divRec_zero (qr : DivModState w) : + divRec w w 0 n d qr = qr := rfl + +@[simp] +theorem divRec_succ (wn : Nat) (qr : DivModState w) : + divRec w wr (wn + 1) n d qr = + divRec w (wr + 1) wn n d + (divSubtractShift n d (wn + 1) qr) := rfl + +theorem divRec_correct {n d : BitVec w} (qr : DivModState w) + (h : DivModState.Lawful w wr wn n d qr) : DivModState.Lawful w w 0 n d (divRec w wr wn n d qr) := by + induction wn generalizing wr qr + case zero => + unfold divRec + simp [← h.hwrn, h] + case succ wn' ih => + simp only [divRec] + apply ih + apply divSubtractShiftProof (w := w) + (wr := wr) + (wn := wn' + 1) + exact h.toLawfulShiftSubtract + +/-- The result of `udiv` agrees with the result of the division recurrence. -/ +theorem udiv_eq_divRec (hd : 0#w < d) : + let out := divRec w 0 w n d (DivModState.init w) + n.udiv d = out.q := by + have := DivModState.Lawful.init w n d hd + have := divRec_correct (DivModState.init w) this + apply DivModState.udiv_eq_of_lawful_zero this + +/-- The result of `umod` agrees with the result of the division recurrence. -/ +theorem umod_eq_divRec (hd : 0#w < d) : + let out := divRec w 0 w n d (DivModState.init w) + n.umod d = out.r := by + have := DivModState.Lawful.init w n d hd + have := divRec_correct (DivModState.init w) this + apply DivModState.umod_eq_of_lawful_zero this + +@[simp] +theorem divRec_succ' (wn : Nat) (qr : DivModState w) : + divRec w wr (wn + 1) n d qr = + let r' := shiftConcat qr.r (n.getLsbD wn) + let input : DivModState w := + if r' < d then ⟨qr.q.shiftConcat false, r'⟩ else ⟨qr.q.shiftConcat true, r' - d⟩ + divRec w (wr + 1) wn n d input := by + simp [divRec_succ, divSubtractShift] + /- ### Arithmetic shift right (sshiftRight) recurrence -/ /-- diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index a537d407079c..b192013a875e 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -966,6 +966,17 @@ theorem ushiftRight_or_distrib (x y : BitVec w) (n : Nat) : theorem ushiftRight_zero_eq (x : BitVec w) : x >>> 0 = x := by simp [bv_toNat] +/-- +Shifting right by `n < w` yields a bitvector whose value is less than `2 ^ (w - n)`. +-/ +theorem toNat_ushiftRight_lt (x : BitVec w) (n : Nat) (hn : n ≤ w) : + (x >>> n).toNat < 2 ^ (w - n) := by + rw [toNat_ushiftRight, Nat.shiftRight_eq_div_pow, Nat.div_lt_iff_lt_mul] + · rw [Nat.pow_sub_mul_pow] + · apply x.isLt + · apply hn + · apply Nat.pow_pos (by decide) + /-! ### ushiftRight reductions from BitVec to Nat -/ @[simp] @@ -1421,6 +1432,28 @@ theorem getLsbD_concat (x : BitVec w) (b : Bool) (i : Nat) : (concat x a) ^^^ (concat y b) = concat (x ^^^ y) (xor a b) := by ext i; cases i using Fin.succRecOn <;> simp +/-! ### shiftConcat -/ + +@[simp] +theorem getLsb_shiftConcat {x : BitVec w} {b : Bool} {i : Nat} : + (x.shiftConcat b).getLsbD i = + ((decide (i < w) && !decide (i < 1) && x.getLsbD (i - 1)) || + decide (i < w) && (decide (i = 0) && b)) := by + simp [shiftConcat] + +theorem shiftRight_sub_one_eq_shiftConcat_getLsb_of_lt {n : BitVec w} (hwn : 0 < wn) : + n >>> (wn - 1) = (n >>> wn).shiftConcat (n.getLsbD (wn - 1)) := by + ext i + simp only [getLsbD_ushiftRight, getLsbD_or, getLsbD_shiftLeft, Fin.is_lt, decide_True, Bool.true_and, + getLsbD_zeroExtend, getLsbD_ofBool] + by_cases (i : Nat) < 1 + case pos h => + simp [show (i : Nat) = 0 by omega] + omega + case neg h => + have hi : (i : Nat) ≠ 0 := by omega + simp [shiftConcat, h, hi, show wn - 1 + ↑i = wn + (↑i - 1) by omega] + /-! ### add -/ theorem add_def {n} (x y : BitVec n) : x + y = .ofNat n (x.toNat + y.toNat) := rfl @@ -1638,6 +1671,10 @@ protected theorem lt_of_le_ne (x y : BitVec n) (h1 : x <= y) (h2 : ¬ x = y) : x simp exact Nat.lt_of_le_of_ne +theorem le_iff_not_lt {x y : BitVec w} : (¬ x < y) ↔ y ≤ x := by + constructor <;> + (intro h; simp only [lt_def, Nat.not_lt, le_def] at h ⊢; omega) + /-! ### ofBoolList -/ @[simp] theorem getMsbD_ofBoolListBE : (ofBoolListBE bs).getMsbD i = bs.getD i false := by @@ -1876,6 +1913,11 @@ theorem twoPow_zero {w : Nat} : twoPow w 0 = 1#w := by theorem getLsbD_one {w i : Nat} : (1#w).getLsbD i = (decide (0 < w) && decide (0 = i)) := by rw [← twoPow_zero, getLsbD_twoPow] +theorem shiftLeft_eq_mul_twoPow (x : BitVec w) (n : Nat) : + x <<< n = x * (BitVec.twoPow w n) := by + ext i + simp [getLsbD_shiftLeft, Fin.is_lt, decide_True, Bool.true_and, mul_twoPow_eq_shiftLeft] + /- ### zeroExtend, truncate, and bitwise operations -/ /-- @@ -2009,4 +2051,13 @@ theorem getLsbD_intMax (w : Nat) : (intMax w).getLsbD i = decide (i + 1 < w) := · simp [h] · rw [Nat.sub_add_cancel (Nat.two_pow_pos (w - 1)), Nat.two_pow_pred_mod_two_pow (by omega)] +/-! ### Lemmas about non-overflowing computations -/ + +theorem toNat_sub_of_le {x y : BitVec w} (h : y ≤ x) : + (x - y).toNat = x.toNat - y.toNat := by + rw [BitVec.le_def] at h + simp only [toNat_sub, show 2 ^ w - y.toNat + x.toNat = 2 ^ w + (x.toNat - y.toNat) by omega, + Nat.add_mod_left] + rw [Nat.mod_eq_of_lt (by omega)] + end BitVec diff --git a/src/Init/Data/Bool.lean b/src/Init/Data/Bool.lean index e5910ef86eaa..ec573845b479 100644 --- a/src/Init/Data/Bool.lean +++ b/src/Init/Data/Bool.lean @@ -380,13 +380,14 @@ theorem and_or_inj_left_iff : /-- convert a `Bool` to a `Nat`, `false -> 0`, `true -> 1` -/ def toNat (b : Bool) : Nat := cond b 1 0 -@[simp] theorem toNat_false : false.toNat = 0 := rfl +@[simp, bv_toNat] theorem toNat_false : false.toNat = 0 := rfl -@[simp] theorem toNat_true : true.toNat = 1 := rfl +@[simp, bv_toNat] theorem toNat_true : true.toNat = 1 := rfl theorem toNat_le (c : Bool) : c.toNat ≤ 1 := by cases c <;> trivial +@[bv_toNat] theorem toNat_lt (b : Bool) : b.toNat < 2 := Nat.lt_succ_of_le (toNat_le _) diff --git a/src/Init/Data/Nat/Lemmas.lean b/src/Init/Data/Nat/Lemmas.lean index 268e89a1117a..a5697e5a75b0 100644 --- a/src/Init/Data/Nat/Lemmas.lean +++ b/src/Init/Data/Nat/Lemmas.lean @@ -739,6 +739,11 @@ protected theorem two_pow_pred_mod_two_pow (h : 0 < w) : rw [mod_eq_of_lt] apply Nat.pow_pred_lt_pow (by omega) h +protected theorem pow_lt_pow_eq_pow_mul_le_pow {a n m : Nat} (h : 1 < a) : + a ^ n < a ^ m ↔ a ^ n * a ≤ a ^ m := by + rw [←Nat.pow_add_one, Nat.pow_le_pow_iff_right (by omega), Nat.pow_lt_pow_iff_right (by omega)] + omega + /-! ### log2 -/ @[simp]