Skip to content

Commit

Permalink
chore: tobi cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
bollu committed Sep 4, 2024
1 parent 0cd3a29 commit ac02e6f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 25 deletions.
42 changes: 21 additions & 21 deletions src/Init/Data/BitVec/Bitblast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,6 @@ References:
- Bitwuzla sources for bitblasting.h
-/


/-- TODO: This theorem surely exists somewhere, but I can't find it. -/
private theorem Nat.div_add_eq_left_of_lt {x y z : Nat} (hx : z ∣ x) (hy : y < z) (hz : 0 < z):
(x + y) / z = x / z := by
refine Nat.div_eq_of_lt_le ?lo ?hi
Expand Down Expand Up @@ -495,7 +493,7 @@ theorem udiv_umod_characterized_of_mul_add_toNat {d n q r : BitVec w} (hd : 0 <
replace hdqnr : (d.toNat * q.toNat + r.toNat) % d.toNat = n.toNat % d.toNat := by
simp [hdqnr]
rw [Nat.add_mod, Nat.mul_mod_right] at hdqnr
simp at hdqnr
simp only [Nat.zero_add, mod_mod] at hdqnr
replace hrd : r.toNat < d.toNat := by
rw [BitVec.lt_def] at hrd
exact hrd -- TODO: golf
Expand Down Expand Up @@ -531,7 +529,6 @@ structure DivRemInput.Lawful (w wr wn : Nat) (n d : BitVec w)
/-- The low n bits of `n` obey the fundamental division equation. -/
hdiv : n.toNat >>> wn = d.toNat * qr.q.toNat + qr.r.toNat


/-- A lawful DivRemInput implies `w > 0`. -/
def DivRemInput.Lawful.hw {qr : DivRemInput w}
{h : DivRemInput.Lawful w wr wn n d qr} : 0 < w := by
Expand Down Expand Up @@ -746,7 +743,6 @@ theorem toNat_shiftConcat_lt (x : BitVec w) (b : Bool) (k : Nat)
· rcases b with rfl | rfl <;> decide
· omega


/-- The value of shifting by `wn - 1` equals
shifting by `wn` and grabbing the lsb at (wn - 1) -/
theorem ShiftSubtractInput.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb
Expand All @@ -755,7 +751,7 @@ theorem ShiftSubtractInput.toNat_n_shiftr_wl_minus_one_eq_n_shiftr_wl_plus_nmsb
have hn := ShiftSubtractInput.n_shiftr_wl_minus_one_eq_n_shiftr_wl_or_nmsb (qr := qr) h
obtain hn : (n >>> (wn - 1)).toNat = ((n >>> wn).shiftConcat (n.getLsb (wn - 1))).toNat := by
simp [hn]
simp at hn
simp only [toNat_ushiftRight] at hn
rw [toNat_shiftConcat_eq (k := w - wn)] at hn
· rw [hn]
rw [toNat_ushiftRight]
Expand Down Expand Up @@ -791,7 +787,7 @@ def divSubtractShiftProof (qr : ShiftSubtractInput w) (h : qr.Lawful wr wn n d)
DivRemInput.Lawful w (wr + 1) (wn - 1) n d (divSubtractShift n d wn qr) := by
simp only [divSubtractShift, decide_eq_true_eq]
by_cases rltd : shiftConcat qr.r (n.getLsb (wn - 1)) < d
· simp [rltd]
· simp only [rltd, ↓reduceIte]
constructor
case pos.hwr =>
have := h.hwr
Expand All @@ -806,8 +802,7 @@ def divSubtractShiftProof (qr : ShiftSubtractInput w) (h : qr.Lawful wr wn n d)
assumption
case pos.hrd =>
simp only
simp [BitVec.lt_def] at rltd
assumption
simpa using rltd
case pos.hrwr =>
simp [rltd]
apply toNat_shiftConcat_lt
Expand All @@ -823,7 +818,7 @@ def divSubtractShiftProof (qr : ShiftSubtractInput w) (h : qr.Lawful wr wn n d)
(hk := qr.wr_lt_w h)
(hx := h.hrwr)]
rw [h.hdiv]
simp only
simp only [decide_True, Bool.not_true]
rw [toNat_shiftConcat_false_eq qr.q wr (qr.wr_lt_w h) h.hqwr]
rw [Nat.add_mul]
simp [show d.toNat * (qr.q.toNat * 2) = d.toNat * qr.q.toNat * 2 by
Expand All @@ -849,7 +844,7 @@ def divSubtractShiftProof (qr : ShiftSubtractInput w) (h : qr.Lawful wr wn n d)
assumption
case neg.hrd =>
simp only
simp [BitVec.lt_def] at rltd
simp only [lt_def, Nat.not_lt] at rltd
have hr := h.hrd
have hr' : qr.r < d := by simp only [lt_def]; exact hr
rw [BitVec.toNat_sub_eq_toNat_sub_toNat_of_le rltd]
Expand Down Expand Up @@ -904,14 +899,18 @@ def divSubtractShiftProof (qr : ShiftSubtractInput w) (h : qr.Lawful wr wn n d)
have := h.hwn
omega

/- ### Core divsion recurrence.
/-! ### Core divsion algorithm
We have three widths at play:
- w, the total bitwidth
- wr, the effective bitwidth of the reminder
- wn, the effective bitwidth of the dividend.
- We have the invariant that wn + wr = w.
See that when it is called, we will know that:
- `w`, the total bitwidth
- `wr`, the effective bitwidth of the reminder
- `wn`, the effective bitwidth of the dividend.
We have the invariant that wn + wr = w.
See that when `divRec'` is called with a `DivRemInput.Lawful h`, we know that:
- r < [2^wr = 2^(w - wn)]
which allows us to safely shift left, since it is of length n.
In particular, since 'wn' decreases in the course of the recursion,
Expand All @@ -920,8 +919,9 @@ See that when it is called, we will know that:
Thus, at this step, we will stop and return a full remainder.
So, the remainder is morally of length `w - wn`.
- d > 0
- r < d
- r < d.
-/

def divRec' (w wr wn : Nat) (n d : BitVec w) (qr : DivRemInput w) :
DivRemInput w :=
match wn with
Expand All @@ -935,9 +935,9 @@ theorem divRec'_zero (qr : DivRemInput w) :

@[simp]
theorem divRec'_succ (wn : Nat) (qr : DivRemInput w) :
divRec' w wr (wn + 1) n d qr =
divRec' w (wr + 1) wn n d (divSubtractShift n d (wn + 1) (ShiftSubtractInput.ofDivRemInput qr)) := rfl

divRec' w wr (wn + 1) n d qr =
divRec' w (wr + 1) wn n d
(divSubtractShift n d (wn + 1) (ShiftSubtractInput.ofDivRemInput qr)) := rfl

theorem divRec'_correct {n d : BitVec w} (qr : DivRemInput w)
(h : DivRemInput.Lawful w wr wn n d qr) : DivRemInput.Lawful w w 0 n d (divRec' w wr wn n d qr) := by
Expand Down
6 changes: 2 additions & 4 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -815,10 +815,8 @@ theorem ushiftRight_or_distrib (x y : BitVec w) (n : Nat) :
theorem ushiftRight_zero_eq (x : BitVec w) : x >>> 0 = x := by
simp [bv_toNat]


/--
Shifting right by `n < w` yields a bitvector whose value
is less than `2^(w - n)`
Shifting right by `n < w` yields a bitvector whose value is less than `2^(w - n)`.
-/
theorem toNat_ushiftRight_lt (x : BitVec w) (n : Nat) (hn : n ≤ w) :
(x >>> n).toNat < 2 ^ (w - n) := by
Expand Down Expand Up @@ -1727,7 +1725,7 @@ theorem getLsb_one {w i : Nat} : (1#w).getLsb i = (decide (0 < w) && decide (0 =
theorem shiftLeft_eq_mul_twoPow (x : BitVec w) (n : Nat) :
x <<< n = x * (BitVec.twoPow w n) := by
ext i
simp only [getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and, mul_twoPow_eq_shiftLeft]
simp [getLsb_shiftLeft, Fin.is_lt, decide_True, Bool.true_and, mul_twoPow_eq_shiftLeft]

/- ### zeroExtend, truncate, and bitwise operations -/

Expand Down

0 comments on commit ac02e6f

Please sign in to comment.