Skip to content
This repository has been archived by the owner on Aug 19, 2024. It is now read-only.

Add k-induction for SmtModelCheckers #713

Merged
merged 4 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/main/scala/chiseltest/formal/Formal.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import firrtl2.transforms.formal.DontAssertSubmoduleAssumptionsAnnotation

sealed trait FormalOp extends NoTargetAnnotation
case class BoundedCheck(kMax: Int = -1) extends FormalOp
case class InductionCheck(kMax: Int = -1) extends FormalOp

/** Specifies how many cycles the circuit should be reset for. */
case class ResetOption(cycles: Int = 1) extends NoTargetAnnotation {
Expand All @@ -27,6 +28,14 @@ private[chiseltest] object FailedBoundedCheckException {
}
}

class FailedInductionCheckException(val message: String, val failAt: Int) extends Exception(message)
private[chiseltest] object FailedInductionCheckException {
def apply(module: String, failAt: Int): FailedInductionCheckException = {
val msg = s"[$module] found an assertion violation after $failAt steps!"
new FailedInductionCheckException(msg, failAt)
}
}

/** Adds the `verify` command for formal checks to a ChiselScalatestTester */
trait Formal { this: HasTestName =>
def verify[T <: Module](dutGen: => T, annos: AnnotationSeq, chiselAnnos: firrtl.AnnotationSeq = Seq()): Unit = {
Expand Down Expand Up @@ -79,5 +88,7 @@ private object Formal {
def executeOp(state: CircuitState, resetLength: Int, op: FormalOp): Unit = op match {
case BoundedCheck(kMax) =>
backends.Maltese.bmc(state.circuit, state.annotations, kMax = kMax, resetLength = resetLength)
case InductionCheck(kMax) =>
backends.Maltese.induction(state.circuit, state.annotations, kMax = kMax, resetLength = resetLength)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@ private[chiseltest] case class ModelCheckSuccess() extends ModelCheckResult { ov
private[chiseltest] case class ModelCheckFail(witness: Witness) extends ModelCheckResult {
override def isFail: Boolean = true
}
private[chiseltest] case class ModelCheckFailInduction(witness: Witness) extends ModelCheckResult {
override def isFail: Boolean = true
}

private[chiseltest] trait IsModelChecker {
def name: String
val prefix: String
val fileExtension: String
def check(sys: TransitionSystem, kMax: Int = -1): ModelCheckResult
def checkBounded(sys: TransitionSystem, kMax: Int = -1): ModelCheckResult
def checkInduction(sys: TransitionSystem, resetLength: Int, kMax: Int = -1): ModelCheckResult
}

private[chiseltest] case class Witness(
Expand Down
65 changes: 54 additions & 11 deletions src/main/scala/chiseltest/formal/backends/Maltese.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,20 @@ package chiseltest.formal.backends

import chiseltest.formal.backends.btor.BtormcModelChecker
import chiseltest.formal.backends.smt._
import chiseltest.formal.{DoNotModelUndef, DoNotOptimizeFormal, FailedBoundedCheckException}
import chiseltest.formal.{
DoNotModelUndef,
DoNotOptimizeFormal,
FailedBoundedCheckException,
FailedInductionCheckException
}
import firrtl2._
import firrtl2.annotations._
import firrtl2.stage._
import firrtl2.backends.experimental.smt.random._
import firrtl2.backends.experimental.smt._
import chiseltest.simulator._
import firrtl2.options.Dependency
import os.Path

sealed trait FormalEngineAnnotation extends NoTargetAnnotation

Expand Down Expand Up @@ -62,6 +68,26 @@ private[chiseltest] object Maltese {
require(kMax > 0)
require(resetLength >= 0)

val checkFn = (checker: IsModelChecker, sys: TransitionSystem) =>
checker.checkBounded(sys, kMax = kMax + resetLength);
check(circuit, annos, checkFn, resetLength);
}

def induction(circuit: ir.Circuit, annos: AnnotationSeq, kMax: Int, resetLength: Int = 0): Unit = {
require(kMax > 0)
require(resetLength >= 0)

val checkFn = (checker: IsModelChecker, sys: TransitionSystem) =>
checker.checkInduction(sys, resetLength, kMax = kMax);
check(circuit, annos, checkFn, resetLength);
}

def check(
circuit: ir.Circuit,
annos: AnnotationSeq,
checkFn: (IsModelChecker, TransitionSystem) => ModelCheckResult,
resetLength: Int
): Unit = {
// convert to transition system
val targetDir = Compiler.requireTargetDir(annos)
val modelUndef = !annos.contains(DoNotModelUndef)
Expand All @@ -77,19 +103,15 @@ private[chiseltest] object Maltese {
// perform check
val checkers = makeCheckers(annos, targetDir)
assert(checkers.size == 1, "Parallel checking not supported atm!")
checkers.head.check(sysInfo.sys, kMax = kMax + resetLength) match {
checkFn(checkers.head, sysInfo.sys) match {
case ModelCheckFail(witness) =>
val writeVcd = annos.contains(WriteVcdAnnotation)
if (writeVcd) {
val sim = new TransitionSystemSimulator(sysInfo.sys)
sim.run(witness, vcdFileName = Some((targetDir / s"${circuit.main}.bmc.vcd").toString))
val trace = witnessToTrace(sysInfo, witness)
val treadleState = prepTreadle(circuit, annos, modelUndef)
val treadleDut = TreadleBackendAnnotation.getSimulator.createContext(treadleState)
Trace.replayOnSim(trace, treadleDut)
}
processWitness(circuit, sysInfo, annos, witness, modelUndef, targetDir, "bmc")
val failSteps = witness.inputs.length - 1 - resetLength
throw FailedBoundedCheckException(circuit.main, failSteps)
case ModelCheckFailInduction(witness) =>
processWitness(circuit, sysInfo, annos, witness, modelUndef, targetDir, "induction")
val failSteps = witness.inputs.length - 1
throw FailedInductionCheckException(circuit.main, failSteps)
case ModelCheckSuccess() => // good!
}
}
Expand All @@ -110,6 +132,27 @@ private[chiseltest] object Maltese {
}
}

// Produces a vcd file based on the witness is @annos contains WriteVcdAnnotation
private def processWitness(
circuit: ir.Circuit,
sysInfo: SysInfo,
annos: AnnotationSeq,
witness: Witness,
modelUndef: Boolean,
targetDir: Path,
vcdSuffix: String
) = {
val writeVcd = annos.contains(WriteVcdAnnotation)
if (writeVcd) {
val sim = new TransitionSystemSimulator(sysInfo.sys)
sim.run(witness, vcdFileName = Some((targetDir / s"${circuit.main}.${vcdSuffix}.vcd").toString))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the simulator is run, the _resetActive constraint is violated in the induction case as we start from an arbitrary step. The message ERROR: Constraint #_resetActive was violated! will be printed every time an induction check fails. I suppose this ok, but I'm not sure we can prevent it from noticing the constraints are violated unless we remove the _resetActive constraint. Maybe it could be classed as something entirely different.

val trace = witnessToTrace(sysInfo, witness)
val treadleState = prepTreadle(circuit, annos, modelUndef)
val treadleDut = TreadleBackendAnnotation.getSimulator.createContext(treadleState)
Trace.replayOnSim(trace, treadleDut)
}
}

private val LoweringAnnos: AnnotationSeq = Seq(
// we need to flatten the whole circuit
RunFirrtlTransformAnnotation(Dependency(FlattenPass)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ class BtormcModelChecker(targetDir: os.Path) extends IsModelChecker {
override val name: String = "btormc"
override val prefix: String = "btormc"

override def check(sys: TransitionSystem, kMax: Int): ModelCheckResult = {
override def checkInduction(sys: TransitionSystem, resetLenght: Int, kMax: Int = -1): ModelCheckResult = {
throw new RuntimeException(s"Induction unsupported for btormc");
}

override def checkBounded(sys: TransitionSystem, kMax: Int): ModelCheckResult = {
// serialize the system to btor2
val filename = sys.name + ".btor"
// btromc isn't happy if we include output nodes, so we skip them during serialization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ class CompactSmtEncoding(sys: TransitionSystem) extends TransitionSystemSmtEncod
s
}

def init(ctx: SolverContext): Unit = {
def init(ctx: SolverContext, isArbitraryStep: Boolean): Unit = {
assert(states.isEmpty)
val s0 = appendState(ctx)
ctx.assert(BVFunctionCall(stateInitFun, List(s0), 1))
if (!isArbitraryStep) {
ctx.assert(BVFunctionCall(stateInitFun, List(s0), 1))
}
}

def unroll(ctx: SolverContext): Unit = {
Expand Down
169 changes: 111 additions & 58 deletions src/main/scala/chiseltest/formal/backends/smt/SMTModelChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,80 @@ class SMTModelChecker(
override val prefix: String = solver.name
override val fileExtension: String = ".smt2"

override def check(
override def checkInduction(sys: TransitionSystem, resetLength: Int, kMax: Int = -1): ModelCheckResult = {
require(kMax > 0 && kMax <= 2000, s"unreasonable kMax=$kMax")
// Check BMC first
checkBounded(sys, kMax + resetLength) match {
case ModelCheckFail(w) => return ModelCheckFail(w)
case _ =>
}

val (ctx, enc) = checkInit(sys)
// Initialise transition system at an arbitrary step
enc.init(ctx, true)

val constraints = sys.signals.filter(s => s.lbl == IsConstraint).map(_.name)
val assertions = sys.signals.filter(_.lbl == IsBad).map(_.name)

(0 to kMax).foreach { k =>
// Assume all constraints hold for each k
constraints.foreach(c => ctx.assert(enc.getConstraint(c)))
// Assume All assertions up to k
assertions.foreach(c => ctx.assert(enc.getAssertion(c)))
// Advance
enc.unroll(ctx)
}
// Assume constraints one last time
constraints.foreach(c => ctx.assert(enc.getConstraint(c)))

val modelResult =
checkAssertions(sys, ctx, enc, assertions, kMax).map(ModelCheckFailInduction(_)).getOrElse(ModelCheckSuccess())
checkFini(ctx)
modelResult
}

override def checkBounded(
sys: TransitionSystem,
kMax: Int
): ModelCheckResult = {
require(kMax > 0 && kMax <= 2000, s"unreasonable kMax=$kMax")

val (ctx, enc) = checkInit(sys)
// Initialise transition system at reset
enc.init(ctx, false)

val constraints = sys.signals.filter(_.lbl == IsConstraint).map(_.name)
val assertions = sys.signals.filter(_.lbl == IsBad).map(_.name)

(0 to kMax).foreach { k =>
if (printProgress) println(s"Step #$k")

// assume all constraints hold in this step
constraints.foreach(c => ctx.assert(enc.getConstraint(c)))

// make sure the constraints are not contradictory
if (options.checkConstraints) {
val res = ctx.check(produceModel = false)
assert(res.isSat, s"Found unsatisfiable constraints in cycle $k")
}

checkAssertions(sys, ctx, enc, assertions, k) match {
case Some(w) => {
checkFini(ctx)
return ModelCheckFail(w)
}
case _ => {}
}
// advance
enc.unroll(ctx)
}

checkFini(ctx)
ModelCheckSuccess()
}

// Initialise solver context and transition system
private def checkInit(sys: TransitionSystem): (SolverContext, TransitionSystemSmtEncoding) = {
val ctx = solver.createContext()
// z3 only supports the non-standard as-const array syntax when the logic is set to ALL
val logic = if (solver.name.contains("z3")) { "ALL" }
Expand All @@ -47,73 +115,58 @@ class SMTModelChecker(
new UnrollSmtEncoding(sys)
}
enc.defineHeader(ctx)
enc.init(ctx)

val constraints = sys.signals.filter(_.lbl == IsConstraint).map(_.name)
val assertions = sys.signals.filter(_.lbl == IsBad).map(_.name)
(ctx, enc)
}

(0 to kMax).foreach { k =>
if (printProgress) println(s"Step #$k")
private def checkFini(ctx: SolverContext) = {
ctx.pop()
assert(ctx.stackDepth == 0, s"Expected solver stack to be empty, not: ${ctx.stackDepth}")
ctx.close()
}

// assume all constraints hold in this step
constraints.foreach(c => ctx.assert(enc.getConstraint(c)))
private def checkAssertions(
sys: TransitionSystem,
ctx: SolverContext,
enc: TransitionSystemSmtEncoding,
assertions: List[String],
k: Int
): Option[Witness] = {
if (options.checkBadStatesIndividually) {
// check each bad state individually
assertions.zipWithIndex.foreach { case (b, bi) =>
if (printProgress) print(s"- b$bi? ")

// make sure the constraints are not contradictory
if (options.checkConstraints) {
val res = ctx.check(produceModel = false)
assert(res.isSat, s"Found unsatisfiable constraints in cycle $k")
}

if (options.checkBadStatesIndividually) {
// check each bad state individually
assertions.zipWithIndex.foreach { case (b, bi) =>
if (printProgress) print(s"- b$bi? ")

ctx.push()
ctx.assert(BVNot(enc.getAssertion(b)))
val res = ctx.check(produceModel = false)

// did we find an assignment for which the bad state is true?
if (res.isSat) {
if (printProgress) println("❌")
val w = getWitness(ctx, sys, enc, k, Seq(b))
ctx.pop()
ctx.pop()
assert(ctx.stackDepth == 0, s"Expected solver stack to be empty, not: ${ctx.stackDepth}")
ctx.close()
return ModelCheckFail(w)
} else {
if (printProgress) println("✅")
}
ctx.pop()
}
} else {
val anyBad = BVNot(BVAnd(assertions.map(enc.getAssertion)))
ctx.push()
ctx.assert(anyBad)
ctx.assert(BVNot(enc.getAssertion(b)))
val res = ctx.check(produceModel = false)

// did we find an assignment for which at least one bad state is true?
// did we find an assignment for which the bad state is true?
if (res.isSat) {
val w = getWitness(ctx, sys, enc, k)
ctx.pop()
if (printProgress) println("❌")
val w = getWitness(ctx, sys, enc, k, Seq(b))
ctx.pop()
assert(ctx.stackDepth == 0, s"Expected solver stack to be empty, not: ${ctx.stackDepth}")
ctx.close()
return ModelCheckFail(w)
return Some(w)
} else {
if (printProgress) println("✅")
}
ctx.pop()
}

// advance
enc.unroll(ctx)
} else {
val anyBad = BVNot(BVAnd(assertions.map(enc.getAssertion)))
ctx.push()
ctx.assert(anyBad)
val res = ctx.check(produceModel = false)

// did we find an assignment for which at least one bad state is true?
if (res.isSat) {
val w = getWitness(ctx, sys, enc, k)
ctx.pop()
return Some(w)
}
ctx.pop()
}

// clean up
ctx.pop()
assert(ctx.stackDepth == 0, s"Expected solver stack to be empty, not: ${ctx.stackDepth}")
ctx.close()
ModelCheckSuccess()
None
}

private def getWitness(
Expand Down Expand Up @@ -151,10 +204,10 @@ class SMTModelChecker(

trait TransitionSystemSmtEncoding {
def defineHeader(ctx: SolverContext): Unit
def init(ctx: SolverContext): Unit
def init(ctx: SolverContext, isArbitraryStep: Boolean): Unit
def unroll(ctx: SolverContext): Unit
def getConstraint(name: String): BVExpr
def getAssertion(name: String): BVExpr
def getSignalAt(sym: BVSymbol, k: Int): BVExpr
def getSignalAt(sym: ArraySymbol, k: Int): ArrayExpr
def getSignalAt(sym: BVSymbol, k: Int): BVExpr
def getSignalAt(sym: ArraySymbol, k: Int): ArrayExpr
}
Loading
Loading