Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get lsb d get elem simp #20

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion src/Init/Data/BitVec/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ instance : GetElem (BitVec w) Nat Bool fun _ i => i < w where
theorem getElem_eq_testBit_toNat (x : BitVec w) (i : Nat) (h : i < w) :
x[i] = x.toNat.testBit i := rfl

theorem getLsbD_eq_getElem {x : BitVec w} {i : Nat} (h : i < w) :
@[simp] theorem getLsbD_eq_getElem {x : BitVec w} {i : Nat} (h : i < w) :
x.getLsbD i = x[i] := rfl

end getElem
Expand Down
50 changes: 24 additions & 26 deletions src/Init/Data/BitVec/Bitblast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,12 @@ def adcb (x y c : Bool) : Bool × Bool := (atLeastTwo x y c, x ^^ (y ^^ c))
def adc (x y : BitVec w) : Bool → Bool × BitVec w :=
iunfoldr fun (i : Fin w) c => adcb (x.getLsbD i) (y.getLsbD i) c

theorem getLsbD_add_add_bool {i : Nat} (i_lt : i < w) (x y : BitVec w) (c : Bool) :
getLsbD (x + y + setWidth w (ofBool c)) i =
(getLsbD x i ^^ (getLsbD y i ^^ carry i x y c)) := by
theorem getElem_add_add_setWidth_ofBool {i : Nat} (i_lt : i < w) (x y : BitVec w) (c : Bool) :
(x + y + setWidth w (ofBool c))[i] = (x[i] ^^ (y[i] ^^ carry i x y c)) := by
let ⟨x, x_lt⟩ := x
let ⟨y, y_lt⟩ := y
simp only [getLsbD, toNat_add, toNat_setWidth, i_lt, toNat_ofFin, toNat_ofBool,
Nat.mod_add_mod, Nat.add_mod_mod]
simp only [getElem_eq_testBit_toNat, toNat_add, toNat_setWidth, toNat_ofFin, toNat_ofBool,
mod_add_mod, Nat.add_mod_mod]
apply Eq.trans
rw [← Nat.div_add_mod x (2^i), ← Nat.div_add_mod y (2^i)]
simp only
Expand All @@ -159,10 +158,9 @@ theorem getLsbD_add_add_bool {i : Nat} (i_lt : i < w) (x y : BitVec w) (c : Bool
]
simp [testBit_to_div_mod, carry, Nat.add_assoc]

theorem getLsbD_add {i : Nat} (i_lt : i < w) (x y : BitVec w) :
getLsbD (x + y) i =
(getLsbD x i ^^ (getLsbD y i ^^ carry i x y false)) := by
simpa using getLsbD_add_add_bool i_lt x y false
theorem getElem_add {i : Nat} (i_lt : i < w) (x y : BitVec w) :
(x + y)[i] = (x[i] ^^ (y[i] ^^ carry i x y false)) := by
simpa using getElem_add_add_setWidth_ofBool i_lt x y false

theorem adc_spec (x y : BitVec w) (c : Bool) :
adc x y c = (carry w x y c, x + y + setWidth w (ofBool c)) := by
Expand All @@ -175,7 +173,8 @@ theorem adc_spec (x y : BitVec w) (c : Bool) :
simp [carry, Nat.mod_one]
cases c <;> rfl
case step =>
simp [adcb, Prod.mk.injEq, carry_succ, getLsbD_add_add_bool]
simp [adcb, Prod.mk.injEq, carry_succ, getElem_add_add_setWidth_ofBool]


theorem add_eq_adc (w : Nat) (x y : BitVec w) : x + y = (adc x y false).snd := by
simp [adc_spec]
Expand Down Expand Up @@ -210,24 +209,26 @@ theorem add_eq_or_of_and_eq_zero {w : Nat} (x y : BitVec w)
/-! ### Negation -/

theorem bit_not_testBit (x : BitVec w) (i : Fin w) :
getLsbD (((iunfoldr (fun (i : Fin w) c => (c, !(x.getLsbD i)))) ()).snd) i.val = !(getLsbD x i.val) := by
(((iunfoldr (fun (i : Fin w) c => (c, !x[i.val]))) ()).snd)[i.val] = !x[i.val] := by
apply iunfoldr_getLsbD (fun _ => ()) i (by simp)

