From 275085251cb263882ab226d7a8d124fc697d2959 Mon Sep 17 00:00:00 2001 From: Liam Gallagher Date: Tue, 6 Feb 2024 21:11:57 +0000 Subject: [PATCH 1/4] K-induction: Initial codepath. --- src/main/scala/chiseltest/formal/Formal.scala | 3 ++ .../chiseltest/formal/backends/Maltese.scala | 20 ++++++++++- .../chiseltest/formal/examples/Counter.scala | 34 +++++++++++++++++++ 3 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 src/test/scala/chiseltest/formal/examples/Counter.scala diff --git a/src/main/scala/chiseltest/formal/Formal.scala b/src/main/scala/chiseltest/formal/Formal.scala index 4a0051b04..e636be8a8 100644 --- a/src/main/scala/chiseltest/formal/Formal.scala +++ b/src/main/scala/chiseltest/formal/Formal.scala @@ -13,6 +13,7 @@ import firrtl2.transforms.formal.DontAssertSubmoduleAssumptionsAnnotation sealed trait FormalOp extends NoTargetAnnotation case class BoundedCheck(kMax: Int = -1) extends FormalOp +case class InductionCheck(kMax: Int = -1) extends FormalOp /** Specifies how many cycles the circuit should be reset for. */ case class ResetOption(cycles: Int = 1) extends NoTargetAnnotation { @@ -79,5 +80,7 @@ private object Formal { def executeOp(state: CircuitState, resetLength: Int, op: FormalOp): Unit = op match { case BoundedCheck(kMax) => backends.Maltese.bmc(state.circuit, state.annotations, kMax = kMax, resetLength = resetLength) + case InductionCheck(kMax) => + backends.Maltese.induction(state.circuit, state.annotations, kMax = kMax, resetLength = resetLength) } } diff --git a/src/main/scala/chiseltest/formal/backends/Maltese.scala b/src/main/scala/chiseltest/formal/backends/Maltese.scala index d86948384..fe1ef882d 100644 --- a/src/main/scala/chiseltest/formal/backends/Maltese.scala +++ b/src/main/scala/chiseltest/formal/backends/Maltese.scala @@ -62,6 +62,24 @@ private[chiseltest] object Maltese { require(kMax > 0) require(resetLength >= 0) + val checkFn = (checker: IsModelChecker, sys: TransitionSystem) => checker.check(sys, kMax = kMax + resetLength); + check(circuit, annos, checkFn, resetLength); + } + + def induction(circuit: ir.Circuit, annos: AnnotationSeq, kMax: Int, resetLength: Int = 0): Unit = { + require(kMax > 0) + require(resetLength >= 0) + + val checkFn = (checker: IsModelChecker, sys: TransitionSystem) => checker.check(sys, kMax = kMax + resetLength); + check(circuit, annos, checkFn, resetLength); + } + + def check( + circuit: ir.Circuit, + annos: AnnotationSeq, + checkFn: (IsModelChecker, TransitionSystem) => ModelCheckResult, + resetLength: Int + ): Unit = { // convert to transition system val targetDir = Compiler.requireTargetDir(annos) val modelUndef = !annos.contains(DoNotModelUndef) @@ -77,7 +95,7 @@ private[chiseltest] object Maltese { // perform check val checkers = makeCheckers(annos, targetDir) assert(checkers.size == 1, "Parallel checking not supported atm!") - checkers.head.check(sysInfo.sys, kMax = kMax + resetLength) match { + checkFn(checkers.head, sysInfo.sys) match { case ModelCheckFail(witness) => val writeVcd = annos.contains(WriteVcdAnnotation) if (writeVcd) { diff --git a/src/test/scala/chiseltest/formal/examples/Counter.scala b/src/test/scala/chiseltest/formal/examples/Counter.scala new file mode 100644 index 000000000..048cdbb21 --- /dev/null +++ b/src/test/scala/chiseltest/formal/examples/Counter.scala @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: Apache-2.0 +package chiseltest.formal.examples + +import chisel3._ +import chiseltest._ +import chiseltest.formal._ +import org.scalatest.flatspec.AnyFlatSpec + +class CounterVerify extends AnyFlatSpec with ChiselScalatestTester with Formal with FormalBackendOption { + "Counter" should "succeed BMC" taggedAs FormalTag in { + verify(new Counter(65000, 60000), Seq(BoundedCheck(10), DefaultBackend)) + } + + "Counter" should "Fail induction" taggedAs FormalTag in { + verify(new Counter(65000, 50), Seq(InductionCheck(10), DefaultBackend)) + } +} + +class Counter(to: Int, assert_bound: Int) extends Module { + val en = IO(Input(Bool())) + val out = IO(Output(UInt(16.W))) + + val cnt = RegInit(0.U(16.W)) + out := cnt + when(en) { + when(cnt === to.U) { + cnt := 0.U + }.otherwise { + cnt := cnt + 1.U; + } + } + + assert(cnt <= assert_bound.U(16.W)) +} From 6770f9d377cd3d8a48156db9ac13d528ed07f7d8 Mon Sep 17 00:00:00 2001 From: Liam Gallagher Date: Mon, 19 Feb 2024 19:42:19 +0000 Subject: [PATCH 2/4] Add k-induction for SmtModelCheckers Reused much of the BMC code. - Perform BMC for cycles [0..k-1] - Asserted constraints for cycles [n..n + k] - Asserted assumptions for cycles [n..n + k-1] - Checked the negation of the assertions for cycle n + k --- src/main/scala/chiseltest/formal/Formal.scala | 8 + .../formal/backends/IsModelChecker.scala | 6 +- .../chiseltest/formal/backends/Maltese.scala | 46 +++-- .../backends/btor/Btor2ModelChecker.scala | 4 + .../formal/backends/smt/SMTModelChecker.scala | 158 ++++++++++++------ .../chiseltest/formal/examples/Counter.scala | 38 ++++- 6 files changed, 191 insertions(+), 69 deletions(-) diff --git a/src/main/scala/chiseltest/formal/Formal.scala b/src/main/scala/chiseltest/formal/Formal.scala index e636be8a8..d022e87dc 100644 --- a/src/main/scala/chiseltest/formal/Formal.scala +++ b/src/main/scala/chiseltest/formal/Formal.scala @@ -28,6 +28,14 @@ private[chiseltest] object FailedBoundedCheckException { } } +class FailedInductionCheckException(val message: String, val failAt: Int) extends Exception(message) +private[chiseltest] object FailedInductionCheckException { + def apply(module: String, failAt: Int): FailedInductionCheckException = { + val msg = s"[$module] found an assertion violation after $failAt steps!" + new FailedInductionCheckException(msg, failAt) + } +} + /** Adds the `verify` command for formal checks to a ChiselScalatestTester */ trait Formal { this: HasTestName => def verify[T <: Module](dutGen: => T, annos: AnnotationSeq, chiselAnnos: firrtl.AnnotationSeq = Seq()): Unit = { diff --git a/src/main/scala/chiseltest/formal/backends/IsModelChecker.scala b/src/main/scala/chiseltest/formal/backends/IsModelChecker.scala index 1b9723601..8d36e1306 100644 --- a/src/main/scala/chiseltest/formal/backends/IsModelChecker.scala +++ b/src/main/scala/chiseltest/formal/backends/IsModelChecker.scala @@ -12,12 +12,16 @@ private[chiseltest] case class ModelCheckSuccess() extends ModelCheckResult { ov private[chiseltest] case class ModelCheckFail(witness: Witness) extends ModelCheckResult { override def isFail: Boolean = true } +private[chiseltest] case class ModelCheckFailInduction(witness: Witness) extends ModelCheckResult { + override def isFail: Boolean = true +} private[chiseltest] trait IsModelChecker { def name: String val prefix: String val fileExtension: String - def check(sys: TransitionSystem, kMax: Int = -1): ModelCheckResult + def check(sys: TransitionSystem, kMax: Int = -1): ModelCheckResult + def checkInduction(sys: TransitionSystem, resetLength: Int, kMax: Int = -1): ModelCheckResult } private[chiseltest] case class Witness( diff --git a/src/main/scala/chiseltest/formal/backends/Maltese.scala b/src/main/scala/chiseltest/formal/backends/Maltese.scala index fe1ef882d..293d61ae1 100644 --- a/src/main/scala/chiseltest/formal/backends/Maltese.scala +++ b/src/main/scala/chiseltest/formal/backends/Maltese.scala @@ -4,7 +4,12 @@ package chiseltest.formal.backends import chiseltest.formal.backends.btor.BtormcModelChecker import chiseltest.formal.backends.smt._ -import chiseltest.formal.{DoNotModelUndef, DoNotOptimizeFormal, FailedBoundedCheckException} +import chiseltest.formal.{ + DoNotModelUndef, + DoNotOptimizeFormal, + FailedBoundedCheckException, + FailedInductionCheckException +} import firrtl2._ import firrtl2.annotations._ import firrtl2.stage._ @@ -12,6 +17,7 @@ import firrtl2.backends.experimental.smt.random._ import firrtl2.backends.experimental.smt._ import chiseltest.simulator._ import firrtl2.options.Dependency +import os.Path sealed trait FormalEngineAnnotation extends NoTargetAnnotation @@ -70,7 +76,8 @@ private[chiseltest] object Maltese { require(kMax > 0) require(resetLength >= 0) - val checkFn = (checker: IsModelChecker, sys: TransitionSystem) => checker.check(sys, kMax = kMax + resetLength); + val checkFn = (checker: IsModelChecker, sys: TransitionSystem) => + checker.checkInduction(sys, resetLength, kMax = kMax); check(circuit, annos, checkFn, resetLength); } @@ -97,17 +104,13 @@ private[chiseltest] object Maltese { assert(checkers.size == 1, "Parallel checking not supported atm!") checkFn(checkers.head, sysInfo.sys) match { case ModelCheckFail(witness) => - val writeVcd = annos.contains(WriteVcdAnnotation) - if (writeVcd) { - val sim = new TransitionSystemSimulator(sysInfo.sys) - sim.run(witness, vcdFileName = Some((targetDir / s"${circuit.main}.bmc.vcd").toString)) - val trace = witnessToTrace(sysInfo, witness) - val treadleState = prepTreadle(circuit, annos, modelUndef) - val treadleDut = TreadleBackendAnnotation.getSimulator.createContext(treadleState) - Trace.replayOnSim(trace, treadleDut) - } + processWitness(circuit, sysInfo, annos, witness, modelUndef, targetDir, "bmc") val failSteps = witness.inputs.length - 1 - resetLength throw FailedBoundedCheckException(circuit.main, failSteps) + case ModelCheckFailInduction(witness) => + processWitness(circuit, sysInfo, annos, witness, modelUndef, targetDir, "induction") + val failSteps = witness.inputs.length - 1 + throw FailedInductionCheckException(circuit.main, failSteps) case ModelCheckSuccess() => // good! } } @@ -128,6 +131,27 @@ private[chiseltest] object Maltese { } } + // Produces a vcd file based on the witness is @annos contains WriteVcdAnnotation + private def processWitness( + circuit: ir.Circuit, + sysInfo: SysInfo, + annos: AnnotationSeq, + witness: Witness, + modelUndef: Boolean, + targetDir: Path, + vcdSuffix: String + ) = { + val writeVcd = annos.contains(WriteVcdAnnotation) + if (writeVcd) { + val sim = new TransitionSystemSimulator(sysInfo.sys) + sim.run(witness, vcdFileName = Some((targetDir / s"${circuit.main}.${vcdSuffix}.vcd").toString)) + val trace = witnessToTrace(sysInfo, witness) + val treadleState = prepTreadle(circuit, annos, modelUndef) + val treadleDut = TreadleBackendAnnotation.getSimulator.createContext(treadleState) + Trace.replayOnSim(trace, treadleDut) + } + } + private val LoweringAnnos: AnnotationSeq = Seq( // we need to flatten the whole circuit RunFirrtlTransformAnnotation(Dependency(FlattenPass)), diff --git a/src/main/scala/chiseltest/formal/backends/btor/Btor2ModelChecker.scala b/src/main/scala/chiseltest/formal/backends/btor/Btor2ModelChecker.scala index 2548e0826..e1e40edc1 100644 --- a/src/main/scala/chiseltest/formal/backends/btor/Btor2ModelChecker.scala +++ b/src/main/scala/chiseltest/formal/backends/btor/Btor2ModelChecker.scala @@ -10,6 +10,10 @@ class BtormcModelChecker(targetDir: os.Path) extends IsModelChecker { override val name: String = "btormc" override val prefix: String = "btormc" + override def checkInduction(sys: TransitionSystem, resetLenght: Int, kMax: Int = -1): ModelCheckResult = { + throw new RuntimeException(s"Induction unsupported for btormc"); + } + override def check(sys: TransitionSystem, kMax: Int): ModelCheckResult = { // serialize the system to btor2 val filename = sys.name + ".btor" diff --git a/src/main/scala/chiseltest/formal/backends/smt/SMTModelChecker.scala b/src/main/scala/chiseltest/formal/backends/smt/SMTModelChecker.scala index 0682a26a1..0565e1d46 100644 --- a/src/main/scala/chiseltest/formal/backends/smt/SMTModelChecker.scala +++ b/src/main/scala/chiseltest/formal/backends/smt/SMTModelChecker.scala @@ -24,12 +24,77 @@ class SMTModelChecker( override val prefix: String = solver.name override val fileExtension: String = ".smt2" + override def checkInduction(sys: TransitionSystem, resetLength: Int, kMax: Int = -1): ModelCheckResult = { + require(kMax > 0 && kMax <= 2000, s"unreasonable kMax=$kMax") + // Check BMC first + check(sys, kMax + resetLength) match { + case ModelCheckFail(w) => return ModelCheckFail(w) + case _ => + } + + val (ctx, enc) = checkInit(sys) + + // TODO: remove hardcoding of "_resetActive" + val constraints = sys.signals.filter(s => s.lbl == IsConstraint && s.name != "_resetActive").map(_.name) + val assertions = sys.signals.filter(_.lbl == IsBad).map(_.name) + + (0 to kMax).foreach { k => + // Assume all constraints hold for each k + constraints.foreach(c => ctx.assert(enc.getConstraint(c))) + // Assume All assertions up to k + assertions.foreach(c => ctx.assert(enc.getAssertion(c))) + // Advance + enc.unroll(ctx) + } + // Assume constraints one last time + constraints.foreach(c => ctx.assert(enc.getConstraint(c))) + + val modelResult = + checkAssertions(sys, ctx, enc, assertions, kMax).map(ModelCheckFailInduction(_)).getOrElse(ModelCheckSuccess()) + checkFini(ctx) + modelResult + } + override def check( sys: TransitionSystem, kMax: Int ): ModelCheckResult = { require(kMax > 0 && kMax <= 2000, s"unreasonable kMax=$kMax") + val (ctx, enc) = checkInit(sys) + + val constraints = sys.signals.filter(_.lbl == IsConstraint).map(_.name) + val assertions = sys.signals.filter(_.lbl == IsBad).map(_.name) + + (0 to kMax).foreach { k => + if (printProgress) println(s"Step #$k") + + // assume all constraints hold in this step + constraints.foreach(c => ctx.assert(enc.getConstraint(c))) + + // make sure the constraints are not contradictory + if (options.checkConstraints) { + val res = ctx.check(produceModel = false) + assert(res.isSat, s"Found unsatisfiable constraints in cycle $k") + } + + checkAssertions(sys, ctx, enc, assertions, k) match { + case Some(w) => { + checkFini(ctx) + return ModelCheckFail(w) + } + case _ => {} + } + // advance + enc.unroll(ctx) + } + + checkFini(ctx) + ModelCheckSuccess() + } + + // Initialise solver context and transition system + private def checkInit(sys: TransitionSystem): (SolverContext, TransitionSystemSmtEncoding) = { val ctx = solver.createContext() // z3 only supports the non-standard as-const array syntax when the logic is set to ALL val logic = if (solver.name.contains("z3")) { "ALL" } @@ -49,71 +114,58 @@ class SMTModelChecker( enc.defineHeader(ctx) enc.init(ctx) - val constraints = sys.signals.filter(_.lbl == IsConstraint).map(_.name) - val assertions = sys.signals.filter(_.lbl == IsBad).map(_.name) + (ctx, enc) + } - (0 to kMax).foreach { k => - if (printProgress) println(s"Step #$k") + private def checkFini(ctx: SolverContext) = { + ctx.pop() + assert(ctx.stackDepth == 0, s"Expected solver stack to be empty, not: ${ctx.stackDepth}") + ctx.close() + } - // assume all constraints hold in this step - constraints.foreach(c => ctx.assert(enc.getConstraint(c))) + // Returns Some(witness) if an assertion failed, otherwise returns None + private def checkAssertions( + sys: TransitionSystem, + ctx: SolverContext, + enc: TransitionSystemSmtEncoding, + assertions: List[String], + k: Int + ): Option[Witness] = { + if (options.checkBadStatesIndividually) { + // check each bad state individually + assertions.zipWithIndex.foreach { case (b, bi) => + if (printProgress) print(s"- b$bi? ") - // make sure the constraints are not contradictory - if (options.checkConstraints) { - val res = ctx.check(produceModel = false) - assert(res.isSat, s"Found unsatisfiable constraints in cycle $k") - } - - if (options.checkBadStatesIndividually) { - // check each bad state individually - assertions.zipWithIndex.foreach { case (b, bi) => - if (printProgress) print(s"- b$bi? ") - - ctx.push() - ctx.assert(BVNot(enc.getAssertion(b))) - val res = ctx.check(produceModel = false) - - // did we find an assignment for which the bad state is true? - if (res.isSat) { - if (printProgress) println("❌") - val w = getWitness(ctx, sys, enc, k, Seq(b)) - ctx.pop() - ctx.pop() - assert(ctx.stackDepth == 0, s"Expected solver stack to be empty, not: ${ctx.stackDepth}") - ctx.close() - return ModelCheckFail(w) - } else { - if (printProgress) println("✅") - } - ctx.pop() - } - } else { - val anyBad = BVNot(BVAnd(assertions.map(enc.getAssertion))) ctx.push() - ctx.assert(anyBad) + ctx.assert(BVNot(enc.getAssertion(b))) val res = ctx.check(produceModel = false) - // did we find an assignment for which at least one bad state is true? + // did we find an assignment for which the bad state is true? if (res.isSat) { - val w = getWitness(ctx, sys, enc, k) - ctx.pop() + if (printProgress) println("❌") + val w = getWitness(ctx, sys, enc, k, Seq(b)) ctx.pop() - assert(ctx.stackDepth == 0, s"Expected solver stack to be empty, not: ${ctx.stackDepth}") - ctx.close() - return ModelCheckFail(w) + return Some(w) + } else { + if (printProgress) println("✅") } ctx.pop() } - - // advance - enc.unroll(ctx) + } else { + val anyBad = BVNot(BVAnd(assertions.map(enc.getAssertion))) + ctx.push() + ctx.assert(anyBad) + val res = ctx.check(produceModel = false) + + // did we find an assignment for which at least one bad state is true? + if (res.isSat) { + val w = getWitness(ctx, sys, enc, k) + ctx.pop() + return Some(w) + } + ctx.pop() } - - // clean up - ctx.pop() - assert(ctx.stackDepth == 0, s"Expected solver stack to be empty, not: ${ctx.stackDepth}") - ctx.close() - ModelCheckSuccess() + None } private def getWitness( diff --git a/src/test/scala/chiseltest/formal/examples/Counter.scala b/src/test/scala/chiseltest/formal/examples/Counter.scala index 048cdbb21..aa7061e96 100644 --- a/src/test/scala/chiseltest/formal/examples/Counter.scala +++ b/src/test/scala/chiseltest/formal/examples/Counter.scala @@ -7,12 +7,42 @@ import chiseltest.formal._ import org.scalatest.flatspec.AnyFlatSpec class CounterVerify extends AnyFlatSpec with ChiselScalatestTester with Formal with FormalBackendOption { - "Counter" should "succeed BMC" taggedAs FormalTag in { - verify(new Counter(65000, 60000), Seq(BoundedCheck(10), DefaultBackend)) + "Counter" should "pass BMC" taggedAs FormalTag in { + verify(new Counter(65000, 60000), Seq(BoundedCheck(3), DefaultBackend)) } - "Counter" should "Fail induction" taggedAs FormalTag in { - verify(new Counter(65000, 50), Seq(InductionCheck(10), DefaultBackend)) + "Counter" should "fail induction" taggedAs FormalTag in { + // btormc induction is unsupported + DefaultBackend match { + case BtormcEngineAnnotation => {} + case anno => { + val e = intercept[FailedInductionCheckException] { + verify(new Counter(65000, 64999), Seq(InductionCheck(1), anno)) + } + assert(e.failAt == 1) + } + } + } + + "Counter" should "pass induction" taggedAs FormalTag in { + // btormc induction is unsupported + DefaultBackend match { + case BtormcEngineAnnotation => {} + case anno => verify(new Counter(65000, 65000), Seq(InductionCheck(1), anno)) + } + } + + "Counter" should "Fail BMC step of induction" taggedAs FormalTag in { + // btormc induction is unsupported + DefaultBackend match { + case BtormcEngineAnnotation => {} + case anno => { + val e = intercept[FailedBoundedCheckException] { + verify(new Counter(65000, 4), Seq(InductionCheck(8), anno)) + } + assert(e.failAt == 5) + } + } } } From c6ee282e5691b3a6cf7c2c020fd9db7dfe0c9bfd Mon Sep 17 00:00:00 2001 From: Liam Gallagher Date: Sun, 25 Feb 2024 15:10:37 +0000 Subject: [PATCH 3/4] K-Induction: Initialize transition system at arbitrary step In doing so there is no need to filter out the _resetActive constraint. --- .../formal/backends/IsModelChecker.scala | 2 +- .../chiseltest/formal/backends/Maltese.scala | 3 ++- .../backends/btor/Btor2ModelChecker.scala | 2 +- .../backends/smt/CompactSmtEncoding.scala | 6 ++++-- .../formal/backends/smt/SMTModelChecker.scala | 19 ++++++++++--------- .../backends/smt/UnrollSmtEncoding.scala | 8 ++++++-- 6 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/main/scala/chiseltest/formal/backends/IsModelChecker.scala b/src/main/scala/chiseltest/formal/backends/IsModelChecker.scala index 8d36e1306..2dff2cf02 100644 --- a/src/main/scala/chiseltest/formal/backends/IsModelChecker.scala +++ b/src/main/scala/chiseltest/formal/backends/IsModelChecker.scala @@ -20,7 +20,7 @@ private[chiseltest] trait IsModelChecker { def name: String val prefix: String val fileExtension: String - def check(sys: TransitionSystem, kMax: Int = -1): ModelCheckResult + def checkBounded(sys: TransitionSystem, kMax: Int = -1): ModelCheckResult def checkInduction(sys: TransitionSystem, resetLength: Int, kMax: Int = -1): ModelCheckResult } diff --git a/src/main/scala/chiseltest/formal/backends/Maltese.scala b/src/main/scala/chiseltest/formal/backends/Maltese.scala index 293d61ae1..89340f7e3 100644 --- a/src/main/scala/chiseltest/formal/backends/Maltese.scala +++ b/src/main/scala/chiseltest/formal/backends/Maltese.scala @@ -68,7 +68,8 @@ private[chiseltest] object Maltese { require(kMax > 0) require(resetLength >= 0) - val checkFn = (checker: IsModelChecker, sys: TransitionSystem) => checker.check(sys, kMax = kMax + resetLength); + val checkFn = (checker: IsModelChecker, sys: TransitionSystem) => + checker.checkBounded(sys, kMax = kMax + resetLength); check(circuit, annos, checkFn, resetLength); } diff --git a/src/main/scala/chiseltest/formal/backends/btor/Btor2ModelChecker.scala b/src/main/scala/chiseltest/formal/backends/btor/Btor2ModelChecker.scala index e1e40edc1..990223935 100644 --- a/src/main/scala/chiseltest/formal/backends/btor/Btor2ModelChecker.scala +++ b/src/main/scala/chiseltest/formal/backends/btor/Btor2ModelChecker.scala @@ -14,7 +14,7 @@ class BtormcModelChecker(targetDir: os.Path) extends IsModelChecker { throw new RuntimeException(s"Induction unsupported for btormc"); } - override def check(sys: TransitionSystem, kMax: Int): ModelCheckResult = { + override def checkBounded(sys: TransitionSystem, kMax: Int): ModelCheckResult = { // serialize the system to btor2 val filename = sys.name + ".btor" // btromc isn't happy if we include output nodes, so we skip them during serialization diff --git a/src/main/scala/chiseltest/formal/backends/smt/CompactSmtEncoding.scala b/src/main/scala/chiseltest/formal/backends/smt/CompactSmtEncoding.scala index 590c290ff..34bdba4d8 100644 --- a/src/main/scala/chiseltest/formal/backends/smt/CompactSmtEncoding.scala +++ b/src/main/scala/chiseltest/formal/backends/smt/CompactSmtEncoding.scala @@ -24,10 +24,12 @@ class CompactSmtEncoding(sys: TransitionSystem) extends TransitionSystemSmtEncod s } - def init(ctx: SolverContext): Unit = { + def init(ctx: SolverContext, isArbitraryStep: Boolean): Unit = { assert(states.isEmpty) val s0 = appendState(ctx) - ctx.assert(BVFunctionCall(stateInitFun, List(s0), 1)) + if (!isArbitraryStep) { + ctx.assert(BVFunctionCall(stateInitFun, List(s0), 1)) + } } def unroll(ctx: SolverContext): Unit = { diff --git a/src/main/scala/chiseltest/formal/backends/smt/SMTModelChecker.scala b/src/main/scala/chiseltest/formal/backends/smt/SMTModelChecker.scala index 0565e1d46..89bbb98c0 100644 --- a/src/main/scala/chiseltest/formal/backends/smt/SMTModelChecker.scala +++ b/src/main/scala/chiseltest/formal/backends/smt/SMTModelChecker.scala @@ -27,15 +27,16 @@ class SMTModelChecker( override def checkInduction(sys: TransitionSystem, resetLength: Int, kMax: Int = -1): ModelCheckResult = { require(kMax > 0 && kMax <= 2000, s"unreasonable kMax=$kMax") // Check BMC first - check(sys, kMax + resetLength) match { + checkBounded(sys, kMax + resetLength) match { case ModelCheckFail(w) => return ModelCheckFail(w) case _ => } val (ctx, enc) = checkInit(sys) + // Initialise transition system at an arbitrary step + enc.init(ctx, true) - // TODO: remove hardcoding of "_resetActive" - val constraints = sys.signals.filter(s => s.lbl == IsConstraint && s.name != "_resetActive").map(_.name) + val constraints = sys.signals.filter(s => s.lbl == IsConstraint).map(_.name) val assertions = sys.signals.filter(_.lbl == IsBad).map(_.name) (0 to kMax).foreach { k => @@ -55,13 +56,15 @@ class SMTModelChecker( modelResult } - override def check( + override def checkBounded( sys: TransitionSystem, kMax: Int ): ModelCheckResult = { require(kMax > 0 && kMax <= 2000, s"unreasonable kMax=$kMax") val (ctx, enc) = checkInit(sys) + // Initialise transition system at reset + enc.init(ctx, false) val constraints = sys.signals.filter(_.lbl == IsConstraint).map(_.name) val assertions = sys.signals.filter(_.lbl == IsBad).map(_.name) @@ -112,7 +115,6 @@ class SMTModelChecker( new UnrollSmtEncoding(sys) } enc.defineHeader(ctx) - enc.init(ctx) (ctx, enc) } @@ -123,7 +125,6 @@ class SMTModelChecker( ctx.close() } - // Returns Some(witness) if an assertion failed, otherwise returns None private def checkAssertions( sys: TransitionSystem, ctx: SolverContext, @@ -203,10 +204,10 @@ class SMTModelChecker( trait TransitionSystemSmtEncoding { def defineHeader(ctx: SolverContext): Unit - def init(ctx: SolverContext): Unit + def init(ctx: SolverContext, isArbitraryStep: Boolean): Unit def unroll(ctx: SolverContext): Unit def getConstraint(name: String): BVExpr def getAssertion(name: String): BVExpr - def getSignalAt(sym: BVSymbol, k: Int): BVExpr - def getSignalAt(sym: ArraySymbol, k: Int): ArrayExpr + def getSignalAt(sym: BVSymbol, k: Int): BVExpr + def getSignalAt(sym: ArraySymbol, k: Int): ArrayExpr } diff --git a/src/main/scala/chiseltest/formal/backends/smt/UnrollSmtEncoding.scala b/src/main/scala/chiseltest/formal/backends/smt/UnrollSmtEncoding.scala index 4d0c029c7..252366ae2 100644 --- a/src/main/scala/chiseltest/formal/backends/smt/UnrollSmtEncoding.scala +++ b/src/main/scala/chiseltest/formal/backends/smt/UnrollSmtEncoding.scala @@ -15,14 +15,18 @@ class UnrollSmtEncoding(sys: TransitionSystem) extends TransitionSystemSmtEncodi // nothing to do in this encoding } - override def init(ctx: SolverContext): Unit = { + override def init(ctx: SolverContext, isArbitraryStep: Boolean): Unit = { require(currentStep == -1) currentStep = 0 // declare initial states sys.states.foreach { state => state.init match { case Some(value) => - ctx.runCommand(DefineFunction(at(state.name, 0), Seq(), signalsAndStatesAt(value, 0))) + if (isArbitraryStep) { + ctx.runCommand(DeclareFunction(at(state.sym, 0), Seq())) + } else { + ctx.runCommand(DefineFunction(at(state.name, 0), Seq(), signalsAndStatesAt(value, 0))) + } case None => ctx.runCommand(DeclareFunction(at(state.sym, 0), Seq())) } From 6af02f4bcbdc98985a4200d157221af1ba44972d Mon Sep 17 00:00:00 2001 From: Liam Gallagher Date: Sun, 25 Feb 2024 19:46:01 +0000 Subject: [PATCH 4/4] K-Induction: Add zipcpu k-induction tests --- .../formal/examples/ZipCpuQuizzes.scala | 108 +++++++++++++++++- 1 file changed, 103 insertions(+), 5 deletions(-) diff --git a/src/test/scala/chiseltest/formal/examples/ZipCpuQuizzes.scala b/src/test/scala/chiseltest/formal/examples/ZipCpuQuizzes.scala index ffe772ea5..babf45345 100644 --- a/src/test/scala/chiseltest/formal/examples/ZipCpuQuizzes.scala +++ b/src/test/scala/chiseltest/formal/examples/ZipCpuQuizzes.scala @@ -6,6 +6,7 @@ import chisel3._ import chiseltest._ import chiseltest.formal._ import org.scalatest.flatspec.AnyFlatSpec +import chisel3.util._ /** Chisel versions of the quizzes from ZipCPU: http://zipcpu.com/quiz/quizzes.html */ @@ -30,6 +31,18 @@ class ZipCpuQuizzes extends AnyFlatSpec with ChiselScalatestTester with Formal w verify(new Quiz2(true), Seq(BoundedCheck(5), DefaultBackend)) } + "Quiz3" should "fail induction check" taggedAs FormalTag in { + DefaultBackend match { + case BtormcEngineAnnotation => {} + case anno => { + val e = intercept[FailedInductionCheckException] { + verify(new Quiz3(), Seq(InductionCheck(1), anno)) + } + assert(e.failAt == 1) + } + } + } + "Quiz4" should "fail when using a RegNext to delay the signal" taggedAs FormalTag in { val e = intercept[FailedBoundedCheckException] { verify(new Quiz4(0), Seq(BoundedCheck(5), DefaultBackend)) @@ -52,7 +65,25 @@ class ZipCpuQuizzes extends AnyFlatSpec with ChiselScalatestTester with Formal w "Quiz7" should "pass when using the Chisel past function" taggedAs FormalTag in { verify(new Quiz7(true), Seq(BoundedCheck(5), DefaultBackend)) } - + "Quiz11" should "fail induction" taggedAs FormalTag in { + DefaultBackend match { + case BtormcEngineAnnotation => {} + case anno => { + val e = intercept[FailedInductionCheckException] { + verify(new Quiz11(false), Seq(InductionCheck(3), anno)) + } + assert(e.failAt == 3) + } + } + } + "Quiz11" should "pass induction" taggedAs FormalTag in { + DefaultBackend match { + case BtormcEngineAnnotation => {} + case anno => { + verify(new Quiz11(true), Seq(InductionCheck(4), anno)) + } + } + } "Quiz13" should "pass when x is 1 wide" taggedAs FormalTag in { verify(new Quiz13(1), Seq(BoundedCheck(4), DefaultBackend)) } @@ -75,6 +106,24 @@ class ZipCpuQuizzes extends AnyFlatSpec with ChiselScalatestTester with Formal w "Quiz15" should "pass when using WriteFirst" taggedAs FormalTag in { verify(new Quiz15(WriteFirst), Seq(BoundedCheck(5), DefaultBackend)) } + "Quiz17" should "fail induction" taggedAs FormalTag in { + DefaultBackend match { + case BtormcEngineAnnotation => {} + case anno => { + val e = intercept[FailedInductionCheckException] { + verify(new Quiz17(22, false), Seq(InductionCheck(4), anno)) + } + } + } + } + "Quiz17" should "pass induction" taggedAs FormalTag in { + DefaultBackend match { + case BtormcEngineAnnotation => {} + case anno => { + verify(new Quiz17(22, true), Seq(InductionCheck(4), anno)) + } + } + } } /** http://zipcpu.com/quiz/2019/08/03/quiz01.html */ @@ -105,7 +154,16 @@ class Quiz2(withInit: Boolean) extends Module { } /** http://zipcpu.com/quiz/2019/08/19/quiz03.html */ -// TODO: re-visit once we support k-induction +class Quiz3() extends Module { + val counter = RegInit(0.U(16.W)) + + when(counter === 22.U) { + counter := 0.U + }.otherwise { + counter := counter + 1.U + } + assert(counter =/= 500.U) +} /** http://zipcpu.com/quiz/2019/08/24/quiz04.html */ class Quiz4(style: Int) extends Module { @@ -163,7 +221,7 @@ class Quiz7(fixed: Boolean) extends Module { } /** http://zipcpu.com/quiz/2019/11/29/quiz08.html */ -// TODO: consider implementing once we support k-induction +// A very similar example is in formal/examples/Counter.scala /** http://zipcpu.com/quiz/2019/12/12/quiz09.html */ // this one is hard to translate to Chisel since we do not have blocking assignment @@ -183,7 +241,24 @@ class Quiz10 extends Module { } /** http://zipcpu.com/quiz/2020/01/23/quiz11.html */ -// TODO: consider implementing once we support k-induction +class Quiz11(shouldPass: Boolean) extends Module { + val i_ce = IO(Input(Bool())) + val i_bit = IO(Input(Bool())) + + val sa = RegInit(0.U(4.W)) + val sb = RegInit(0.U(4.W)) + + when(i_ce) { + sa := Cat(sa(2, 0), i_bit) + sb := Cat(i_bit, sb(3, 1)) + } + + if (shouldPass) { + assert(sa === Reverse(sb)) + } else { + assert(sa(3) === sb(0)) + } +} /** http://zipcpu.com/quiz/2020/09/14/quiz13.html */ class Quiz13(xWidth: Int) extends Module { @@ -227,4 +302,27 @@ class Quiz15(readUnderWrite: ReadUnderWrite) extends Module { // TODO: consider implementing once we have good async reset support /** https://zipcpu.com/quiz/2021/08/05/quiz17.html */ -// TODO: consider implementing once we support k-induction +class Quiz17(maxVal: Int, makePass: Boolean) extends Module { + val i_start = IO(Input(Bool())) + + val counter = RegInit(0.U(16.W)) + val zero_counter = RegInit(true.B) + + when(counter > 0.U) { + counter := counter - 1.U + when(counter === 1.U) { + zero_counter := true.B + } + }.elsewhen(i_start) { + counter := maxVal.U + zero_counter := false.B + } + + when(past(counter === 1.U)) { + assert(rose(zero_counter)) + } + + if (makePass) { + assert(zero_counter === (counter === 0.U)) + } +}