Skip to content

Commit

Permalink
feat: support at in ac_nf and use it in bv_normalize (#5618)
Browse files Browse the repository at this point in the history
... while at it also call `trivial` to close goals that can be trivially
closed.

---------

Co-authored-by: Siddharth <[email protected]>
Co-authored-by: Henrik Böving <[email protected]>
  • Loading branch information
3 people authored Oct 7, 2024
1 parent a3ee111 commit c0617da
Show file tree
Hide file tree
Showing 13 changed files with 139 additions and 57 deletions.
34 changes: 21 additions & 13 deletions src/Init/Tactics.lean
Original file line number Diff line number Diff line change
Expand Up @@ -399,19 +399,6 @@ example (a b c d : Nat) : a + b + c + d = d + (b + c) + a := by ac_rfl
-/
syntax (name := acRfl) "ac_rfl" : tactic

/--
`ac_nf` normalizes equalities up to application of an associative and commutative operator.
```
instance : Associative (α := Nat) (.+.) := ⟨Nat.add_assoc⟩
instance : Commutative (α := Nat) (.+.) := ⟨Nat.add_comm⟩
example (a b c d : Nat) : a + b + c + d = d + (b + c) + a := by
ac_nf
-- goal: a + (b + (c + d)) = a + (b + (c + d))
```
-/
syntax (name := acNf) "ac_nf" : tactic

/--
The `sorry` tactic closes the goal using `sorryAx`. This is intended for stubbing out incomplete
parts of a proof while still having a syntactically correct proof skeleton. Lean will give
Expand Down Expand Up @@ -1172,6 +1159,9 @@ Currently the preprocessor is implemented as `try simp only [bv_toNat] at *`.
-/
macro "bv_omega" : tactic => `(tactic| (try simp only [bv_toNat] at *) <;> omega)

/-- Implementation of `ac_nf` (the full `ac_nf` calls `trivial` afterwards). -/
syntax (name := acNf0) "ac_nf0" (location)? : tactic

/-- Implementation of `norm_cast` (the full `norm_cast` calls `trivial` afterwards). -/
syntax (name := normCast0) "norm_cast0" (location)? : tactic

Expand Down Expand Up @@ -1222,6 +1212,24 @@ See also `push_cast`, which moves casts inwards rather than lifting them outward
macro "norm_cast" loc:(location)? : tactic =>
`(tactic| norm_cast0 $[$loc]? <;> try trivial)

/--
`ac_nf` normalizes equalities up to application of an associative and commutative operator.
- `ac_nf` normalizes all hypotheses and the goal target of the goal.
- `ac_nf at l` normalizes at location(s) `l`, where `l` is either `*` or a
list of hypotheses in the local context. In the latter case, a turnstile `⊢` or `|-`
can also be used, to signify the target of the goal.
```
instance : Associative (α := Nat) (.+.) := ⟨Nat.add_assoc⟩
instance : Commutative (α := Nat) (.+.) := ⟨Nat.add_comm⟩
example (a b c d : Nat) : a + b + c + d = d + (b + c) + a := by
ac_nf
-- goal: a + (b + (c + d)) = a + (b + (c + d))
```
-/
macro "ac_nf" loc:(location)? : tactic =>
`(tactic| ac_nf0 $[$loc]? <;> try trivial)

/--
`push_cast` rewrites the goal to move certain coercions (*casts*) inward, toward the leaf nodes.
This uses `norm_cast` lemmas in the forward direction.
Expand Down
5 changes: 5 additions & 0 deletions src/Lean/Elab/Tactic/BVDecide/Frontend/Attr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ register_builtin_option debug.bv.graphviz : Bool := {
descr := "Output the AIG of bv_decide as graphviz into a file called aig.gv in the working directory of the Lean process."
}

register_builtin_option bv.ac_nf : Bool := {
defValue := true
descr := "Canonicalize with respect to associativity and commutativitiy."
}

builtin_initialize bvNormalizeExt : Meta.SimpExtension ←
Meta.registerSimpAttr `bv_normalize "simp theorems used by bv_normalize"

Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Elab/Tactic/BVDecide/Frontend/LRAT.lean
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ function together with a correctness theorem for it.
`∀ (b : α) (c : LratCert), verifier b c = true → unsat b`
-/
def LratCert.toReflectionProof [ToExpr α] (cert : LratCert) (cfg : TacticContext) (reflected : α)
(verifier : Name) (unsat_of_verifier_eq_true : Name) :
(verifier : Name) (unsat_of_verifier_eq_true : Name) :
MetaM Expr := do
withTraceNode `sat (fun _ => return "Compiling expr term") do
mkAuxDecl cfg.exprDef (toExpr reflected) (toTypeExpr α)
Expand Down
30 changes: 27 additions & 3 deletions src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Authors: Henrik Böving
-/
prelude
import Lean.Meta.AppBuilder
import Lean.Meta.Tactic.AC.Main
import Lean.Elab.Tactic.Simp
import Lean.Elab.Tactic.FalseOrByContra
import Lean.Elab.Tactic.BVDecide.Frontend.Attr
Expand Down Expand Up @@ -112,19 +113,44 @@ def rewriteRulesPass : Pass := fun goal => do
let some (_, newGoal) := result? | return none
return newGoal

/--
Normalize with respect to Associativity and Commutativity.
-/
def acNormalizePass : Pass := fun goal => do
let mut newGoal := goal
for hyp in (← goal.getNondepPropHyps) do
let result ← Lean.Meta.AC.acNfHypMeta newGoal hyp

if let .some nextGoal := result then
newGoal := nextGoal
else
return none

return newGoal

/--
The normalization passes used by `bv_normalize` and thus `bv_decide`.
-/
def defaultPipeline : List Pass := [rewriteRulesPass]

def passPipeline : MetaM (List Pass) := do
let opts ← getOptions

let mut passPipeline := defaultPipeline

if bv.ac_nf.get opts then
passPipeline := passPipeline ++ [acNormalizePass]

return passPipeline

end Pass

def bvNormalize (g : MVarId) : MetaM (Option MVarId) := do
withTraceNode `bv (fun _ => return "Normalizing goal") do
-- Contradiction proof
let some g ← g.falseOrByContra | return none
trace[Meta.Tactic.bv] m!"Running preprocessing pipeline on:\n{g}"
Pass.fixpointPipeline Pass.defaultPipeline g
Pass.fixpointPipeline (← Pass.passPipeline) g

@[builtin_tactic Lean.Parser.Tactic.bvNormalize]
def evalBVNormalize : Tactic := fun
Expand All @@ -137,5 +163,3 @@ def evalBVNormalize : Tactic := fun

end Frontend.Normalize
end Lean.Elab.Tactic.BVDecide


82 changes: 54 additions & 28 deletions src/Lean/Meta/Tactic/AC/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,25 @@ where
| .op l r => mkApp2 preContext.op (convertTarget vars l) (convertTarget vars r)
| .var x => vars[x]!

def post (e : Expr) : SimpM Simp.Step := do
let ctx ← Simp.getContext
match e, ctx.parent? with
| bin op₁ l r, some (bin op₂ _ _) =>
if ←isDefEq op₁ op₂ then
return Simp.Step.done { expr := e }
match ←preContext op₁ with
| some pc =>
let (proof, newTgt) ← buildNormProof pc l r
return Simp.Step.done { expr := newTgt, proof? := proof }
| none => return Simp.Step.done { expr := e }
| bin op l r, _ =>
match ←preContext op with
| some pc =>
let (proof, newTgt) ← buildNormProof pc l r
return Simp.Step.done { expr := newTgt, proof? := proof }
| none => return Simp.Step.done { expr := e }
| e, _ => return Simp.Step.done { expr := e }

def rewriteUnnormalized (mvarId : MVarId) : MetaM MVarId := do
let simpCtx :=
{
Expand All @@ -150,41 +169,48 @@ def rewriteUnnormalized (mvarId : MVarId) : MetaM MVarId := do
let tgt ← instantiateMVars (← mvarId.getType)
let (res, _) ← Simp.main tgt simpCtx (methods := { post })
applySimpResultToTarget mvarId tgt res
where
post (e : Expr) : SimpM Simp.Step := do
let ctx ← Simp.getContext
match e, ctx.parent? with
| bin op₁ l r, some (bin op₂ _ _) =>
if ←isDefEq op₁ op₂ then
return Simp.Step.done { expr := e }
match ←preContext op₁ with
| some pc =>
let (proof, newTgt) ← buildNormProof pc l r
return Simp.Step.done { expr := newTgt, proof? := proof }
| none => return Simp.Step.done { expr := e }
| bin op l r, _ =>
match ←preContext op with
| some pc =>
let (proof, newTgt) ← buildNormProof pc l r
return Simp.Step.done { expr := newTgt, proof? := proof }
| none => return Simp.Step.done { expr := e }
| e, _ => return Simp.Step.done { expr := e }

def rewriteUnnormalizedRefl (goal : MVarId) : MetaM Unit := do
let newGoal ← rewriteUnnormalized goal
newGoal.refl

def rewriteUnnormalizedNormalForm (goal : MVarId) : TacticM Unit := do
let newGoal ← rewriteUnnormalized goal
replaceMainGoal [newGoal]
(← rewriteUnnormalized goal).refl

@[builtin_tactic acRfl] def acRflTactic : Lean.Elab.Tactic.Tactic := fun _ => do
let goal ← getMainGoal
goal.withContext <| rewriteUnnormalizedRefl goal

@[builtin_tactic acNf] def acNfTactic : Lean.Elab.Tactic.Tactic := fun _ => do
let goal ← getMainGoal
goal.withContext <| rewriteUnnormalizedNormalForm goal
def acNfHypMeta (goal : MVarId) (fvarId : FVarId) : MetaM (Option MVarId) := do
goal.withContext do
let simpCtx :=
{
simpTheorems := {}
congrTheorems := (← getSimpCongrTheorems)
config := Simp.neutralConfig
}
let tgt ← instantiateMVars (← fvarId.getType)
let (res, _) ← Simp.main tgt simpCtx (methods := { post })
return (← applySimpResultToLocalDecl goal fvarId res false).map (·.snd)

/-- Implementation of the `ac_nf` tactic when operating on the main goal. -/
def acNfTargetTactic : TacticM Unit :=
liftMetaTactic1 fun goal => rewriteUnnormalized goal

/-- Implementation of the `ac_nf` tactic when operating on a hypothesis. -/
def acNfHypTactic (fvarId : FVarId) : TacticM Unit :=
liftMetaTactic1 fun goal => acNfHypMeta goal fvarId

@[builtin_tactic acNf0]
def evalNf0 : Tactic := fun stx => do
match stx with
| `(tactic| ac_nf0 $[$loc?]?) =>
let loc := if let some loc := loc? then expandLocation loc else Location.targets #[] true
withMainContext do
match loc with
| Location.targets hyps target =>
if target then acNfTargetTactic
(← getFVarIds hyps).forM acNfHypTactic
| Location.wildcard =>
acNfTargetTactic
(← (← getMainGoal).getNondepPropHyps).forM acNfHypTactic
| _ => Lean.Elab.throwUnsupportedSyntax

builtin_initialize
registerTraceClass `Meta.AC
Expand Down
23 changes: 12 additions & 11 deletions tests/lean/run/ac_rfl.lean
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ example : [1, 2] ++ ([] ++ [2+4, 8] ++ [4]) = [1, 2] ++ [4+2, 8] ++ [4] := by ac
example (a b c d : BitVec w) :
a * b * c * d = d * c * b * a := by
ac_nf
rfl

example (a b c d : BitVec w) :
a * b * c * d = d * c * b * a := by
Expand All @@ -52,7 +51,6 @@ example (a b c d : BitVec w) :
example (a b c d : BitVec w) :
a + b + c + d = d + c + b + a := by
ac_nf
rfl

example (a b c d : BitVec w) :
a + b + c + d = d + c + b + a := by
Expand All @@ -63,7 +61,6 @@ example (a b c d : BitVec w) :
example (a b c d : BitVec w) :
a * (b * (c * d)) = ((a * b) * c) * d := by
ac_nf
rfl

example (a b c d : BitVec w) :
a * (b * (c * d)) = ((a * b) * c) * d := by
Expand All @@ -72,7 +69,6 @@ example (a b c d : BitVec w) :
example (a b c d : BitVec w) :
a + (b + (c + d)) = ((a + b) + c) + d := by
ac_nf
rfl

example (a b c d : BitVec w) :
a + (b + (c + d)) = ((a + b) + c) + d := by
Expand All @@ -83,7 +79,6 @@ example (a b c d : BitVec w) :
example (a b c d : BitVec w) :
a ^^^ b ^^^ c ^^^ d = d ^^^ c ^^^ b ^^^ a := by
ac_nf
rfl

example (a b c d : BitVec w) :
a ^^^ b ^^^ c ^^^ d = d ^^^ c ^^^ b ^^^ a := by
Expand All @@ -92,7 +87,6 @@ example (a b c d : BitVec w) :
example (a b c d : BitVec w) :
a &&& b &&& c &&& d = d &&& c &&& b &&& a := by
ac_nf
rfl

example (a b c d : BitVec w) :
a &&& b &&& c &&& d = d &&& c &&& b &&& a := by
Expand All @@ -101,7 +95,6 @@ example (a b c d : BitVec w) :
example (a b c d : BitVec w) :
a ||| b ||| c ||| d = d ||| c ||| b ||| a := by
ac_nf
rfl

example (a b c d : BitVec w) :
a ||| b ||| c ||| d = d ||| c ||| b ||| a := by
Expand All @@ -112,7 +105,6 @@ example (a b c d : BitVec w) :
example (a b c d : BitVec w) :
a &&& (b &&& (c &&& d)) = ((a &&& b) &&& c) &&& d := by
ac_nf
rfl

example (a b c d : BitVec w) :
a &&& (b &&& (c &&& d)) = ((a &&& b) &&& c) &&& d := by
Expand All @@ -121,7 +113,6 @@ example (a b c d : BitVec w) :
example (a b c d : BitVec w) :
a ||| (b ||| (c ||| d)) = ((a ||| b) ||| c) ||| d := by
ac_nf
rfl

example (a b c d : BitVec w) :
a ||| (b ||| (c ||| d)) = ((a ||| b) ||| c) ||| d := by
Expand All @@ -130,12 +121,22 @@ example (a b c d : BitVec w) :
example (a b c d : BitVec w) :
a ^^^ (b ^^^ (c ^^^ d)) = ((a ^^^ b) ^^^ c) ^^^ d := by
ac_nf
rfl

example (a b c d : BitVec w) :
a ^^^ (b ^^^ (c ^^^ d)) = ((a ^^^ b) ^^^ c) ^^^ d := by
ac_rfl

example (a b c d : Nat) : a + b + c + d = d + (b + c) + a := by
ac_nf
rfl

example (a b c d : Nat) (h₁ h₂ : a + b + c + d = d + (b + c) + a) :
a + b + c + d = a + (b + c) + d := by

ac_nf at h₁
guard_hyp h₁ :ₛ a + (b + (c + d)) = a + (b + (c + d))

guard_hyp h₂ :ₛ a + b + c + d = d + (b + c) + a
ac_nf at h₂
guard_hyp h₂ :ₛ a + (b + (c + d)) = a + (b + (c + d))

ac_nf at *
2 changes: 2 additions & 0 deletions tests/lean/run/bv_arith.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import Std.Tactic.BVDecide

open BitVec

set_option bv.ac_nf false

theorem arith_unit_1 (x y : BitVec 64) : x + y = y + x := by
bv_decide

Expand Down
2 changes: 2 additions & 0 deletions tests/lean/run/bv_axiom_check.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import Std.Tactic.BVDecide

open BitVec

set_option bv.ac_nf false

theorem bv_axiomCheck (x y : BitVec 1) : x + y = y + x := by
bv_decide

Expand Down
2 changes: 2 additions & 0 deletions tests/lean/run/bv_bitblast_stress.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import Std.Tactic.BVDecide

open BitVec

set_option exponentiation.threshold 4096

theorem t1 {x y : BitVec 64} (h : x = y) : (~~~x) &&& y = (~~~y) &&& x := by
bv_decide

Expand Down
2 changes: 2 additions & 0 deletions tests/lean/run/bv_bitwise.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import Std.Tactic.BVDecide

open BitVec

set_option bv.ac_nf false

theorem bitwise_unit_1 {x y : BitVec 64} : ~~~(x &&& y) = (~~~x ||| ~~~y) := by
bv_decide

Expand Down
8 changes: 7 additions & 1 deletion tests/lean/run/bv_decide_rewriter.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ example :
example (x y z : BitVec 8) (h1 : x = z → False) (h2 : x = y) (h3 : y = z) : False := by
bv_decide

example (x y : BitVec 256) : x * y = y * x := by
bv_decide

example {x y z : BitVec 64} : ~~~(x &&& (y * z)) = (~~~x ||| ~~~(z * y)) := by
bv_decide

def mem_subset (a1 a2 b1 b2 : BitVec 64) : Bool :=
(b2 - b1 = BitVec.ofNat 64 (2^64 - 1)) ||
((a2 - b1 <= b2 - b1 && a1 - b1 <= a2 - b1))
Expand All @@ -39,7 +45,7 @@ example {x : BitVec 16} : x * 1 = x := by bv_normalize
example {x : BitVec 16} : ~~~(~~~x) = x := by bv_normalize
example {x : BitVec 16} : x &&& 0 = 0 := by bv_normalize
example {x : BitVec 16} : 0 &&& x = 0 := by bv_normalize
example {x : BitVec 16} : (-1#16) &&& x = x := by bv_normalize
example {x : BitVec 16} : (-1#16) &&& x = x := by bv_normalize
example {x : BitVec 16} : x &&& (-1#16) = x := by bv_normalize
example {x : BitVec 16} : x &&& x = x := by bv_normalize
example {x : BitVec 16} : x &&& ~~~x = 0 := by bv_normalize
Expand Down
Loading

0 comments on commit c0617da

Please sign in to comment.