theorem bit_not_add_self (x : BitVec w) :
((iunfoldr (fun (i : Fin w) c => (c, !(x.getLsbD i)))) ()).snd + x = -1 := by
((iunfoldr (fun (i : Fin w) c => (c, !x[i.val]))) ()).snd + x = -1#w := by
simp only [add_eq_adc]
apply iunfoldr_replace_snd (fun _ => false) (-1) false rfl
intro i; simp only [ BitVec.not, adcb, testBit_toNat]
rw [iunfoldr_replace_snd (fun _ => ()) (((iunfoldr (fun i c => (c, !(x.getLsbD i)))) ()).snd)]
<;> simp [bit_not_testBit, negOne_eq_allOnes, getLsbD_allOnes]
apply iunfoldr_replace_snd (fun _ => false) (-1#w) false rfl
intro i; simp only [adcb, Fin.is_lt, getLsbD_eq_getElem, atLeastTwo_false_right, bne_false,
Prod.mk.injEq, and_eq_false_imp]
rw [iunfoldr_replace_snd (fun _ => ()) (((iunfoldr (fun i c => (c, !(x[i])))) ()).snd)]
<;> simp [bit_not_testBit, negOne_eq_allOnes]

theorem bit_not_eq_not (x : BitVec w) :
((iunfoldr (fun i c => (c, !(x.getLsbD i)))) ()).snd = ~~~ x := by
simp [←allOnes_sub_eq_not, BitVec.eq_sub_iff_add_eq.mpr (bit_not_add_self x), ←negOne_eq_allOnes]
((iunfoldr (fun i c => (c, !(x[i])))) ()).snd = ~~~ x := by
simp [← allOnes_sub_eq_not, BitVec.eq_sub_iff_add_eq.mpr (bit_not_add_self x), ←negOne_eq_allOnes]

theorem bit_neg_eq_neg (x : BitVec w) : -x = (adc (((iunfoldr (fun (i : Fin w) c => (c, !(x.getLsbD i)))) ()).snd) (BitVec.ofNat w 1) false).snd:= by
theorem bit_neg_eq_neg (x : BitVec w) :
-x = (adc (((iunfoldr (fun (i : Fin w) c => (c, !x[i.val]))) ()).snd) (BitVec.ofNat w 1) false).snd := by
simp only [← add_eq_adc]
rw [iunfoldr_replace_snd ((fun _ => ())) (((iunfoldr (fun (i : Fin w) c => (c, !(x.getLsbD i)))) ()).snd) _ rfl]
rw [iunfoldr_replace_snd ((fun _ => ())) (((iunfoldr (fun (i : Fin w) c => (c, !x[i.val]))) ()).snd) _ rfl]
· rw [BitVec.eq_sub_iff_add_eq.mpr (bit_not_add_self x), sub_toAdd, BitVec.add_comm _ (-x)]
simp [← sub_toAdd, BitVec.sub_add_cancel]
· simp [bit_not_testBit x _]
Expand Down Expand Up @@ -361,12 +362,9 @@ theorem mulRec_eq_mul_signExtend_setWidth (x y : BitVec w) (s : Nat) :
inherit_doc mulRec_eq_mul_signExtend_setWidth]
abbrev mulRec_eq_mul_signExtend_truncate := @mulRec_eq_mul_signExtend_setWidth

theorem getLsbD_mul (x y : BitVec w) (i : Nat) :
(x * y).getLsbD i = (mulRec x y w).getLsbD i := by
simp only [mulRec_eq_mul_signExtend_setWidth]
rw [setWidth_setWidth_of_le]
· simp
· omega
theorem getElem_mul (x y : BitVec w) (i : Nat) (h : i < w) :
(x * y)[i] = (mulRec x y w)[i] := by
simp [mulRec_eq_mul_signExtend_setWidth]

/-! ## shiftLeft recurrence for bitblasting -/

Expand Down
Loading
Loading