diff --git a/.gitignore b/.gitignore index 484972ecf..f5b639dc5 100644 --- a/.gitignore +++ b/.gitignore @@ -70,3 +70,5 @@ out-classes project/.bloop/ project/metals.sbt project/project/ + +*.semanticdb \ No newline at end of file diff --git a/build.sbt b/build.sbt index e7c9f6dfa..9d3d9dc3f 100644 --- a/build.sbt +++ b/build.sbt @@ -33,13 +33,21 @@ Compile / unmanagedJars += { resolvers ++= Seq( "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", "Sonatype OSS Releases" at "https://oss.sonatype.org/content/repositories/releases", - ("uuverifiers" at "http://logicrunch.research.it.uu.se/maven").withAllowInsecureProtocol(true) + "uuverifiers" at "https://eldarica.org/maven" ) libraryDependencies ++= Seq( "org.scalatest" %% "scalatest" % "3.2.9" % "test;it", "org.apache.commons" % "commons-lang3" % "3.4", - ("org.scala-lang.modules" %% "scala-parser-combinators" % "1.1.2").cross(CrossVersion.for3Use2_13) + ("uuverifiers" %% "eldarica" % "nightly-SNAPSHOT").cross(CrossVersion.for3Use2_13), + ("uuverifiers" %% "princess" % "nightly-SNAPSHOT").cross(CrossVersion.for3Use2_13), + "org.scala-lang.modules" %% "scala-parser-combinators" % "2.3.0" +) + +excludeDependencies ++= Seq( + "org.scala-lang.modules" % "scala-parser-combinators_2.13", + "org.scala-lang.modules" % "scala-xml_2.13", + "org.scalactic" % "scalactic_2.13", ) lazy val nTestParallelism = { @@ -60,9 +68,6 @@ def ghProject(repo: String, version: String) = RootProject(uri(s"${repo}#${versi // lazy val smtlib = RootProject(file("../scala-smtlib")) // If you have a local copy of Scala-SMTLIB and would like to do some changes lazy val smtlib = ghProject("https://github.com/epfl-lara/scala-smtlib.git", "51a44878858b427f1a4e5a5eb01d8f796898d812") -// lazy val princess = RootProject(file("../princess")) // If you have a local copy of Princess and would like to do some changes -lazy val princess = ghProject("https://github.com/uuverifiers/princess.git", "93cbff11d7b02903e532c7b64207bc12f19b79c7") - lazy val scriptName = settingKey[String]("Name of the generated 'inox' script") scriptName := "inox" @@ -153,7 +158,7 @@ lazy val root = (project in file(".")) )) : _*) .settings(compile := ((Compile / compile) dependsOn script).value) .settings(Compile / packageDoc / mappings := Seq()) - .dependsOn(smtlib, princess) + .dependsOn(smtlib) Global / concurrentRestrictions := Seq( Tags.limit(Tags.Test, nTestParallelism) diff --git a/src/main/scala/inox/ast/TypeOps.scala b/src/main/scala/inox/ast/TypeOps.scala index f745b7674..4d3fcf140 100644 --- a/src/main/scala/inox/ast/TypeOps.scala +++ b/src/main/scala/inox/ast/TypeOps.scala @@ -42,7 +42,8 @@ trait TypeOps { def greatestLowerBound(tps: Seq[Type]): Type = if (tps.isEmpty) Untyped else tps.reduceLeft(greatestLowerBound) - def isSubtypeOf(t1: Type, t2: Type): Boolean = t1.getType == t2.getType + def isSubtypeOf(t1: Type, t2: Type): Boolean = + t1.getType == Untyped || t2.getType == Untyped || t1.getType == t2.getType private type Instantiation = Map[TypeParameter, Type] def instantiation(from: Type, to: Type): Option[Instantiation] = { diff --git a/src/main/scala/inox/solvers/SolverFactory.scala b/src/main/scala/inox/solvers/SolverFactory.scala index 80d34be80..fb48ba198 100644 --- a/src/main/scala/inox/solvers/SolverFactory.scala +++ b/src/main/scala/inox/solvers/SolverFactory.scala @@ -5,6 +5,8 @@ package solvers import transformers._ import inox.solvers.evaluating.EvaluatingSolver +import inox.solvers.invariant.AbstractInvariantSolver +import inox.solvers.invariant.InvariantSolver trait SolverFactory { val program: Program @@ -80,17 +82,21 @@ object SolverFactory { "smt-z3-opt" -> "Z3 optimizer through SMT-LIB", "smt-z3:" -> "Z3 through SMT-LIB with custom executable name", "princess" -> "Princess with inox unrolling", - "eval" -> "Internal evaluator to discharge ground assertions" + "eval" -> "Internal evaluator to discharge ground assertions", + "inv-z3" -> "Horn solver using Z3 / Spacer", + "inv-eld" -> "Horn solver using Eldarica" ) private val fallbacks = Map( "nativez3" -> (() => hasNativeZ3, Seq("smt-z3", "smt-cvc4", "smt-cvc5", "princess"), "Z3 native interface"), "nativez3-opt" -> (() => hasNativeZ3, Seq("smt-z3-opt"), "Z3 native interface"), "unrollz3" -> (() => hasNativeZ3, Seq("smt-z3", "smt-cvc4", "smt-cvc5", "princess"), "Z3 native interface"), + "inv-z3" -> (() => hasZ3, Seq("smt-z3", "smt-cvc4", "smt-cvc5", "princess"), "Z3 native interface"), "smt-cvc4" -> (() => hasCVC4, Seq("nativez3", "smt-z3", "princess"), "'cvc4' binary"), "smt-cvc5" -> (() => hasCVC5, Seq("nativez3", "smt-z3", "princess"), "'cvc5' binary"), "smt-z3" -> (() => hasZ3, Seq("nativez3", "smt-cvc4", "smt-cvc5", "princess"), "'z3' binary"), "smt-z3-opt" -> (() => hasZ3, Seq("nativez3-opt"), "'z3' binary"), + "inv-eld" -> (() => true, Seq(), "Eldarica solver"), "princess" -> (() => true, Seq(), "Princess solver"), "eval" -> (() => true, Seq(), "Internal evaluator") ) @@ -257,6 +263,100 @@ object SolverFactory { () => new SMTZ3OptImpl(p) }) + case "inv-z3" => create(p)(finalName, { + val emptyEnc = ProgramEncoder.empty(enc.targetProgram) + val chooses = ChooseEncoder(enc.targetProgram)(emptyEnc) + class SMTZ3Impl(override val program: enc.targetProgram.type) + extends AbstractInvariantSolver (program, ctx)(program)(emptyEnc)(emptyEnc)(using program.getSemantics, emptyEnc.targetProgram.getSemantics) + with InvariantSolver + with TimeoutSolver + with tip.TipDebugger { + + override val name = "inv-z3" + + class Underlying(override val program: targetProgram.type) + extends smtlib.SMTLIBSolver(program, context) + with smtlib.Z3Solver { + override def targetName = "z3" + import _root_.smtlib.trees.Terms + import _root_.smtlib.trees.CommandsResponses._ + import _root_.smtlib.trees.Commands._ + import _root_.smtlib.Interpreter + import _root_.smtlib.printer.Printer + import _root_.smtlib.printer.RecursivePrinter + import java.io.BufferedReader + import _root_.smtlib.interpreters.ProcessInterpreter + import _root_.smtlib.parser.Parser + import _root_.smtlib.extensions.tip.Lexer + + class HornZ3Interpreter(executable: String, + args: Array[String], + printer: Printer = RecursivePrinter, + parserCtor: BufferedReader => Parser = out => new Parser(new Lexer(out))) + extends ProcessInterpreter (executable, args, printer, parserCtor): + printer.printCommand(SetOption(PrintSuccess(true)), in) + in.write("\n") + in.flush() + parser.parseGenResponse + in.write("(set-logic HORN)\n") + in.flush() + parser.parseGenResponse + + override def eval(cmd: Terms.SExpr): Terms.SExpr = + super.eval(cmd) + + override protected val interpreter = { + val opts = interpreterOpts + // reporter.debug("Invoking solver "+targetName+" with "+opts.mkString(" ")) + new HornZ3Interpreter(targetName, opts.toArray, parserCtor = out => new Z3Parser(new Lexer(out))) + } + } + + override protected val underlyingHorn = Underlying(targetProgram) + + // encoder is from TipDebugger and enc from AbstractUnrollingSolver + override protected val encoder = emptyEnc + } + + class EncodedImpl(program: p.type, enc: transformers.ProgramTransformer { + val sourceProgram: program.type + val targetProgram: Program { val trees: inox.trees.type } + }, underlying: Solver {val program: enc.targetProgram.type}) + extends EncodingSolver(program, enc, underlying) with TimeoutSolver + + () => new EncodedImpl(p, enc, SMTZ3Impl(enc.targetProgram)) + }) + + case "inv-eld" => create(p)(finalName, { + val emptyEnc = ProgramEncoder.empty(enc.targetProgram) + val chooses = ChooseEncoder(enc.targetProgram)(emptyEnc) + class SMTEldaricaImpl(override val program: enc.targetProgram.type) + extends AbstractInvariantSolver (program, ctx)(program)(emptyEnc)(emptyEnc)(using program.getSemantics, emptyEnc.targetProgram.getSemantics) + with InvariantSolver + with TimeoutSolver + with tip.TipDebugger { + + override val name = "inv-eld" + + class Underlying(override val program: targetProgram.type) + extends smtlib.SMTLIBSolver(program, context) + with smtlib.EldaricaSolver + + override protected val underlyingHorn = Underlying(targetProgram) + + // encoder is from TipDebugger and enc from AbstractUnrollingSolver + override protected val encoder = emptyEnc + } + + class EncodedImpl(program: p.type, enc: transformers.ProgramTransformer { + val sourceProgram: program.type + val targetProgram: Program { val trees: inox.trees.type } + }, underlying: Solver {val program: enc.targetProgram.type}) + extends EncodingSolver(program, enc, underlying) with TimeoutSolver + + () => new EncodedImpl(p, enc, SMTEldaricaImpl(enc.targetProgram)) + }) + case _ if finalName == "smt-z3" || finalName.startsWith("smt-z3:") => create(p)(finalName, { class SMTZ3Impl(override val program: p.type) (override val enc: transformers.ProgramTransformer { diff --git a/src/main/scala/inox/solvers/invariant/InvariantSolver.scala b/src/main/scala/inox/solvers/invariant/InvariantSolver.scala new file mode 100644 index 000000000..5674d9937 --- /dev/null +++ b/src/main/scala/inox/solvers/invariant/InvariantSolver.scala @@ -0,0 +1,859 @@ +package inox +package solvers +package invariant + +import inox.solvers.SolverResponses.CheckConfiguration + +import inox.solvers.SolverResponses.Configuration +import scala.collection.mutable.ListBuffer +import inox.utils.IncrementalSeq +import inox.solvers.smtlib.SMTLIBSolver +import scala.collection.immutable.LazyList.cons +import inox.utils.IncrementalBijection + +trait InvariantSolver extends Solver: + import program.trees.* + +object AbstractInvariantSolver: + +end AbstractInvariantSolver + +abstract class AbstractInvariantSolver(override val program: Program, + override val context: Context) + // Alias for `program`, as there are some places where `program` is shadowed. + (val prog: program.type) + (val enc: transformers.ProgramTransformer {val sourceProgram: program.type}) + (val programEncoder: transformers.ProgramTransformer { + val sourceProgram: program.type + val targetProgram: Program { val trees: enc.targetProgram.trees.type } + }) + (using val semantics: program.Semantics, + val targetSemantics: programEncoder.targetProgram.Semantics) + extends Solver with InvariantSolver: + + /* Internal imports, types, and aliases */ + + protected final val s: programEncoder.sourceProgram.trees.type = programEncoder.sourceProgram.trees + protected final val t: programEncoder.targetProgram.trees.type = programEncoder.targetProgram.trees + protected final val targetProgram: programEncoder.targetProgram.type = programEncoder.targetProgram + + type Source = s.Expr + type Encoded = t.Expr + + import targetProgram._ + import targetProgram.trees._ + import targetProgram.symbols.{given, _} + + /* Formula representation and transformation */ + + protected object HornClauses: + /** + * Representation of a Horn clause. + * + * @param head conclusion (possibly False, implying empty head) + * @param body possibly empty sequence of premises + */ + case class Clause(head: Encoded, body: Seq[Encoded]): + /** + * Collapse this clause to a single Inox implication. + */ + def collapse: Encoded = + if body.isEmpty then head + else if body.tail.isEmpty then Implies(body.head, head) + else Implies(And(body), head) + /** + * Fresh clause with an additional guard / premise. + */ + def withGuard(expr: Encoded): Clause = + Clause(head, body :+ expr) + /** + * Fresh clause with additional guards / premises. + */ + def withGuards(exprs: Iterable[Encoded]): Clause = + Clause(head, body ++ exprs) + + + extension (expr: Encoded) + /** + * Implied-by relation. Prolog syntax for Horn clauses. + * + * @example `head :- (body0, body1, body2)` + */ + infix def :- (body: Encoded*): Clause = Clause(expr, body) + /** + * Implied-by relation. Prolog syntax for Horn clauses. + */ + @annotation.targetName("impliedBySeq") // disambiguate from the Encoded* version + infix def :- (body: Iterable[Encoded]): Clause = Clause(expr, body.toSeq) + + end HornClauses + + import HornClauses.* + + /* Interface */ + + /** + * Checks whether the given set of assumptions is satisfiable together with + * the current state. Does not permanently add them to the constraint set, + * unlike [[assertCnstr]] + [[check]]. + * + * @param config expected response configuration + * @param assumptions set of assumptions to check + * @return a response containing satisfiability result, and a model if the + * config requires it and one is available + */ + override def checkAssumptions(config: Configuration)(assumptions: Set[Source]): config.Response[Model, Assumptions] = + checkAssumptions_(config)(assumptions) + + override def declare(vd: program.trees.ValDef): Unit = + val evd = encode(vd) + + context.timers.solvers.declare.sanity.run: + assert(evd.getType.isTyped) + + // Multiple calls to registerForInterrupts are (almost) idempotent and acceptable + context.interruptManager.registerForInterrupts(this) + + registerValDef(evd) + + override def interrupt(): Unit = + abort = true + underlyingHorn.interrupt() + + override def reset(): Unit = + abort = false + failures.clear() + constraints.reset() + + override def push(): Unit = + constraints.push() + + override def pop(): Unit = + constraints.pop() + + override def free(): Unit = + failures.clear() + constraints.clear() + context.interruptManager.unregisterForInterrupts(this) + + /** + * Checks the satisfiability of the currently asserted set of constraints. + * + * @param config expected response configuration + * @return a response containing satisfiability result, and a model if the + * config requires it and one is available + */ + override def check(config: CheckConfiguration): config.Response[Model, Assumptions] = + checkAssumptions(config)(Set.empty) + + /** + * Asserts a constraint to the solver. A call to [[check]] checks the + * satisfiability of the conjunction of all asserted constraints. + * + * @param expression constraint to assert + */ + override def assertCnstr(expression: Source): Unit = + constraints += expression + + override def name: String = + "inv-" + underlyingHorn.name + + /* Internal state */ + + /** + * Underlying Horn solver to use for invariant inference. + */ + protected val underlyingHorn: SMTLIBSolver & AbstractSolver { + val program: targetProgram.type + type Trees = Encoded + } + + /** + * Predicate variables introduced during this run. Should not be quantifed, + * should not be treated as HOF. + */ + protected val predicates: IncrementalSeq[Variable] = new IncrementalSeq() + + /** + * Top-level variables declared during this run. Should not be quantified. + */ + protected val freeVariables: IncrementalSeq[Variable] = new IncrementalSeq() + + /** + * Function replacements for function invocations. + */ + protected val funReplacements: Map[Identifier, Variable] = + program.symbols.functions.map: (id, fd) => + val tpe = FunctionType(fd.params.map(p => encode(p.tpe)) :+ encode(fd.returnType), BooleanType()) + val fresh = Variable.fresh(id.name, tpe) + registerPredicate(fresh) + id -> fresh + + // run state + + /** + * List of exceptions caught during this run. Thrown when the solver is asked + * to check constraints next. + */ + protected val failures: ListBuffer[Throwable] = new ListBuffer + + /** + * Whether the solver has been externally interrupted. See [[interrupt]]. + */ + private var abort: Boolean = false + + /** + * Stack of constraints to check. + */ + protected val constraints: IncrementalSeq[Source] = new IncrementalSeq() + + /* Translation of expressions */ + + // encoding and decoding trees + + protected final def encode(vd: s.ValDef): t.ValDef = programEncoder.encode(vd) + protected final def decode(vd: t.ValDef): s.ValDef = programEncoder.decode(vd) + + protected final def encode(v: s.Variable): t.Variable = programEncoder.encode(v) + protected final def decode(v: t.Variable): s.Variable = programEncoder.decode(v) + + protected final def encode(e: s.Expr): t.Expr = programEncoder.encode(e) + protected final def decode(e: t.Expr): s.Expr = programEncoder.decode(e) + + protected final def encode(tpe: s.Type): t.Type = programEncoder.encode(tpe) + protected final def decode(tpe: t.Type): s.Type = programEncoder.decode(tpe) + + // generating Horn clauses + + type Guards = ListBuffer[Expr] + type Clauses = ListBuffer[Clause] + + case class UnexpectedMatchError(expr: Expr) extends Exception(s"Unexpected match encountered on $expr.") + + private def canonicalType(tpe: Type): Type = + // partial map for canonicalization + def typeMap(tpe: Type): Option[Type] = + tpe match + // choose HOF representation + case FunctionType(args, ret) => + Some(args.foldRight(ret) {(next, acc) => t.MapType(next, acc)}) + case _ => None + + // replace, then traverse + typeOps.preMap(typeMap)(tpe) + + inline protected def registerPredicate(pred: Variable): Unit = + predicates += pred + + inline protected def registerValDef(vd: ValDef): Unit = + freeVariables += vd.toVariable + + /** + * is this a "pure" value expression + * + * i.e. no language constructs that require their own predicates or guards. + * so, no lets, no ifs, etc. + */ + private def isSimpleValue(expr: Expr): Boolean = + expr match + case Variable(_, tpe, _) => !tpe.isInstanceOf[FunctionType] + // anything that required the introduction of a predicate + case Assume(_, _) => false + case Let(_, _, _) => false + case Application(_, _) => false + case Lambda(_, _) => false + case Forall(_, _) => false + case Choose(_, _) => false + case IfExpr(_, _, _) => false + case FunctionInvocation(_, _, args) => false // functions need to be replaced! + // everything else, including theory operators && || ! ==> + - > < >= <= == + case Operator(args, _) => args.forall(isSimpleValue) + case _ => throw UnexpectedMatchError(expr) + + object Context: + import collection.mutable.{Map => MMap} + + // guards corresponding to variables + private val variableGuards: MMap[Variable, List[Expr]] = MMap.empty + private val functionReplacements: IncrementalBijection[TypedFunDef, Expr] = IncrementalBijection() + private val functionClauses: MMap[TypedFunDef, Set[Expr]] = MMap.empty + private val callReplacements: IncrementalBijection[Variable, FunctionInvocation] = IncrementalBijection() + + def functionOf(v: Variable): Option[FunctionInvocation] = + callReplacements.getB(v) + + def isFunctionalPredicate(v: Variable): Boolean = + functionReplacements.containsB(v) + + def predicateOf(tfd: TypedFunDef): Expr = + functionReplacements.cachedB(tfd) { + // this is an unseen typed def, generate a new predicate and clause set + // also populate the defining clauses + val funTpe = tfd.functionType + val pred = generateFreshPredicate(tfd.id.name, (funTpe.from :+ funTpe.to).map(canonicalType)) + pred + } + + def definingClauses(tfd: TypedFunDef): Set[Expr] = + if functionClauses.contains(tfd) then + functionClauses(tfd) + else + // unseen def. Unusual but ok. + generateClauses(tfd) // force generation + definingClauses(tfd) + + private def generateClauses(tfd: TypedFunDef): Unit = + val clauses = encodeFunction(tfd) + functionClauses(tfd) = clauses + + def addGuard(v: Variable, guard: Expr): Unit = + variableGuards(v) = guard :: variableGuards.getOrElse(v, Nil) + + def addCallGuard(v: Variable, call: FunctionInvocation): Unit = + callReplacements += v -> call + + def guardFor(v: Variable): List[Expr] = + val call = callReplacements.getB(v) + val callGuard = + if call.isDefined then + val tfd = targetProgram.symbols.getFunction(call.get.id).typed(call.get.tps) + val pred = predicateOf(tfd) + List(Application(pred, call.get.args :+ v)) + else + Nil + callGuard ++ variableGuards.getOrElse(v, Nil) + + def transitiveVarsFor(v: Variable): Set[Variable] = + def rec(vs: Set[Variable]) = + vs.flatMap(guardFor) + .flatMap(exprOps.variablesOf) + .filterNot(isFree) + utils.fixpoint(rec)(Set(v)) + + def transitiveGuardsFor(v: Variable): Set[Expr] = + def rec(guards: Set[Expr]) = + guards ++ + guards + .flatMap(exprOps.variablesOf) + .flatMap(guardFor) + + utils.fixpoint(rec)(guardFor(v).toSet) + + def toReplace(id: Identifier): Boolean = + targetProgram.symbols.functions.contains(id) + + def freshFunctionReplacement(fn: FunctionInvocation): Variable = + callReplacements.cachedA(fn){ + val tpe = fn.getType + val fresh = Variable.fresh(s"${fn.id}_call", tpe, true) + addCallGuard(fresh, fn) + fresh + } + + def insertGuards(clause: Clause): Clause = + val exprs = clause.body.filter(isSimpleValue) + val variables = exprs + .flatMap(exprOps.variablesOf) + .filterNot(isFree) + val newGuards = variables.flatMap(transitiveGuardsFor) + clause.withGuards(newGuards) + + /** + * "Purify" an expression by replacing function invocations with fresh + * variables and registering guards for the variables as predicate + * invocations. + * + * @param expr expression to purify + */ + private def purify(expr: Expr): Expr = + def _purify(expr: Expr): Option[Expr] = + expr match + case fn @ FunctionInvocation(id, _, _) if Context.toReplace(id) => + Some(Context.freshFunctionReplacement(fn)) + case _ => None + + exprOps.postMap(_purify)(expr) + + private case class ClauseResult(clauses: Clauses, guards: Guards, assertions: ListBuffer[List[Expr]], body: Expr => Expr) + + private inline def conjunction(exprs: Iterable[Expr]): Expr = + if exprs.isEmpty then BooleanLiteral(true) + else if exprs.tail.isEmpty then exprs.head + else And(exprs.toSeq) + + private def freeVariablesOf(expr: Expr): Seq[Variable] = + val baseFrees = exprOps.variablesOf(expr) + val guardFrees = baseFrees + .flatMap(Context.transitiveVarsFor) + .flatMap(exprOps.variablesOf) + .filterNot(isFree) + (baseFrees ++ guardFrees).toSeq + + /** + * Given an expression, construct a new predicate P and relevant clauses such that + * + * res == expr ==> P(res) (+ free variables) + * + * Clauses may need recursive exploration and construction. + * + * @return defining clauses, empty guards, and P(res) + */ + private def generatePredicate(name: String, expr: Expr, pathCondition: List[Expr]): ClauseResult = + val res = Variable.fresh("res", expr.getType) + val frees = freeVariablesOf(expr) + val args = frees :+ res + val inputType = args.map(_.tpe).map(canonicalType) + val pred = generateFreshPredicate(name, inputType) + val inner = encodeClauses(expr, pathCondition) + val aply = (x: Expr) => Application(pred, frees :+ x) + + // construct constraints + val predClause = aply(res) :- (inner.body(res) +: inner.guards) + + // guards do not escape predicate definitions + ClauseResult(inner.clauses :+ predClause, ListBuffer.empty, inner.assertions, aply) + + private def encodeClauses(expr: Expr, pathCondition: List[Expr]): ClauseResult = + val clauses = ListBuffer.empty[Clause] + val guards = ListBuffer.empty[Expr] + val assertions = ListBuffer.empty[List[Expr]] + val freeVars = freeVariablesOf(expr) + + inline def ret(body: Expr => Expr) = ClauseResult(clauses, guards, assertions, body) + inline def subsume(inner: ClauseResult): Unit = + clauses ++= inner.clauses + guards ++= inner.guards + assertions ++= inner.assertions + + if isSimpleValue(expr) then + // notably, only HOFs or simple expressions should introduce Equals + return ret(Equals(_, expr)) + + // otherwise, need to elaborate + expr match + case v @ Variable(id, FunctionType(_, _), _) => + // other variable cases are simple expression, and covered above + // return the HOF representation variable corresponding to this + ret(Equals(_, makeHOFVariable(v))) + + case Assume(cond, body) => + // expr == Assume(cond, body) + // iff expr = body, under cond as a guard + + // condition appears in the guards, so we have to decide whether to elaborate or leave it + val newCond = + if isSimpleValue(cond) then + // condition is a theory expression, leave as is + cond + else + // elaborate + val innerCond = generatePredicate("AssumeCond", cond, pathCondition) + subsume(innerCond) + innerCond.body(BooleanLiteral(true)) + + val invCond = + val ncond = Not(cond) + if isSimpleValue(ncond) then + // condition is a theory expression, leave as is + ncond + else + // elaborate + val innerCond = generatePredicate("NegatedAssumeCond", ncond, pathCondition) + subsume(innerCond) + innerCond.body(BooleanLiteral(true)) + + guards += newCond + assertions += invCond :: pathCondition + + val newBody = + // body is the target expression, it must always be elaborated + val bodyRes = encodeClauses(body, newCond :: pathCondition) + subsume(bodyRes) + bodyRes.body + + val res = Variable.fresh("AssumeExpr", expr.getType) + val pred = generateFreshPredicate("AssumeExpr", (freeVars.map(_.tpe) :+ res.getType).map(canonicalType)) + val topClause = Application(pred, freeVars :+ res) .:- (newCond, newBody(res)) + clauses += topClause + + ret(res => Application(pred, freeVars :+ res)) + + case Choose(res, pred) => + // unexpected, what do we do regarding chooses? TODO: simply evaluate + // under uninterpreted values or attempt to synthesize something in + // their place? + ??? + + case Let(vd, value, body) => + // expr == let (vd = value) in body + // iff expr = body under the guard that vd = value + // vd = value becomes conditionOf(value)(vd) + + val cond = + val inner = generatePredicate("LetInnerValue", value, pathCondition) + subsume(inner) + inner.body + + Context.addGuard(vd.toVariable, cond(vd.toVariable)) + + val bodyPred = + val inner = encodeClauses(body, pathCondition) + subsume(inner) + inner.body + + ret(bodyPred) + + case IfExpr(cond, thenn, elze) => + // expr == if cond then e1 else e2 + // iff expr == e1 under guard cond + // and expr == e2 under guard !cond + + val tpe = thenn.getType + + val condHolds = + // split on whether to elaborate cond + if isSimpleValue(cond) then + cond + else + val inner = generatePredicate("IfCond", cond, pathCondition) + subsume(inner) + inner.body(BooleanLiteral(true)) + + val condInv = Not(cond) + + // the encoding of the inverse of the condition can (should) involve a different predicate + val condInvHolds = + // split on whether to elaborate cond + if isSimpleValue(condInv) then + condInv + else + val inner = generatePredicate("NegatedIfCond", condInv, pathCondition) + subsume(inner) + inner.body(BooleanLiteral(true)) + + val condLabel = Variable.fresh("IfCondition", tpe) + + val thennResult = generatePredicate("ThenBranch", thenn, condHolds :: pathCondition) + val elzeResult = generatePredicate("ElseBranch", elze, condInvHolds :: pathCondition) + + subsume(thennResult) + subsume(elzeResult) + + // create a new predicate for this branch + val newPred = generateFreshPredicate("IfNode", (freeVars.map(_.tpe) :+ tpe).map(canonicalType)) + + val branch: Expr => Expr = (res: Expr) => Application(newPred, freeVars :+ res) + + val positiveClause = branch(condLabel) .:- (thennResult.body(condLabel), condHolds) + val negativeClause = branch(condLabel) .:- (elzeResult.body(condLabel), condInvHolds) + + clauses += positiveClause + clauses += negativeClause + + ret(branch) + + case Application(l @ Lambda(params, body), args) => + val fun = + val inner = encodeClauses(l, pathCondition) + subsume(inner) + inner.body + + val funVar = Variable.fresh("fun", canonicalType(l.getType)) + Context.addGuard(funVar, fun(funVar)) + + val newArgs = args.map: a => + if isSimpleValue(a) then + a + else + // register a new guard + val newArg = Variable.fresh("arg", a.getType) + val inner = generatePredicate("LambdaArg", a, pathCondition) + subsume(inner) + Context.addGuard(newArg, inner.body(newArg)) + newArg + + val applied = newArgs.foldLeft(funVar: Expr)((acc, next) => MapApply(acc, next)) + + ret(Equals(_, applied)) + + case l @ Lambda(args, body) => + // eliminate it into an integer, and generate an applicative + // predicate for it, if not already done + val (definingClauses, identifier) = makeLambda(l) + clauses ++= definingClauses + ret(Equals(_, l)) + + case Forall(args, body) => + // Quantifier? Unexpected + ??? + + case FunctionInvocation(id, tps, args) if funReplacements.contains(id) => + // this expression should have been purified + throw UnexpectedMatchError(expr) + + case Application(l : Variable, args) if l.tpe.isInstanceOf[FunctionType] => + // translate HOF applications to array selections + val repr = makeHOFVariable(l) + + val newArgs = args.map: a => + if isSimpleValue(a) then + a + else + // register a new guard + val newArg = Variable.fresh("arg", a.getType) + val inner = generatePredicate("HOFArg", a, pathCondition) + subsume(inner) + Context.addGuard(newArg, inner.body(newArg)) + newArg + + val applied = newArgs.foldLeft(repr: Expr)((acc, next) => MapApply(acc, next)) + + ret(Equals(_, applied)) + + // handle booleans separately + case Not(inner) => + val newRes = Variable.fresh("NotExpr", BooleanType()) + val innerResult = encodeClauses(inner, pathCondition) + subsume(innerResult) + + val newPred = generateFreshPredicate("NotExpr", (freeVars.map(_.tpe) :+ BooleanType()).map(canonicalType)) + clauses += Application(newPred, freeVars :+ newRes) :- (innerResult.body(Not(newRes))) + + ret(res => Application(newPred, freeVars :+ res)) + + case Or(inners) => + val newPred = generateFreshPredicate("OrExpr", (freeVars.map(_.tpe) :+ BooleanType()).map(canonicalType)) + inners + .map(encodeClauses(_, pathCondition)) + .map: inner => + subsume(inner) + clauses += Application(newPred, freeVars :+ BooleanLiteral(true)) :- (inner.body(BooleanLiteral(true)) +: inner.guards) + + ret(res => Application(newPred, freeVars :+ res)) + + case And(inners) => + val newPred = generateFreshPredicate("AndExpr", (freeVars.map(_.tpe) :+ BooleanType()).map(canonicalType)) + val lhs = inners + .map(encodeClauses(_, pathCondition)) + .map: inner => + subsume(inner) + inner.body(BooleanLiteral(true)) + + clauses += Application(newPred, freeVars :+ BooleanLiteral(true)) :- (lhs ++ guards) + + ret(res => Application(newPred, freeVars :+ res)) + + // other operator + // deconstruct arguments similar to a function call + case Operator(args, recons) => + val newArgs = args.map: a => + if isSimpleValue(a) then + a + else + // register a new guard + val newArg = Variable.fresh("arg", canonicalType(a.getType), true) + val inner = generatePredicate("OperatorArg", a, pathCondition) + subsume(inner) + Context.addGuard(newArg, inner.body(newArg)) + newArg + + val newExpr = recons(newArgs) + val body = Equals(_, newExpr) + + ret(body) + + case _ => throw UnexpectedMatchError(expr) + + /** + * Generate a fresh predicate variable of given name and input type, and + * register it as a predicate for the underlying solver. + * @return fresh predicate variable + **/ + private def generateFreshPredicate(name: String, inputType: Seq[Type]): Variable = + val tpe = FunctionType(inputType, BooleanType()) + val pred = Variable.fresh(name, tpe, true) + registerPredicate(pred) + pred + + // TODO: Move + + private val HOVarLookup: collection.mutable.Map[Variable, Expr] = collection.mutable.Map.empty + private val LambdaLookup: collection.mutable.Map[Lambda, (Clauses, Variable)] = collection.mutable.Map.empty + + /** + * Recall or register the representation for an HOF variable + * + * @param v variable to replace + * @return known repr. if available, or new registered repr. otherwise + */ + private def makeHOFVariable(v: Variable): Expr = + require(v.tpe.isInstanceOf[FunctionType]) + HOVarLookup.getOrElseUpdate(v, Variable.fresh("HOF", canonicalType(v.tpe))) + + /** + * Recall or register a lambda expression, clauses for its evaluation + * + * @param l + * @return + */ + private def makeLambda(l: Lambda): (Clauses, Variable) = + LambdaLookup.getOrElseUpdate(l, { + val Lambda(args, body) = l + val res = Variable.fresh("LambdaRes", canonicalType(l.getType)) + val applied = args.foldLeft(res: Expr)((acc, next) => MapApply(acc, next.toVariable)) + val inner = generatePredicate("LambdaBody", body, List.empty) // lambdas should not be conditionally valid? + val topClause = inner.body(res) :- Equals(res, applied) + + (inner.clauses :+ topClause, res) + }) + + private def extractModel(model: underlyingHorn.Model): Model = + ??? + + private def reportInvariants(model: underlyingHorn.Model): Unit = + // context.reporter.info( + // s"Discovered Invariant for: $model" + // ) + () + + protected def encodeFunction(tfd: TypedFunDef): Set[Expr] = + // collect data + val body = purify(tfd.fullBody) + val pred = Context.predicateOf(tfd) + val outType = tfd.returnType + val args = tfd.params.map(_.toVariable) + + val res = Variable.fresh("res", outType) + + // actually generate clauses + val ClauseResult(clauses, guards, assertions, inner) = encodeClauses(body, List.empty) + + val appliedFun: Encoded = Application(pred, args :+ res) + + val topClause = appliedFun :- inner(res) + + // add goal clauses from top-level assertions + assertions.foreach: as => + clauses += (BooleanLiteral(false) :- (appliedFun +: as)) + + (topClause +: clauses) + .to(Set) + .map(Context.insertGuards) + .map(_.collapse) + .map(quantify) + + protected def isFree(v: Variable): Boolean = + freeVariables.exists(_ == v) || predicates.exists(_ == v) + + protected def quantify(clause: Expr): Expr = + val frees = exprOps + .variablesOf(clause) + .filterNot(isFree) + .map(_.toVal) + .toSeq + + Forall(frees, clause) + + /* Communicate with solver */ + + protected lazy val emptyModel: underlyingHorn.Model = + inox.Model(targetProgram)(Map.empty, Map.empty) + // underlyingHorn.checkAssumptions(SolverResponses.Model)(Set.empty) match + // case SolverResponses.SatWithModel(model) => model + // case _ => throw new Exception("Could not construct empty model.") + + protected def emptyProgramModel: Model = + inox.Model(program)(Map.empty, Map.empty) + + private def encodeAssumptions(assumptions: Set[Source]): Set[Encoded] = + assumptions + .map(encode) + .map(purify) + .map(encodeClauses(_, List.empty)) + .flatMap { + case ClauseResult(clauses, guards, assertions, predicate) => + // false :- assumption /\ guards + val topClause = BooleanLiteral(false) :- (predicate(BooleanLiteral(true)) +: guards) + + clauses += topClause + + // do we ever expect these to be non-empty? Don't think so + // your assertions should not come with assume statements inside them + assert(assertions.isEmpty) + + // assertions.foreach: conds => // SHOULD be empty though? FIXME: ? + // clauses += (BooleanLiteral(false) :- conds) + + clauses + } + .map(Context.insertGuards) + .map(_.collapse) + .map(quantify) + + private def encodeFunctionsForAssumptions(assumptions: Set[Source]): Set[Encoded] = + // all functions that appear in the assumptions, transitively + val baseCalls = assumptions + .map(encode) + .flatMap(exprOps.functionCallsOf) + .filter(f => Context.toReplace(f.id)) + + def transitiveCalls(calls: Set[FunctionInvocation]): Set[FunctionInvocation] = + calls.flatMap: call => + val funDef = targetProgram.symbols.getFunction(call.id) + val typedDef = funDef.typed(call.tps) + val body = typedDef.fullBody + val newCalls = exprOps.functionCallsOf(body) + .filter(f => Context.toReplace(f.id)) + newCalls + call + + // termination is guaranteed by Stainless' type checking beforehand + val calls = utils.fixpoint(transitiveCalls)(baseCalls) + + calls + .flatMap: call => + val funDef = targetProgram.symbols.getFunction(call.id) + val typedDef = funDef.typed(call.tps) + Context.definingClauses(typedDef) + + /** + * Invariant-generating implementation of [[checkAssumptions]]. + */ + private def checkAssumptions_(config: Configuration)(assumptions: Set[Source]): config.Response[Model, Assumptions] = + + // send constraints to solver + val totalAssumptions = assumptions ++ constraints + + // Horn encode assumptions + val assumptionClauses = encodeAssumptions(totalAssumptions) + + // find and encode all function calls (recursively) + val definitionClauses = encodeFunctionsForAssumptions(totalAssumptions) + + // declare variables + predicates.foreach(underlyingHorn.registerPredicate) + + + (assumptionClauses ++ definitionClauses).foreach(underlyingHorn.assertCnstr) + + // check satisfiability + val underlyingResult = underlyingHorn.checkAssumptions(config)(Set.empty) + + // interpret result + val res = + underlyingResult match + case SolverResponses.SatWithModel(model) => + // report discovery of invariants + reportInvariants(model) + // discard underlying model. We cannot construct a program model (cex) from + // the Horn model + config.cast(SolverResponses.Unsat) + + case SolverResponses.Check(r) => + lazy val satRes = if config.withModel then SolverResponses.SatWithModel(emptyProgramModel) else SolverResponses.Sat + config.cast(if r then SolverResponses.Unsat else satRes) + + case _ => config.cast(SolverResponses.Unknown) // unknown or unreachable + + res + +end AbstractInvariantSolver diff --git a/src/main/scala/inox/solvers/smtlib/EldaricaInterpreter.scala b/src/main/scala/inox/solvers/smtlib/EldaricaInterpreter.scala new file mode 100644 index 000000000..2d26d0895 --- /dev/null +++ b/src/main/scala/inox/solvers/smtlib/EldaricaInterpreter.scala @@ -0,0 +1,177 @@ +package inox.solvers.smtlib + +import _root_.smtlib.trees.Terms.* +import _root_.smtlib.printer.Printer +import _root_.smtlib.parser.Parser +import _root_.smtlib.Interpreter +import _root_.smtlib.theories.* +import java.io.BufferedReader +import _root_.smtlib.trees.CommandsResponses.* +import _root_.smtlib.trees.Commands.* +import java.io.StringReader +import java.util.concurrent.Future + +/** + * + * + * @param printer + * @param parser + */ +class EldaricaInterpreter(val printer: Printer, val parserCtor: BufferedReader => Parser) extends Interpreter { + + import collection.mutable.{Stack => MStack, Seq => MSeq} + + private val commands = MStack(MSeq.empty[SExpr]) + + private var lastModelResponse: Option[GetModelResponse] = None + + val parser = parserCtor(new BufferedReader(new StringReader(""))) // dummy parser + + private class InterruptibleExecutor[T]: + private var task: Option[Future[T]] = None + + private val executor = scala.concurrent.ExecutionContext.fromExecutorService(null) + + private def asCallable[A](block: => A): java.util.concurrent.Callable[A] = + new java.util.concurrent.Callable[A] { def call(): A = block } + + def execute(block: => T): Option[T] = + this.synchronized: // run only one task at a time + task = Some(executor.submit(asCallable(block))) + + val res = + try + Some(task.get.get()) // block for result or interrupt + catch + case e: java.util.concurrent.CancellationException => None // externally interrupted + case e: Exception => throw e + + task = None + res + + def interrupt(): Unit = + task.foreach(_.cancel(true)) + + private val executor = new InterruptibleExecutor[SExpr] + + /** + * args to run Eldarica calls under + */ + private val eldArgs = Array( + "-in", // read input from (simulated) stdin + "-hsmt", // use SMT-LIB2 input format + "-disj" // use disjunctive interpolation + ) + + def eval(cmd: SExpr): SExpr = + cmd match + case CheckSat() => + checkSat + case CheckSatAssuming(assumptions) => + def toAssertion(lit: PropLiteral): SExpr = + val PropLiteral(sym, polarity) = lit + val id = QualifiedIdentifier(SimpleIdentifier(sym), Some(Core.BoolSort())) + val term = if polarity then id else Core.Not(id) + Assert(term) + + commands.push(assumptions.map(toAssertion).to(MSeq)) + val res = checkSat + commands.pop() + res + case Echo(value) => + EchoResponseSuccess(value.toString) + case Exit() => + // equivalent to reset + commands.clear() + trySuccess + case GetInfo(flag) => + flag match + case VersionInfoFlag() => + GetInfoResponseSuccess(VersionInfoResponse("0.1"), Seq.empty) + case _ => Unsupported + case GetModel() => + getModel + case Pop(n) => + (1 to n).foreach(_ => commands.pop()) + trySuccess + case Push(n) => + (1 to n).foreach(_ => commands.push(MSeq())) + trySuccess + case Reset() => + commands.clear() + trySuccess + case SetOption(option) => + // slightly haphazard + // but we always expect that PrintSuccess(true) has been passed as the first command + trySuccess + case _ => + commands.push(commands.pop() :+ cmd) + trySuccess + + //A free method is kind of justified by the need for the IO streams to be closed, and + //there seems to be a decent case in general to have such a method for things like solvers + //note that free can be used even if the solver is currently solving, and act as a sort of interrupt + def free(): Unit = + commands.clear() + + def interrupt(): Unit = + executor.interrupt() + + private def trySuccess: SExpr = Success + + private def collapsedCommands = commands.toSeq.flatten + + private def checkSat: SExpr = + this.synchronized { + // reset last model + setLastModelResponse(None) + executor + .execute(seqCheckSat) + .getOrElse(CheckSatStatus(UnknownStatus)) + } + + private def seqCheckSat: SExpr = + val commands = collapsedCommands :+ CheckSat() + val script = commands.map(printer.toString).mkString("\n") + + val inputStream = new java.io.StringReader(script) + + val buffer = new java.io.ByteArrayOutputStream + val printStream = new java.io.PrintStream(buffer) + + // actually check sat, requesting a model if possible + Console.withIn(inputStream): + Console.withOut(printStream): + lazabs.Main.doMain(eldArgs, false) + + val eldRes = new java.io.BufferedReader(new java.io.StringReader(buffer.toString)) + + val parser = parserCtor(eldRes) + + val result = parser.parseCheckSatResponse + + result match + case CheckSatStatus(SatStatus) => + // FIXME: @sg: disabled due to non-SMTLIB compliant model printing from eldarica + // there will be a parser exception if we attemp this + // // if Sat, parse and store model + // val model = parser.parseGetModelResponse + // // could be a model or an error, in either case, this is the response for (get-model) + // setLastModelResponse(Some(model)) + setLastModelResponse(Some(GetModelResponseSuccess(Nil))) // empty model + result + case _ => + // if unsat or unknown, reset the model + setLastModelResponse(None) + result + + private def setLastModelResponse(model: Option[GetModelResponse]): Unit = + lastModelResponse = model + + private def getModel: SExpr = + lastModelResponse match + case Some(modelResponse) => + modelResponse + case None => + Error("No model available") +} diff --git a/src/main/scala/inox/solvers/smtlib/EldaricaSolver.scala b/src/main/scala/inox/solvers/smtlib/EldaricaSolver.scala new file mode 100644 index 000000000..760e1b534 --- /dev/null +++ b/src/main/scala/inox/solvers/smtlib/EldaricaSolver.scala @@ -0,0 +1,19 @@ +/* Copyright 2009-2018 EPFL, Lausanne */ + +package inox +package solvers +package smtlib + +import _root_.{smtlib => sl} +import _root_.smtlib.trees.Terms.{Identifier => _, _} +import _root_.smtlib.trees.CommandsResponses._ + +trait EldaricaSolver extends SMTLIBSolver with EldaricaTarget { + + protected val interpreter: sl.Interpreter = + new EldaricaInterpreter(sl.printer.RecursivePrinter, out => sl.parser.Parser(sl.extensions.tip.Lexer(out))) + + def targetName = "eldarica" + + protected def interpreterOpts: Seq[String] = Seq.empty +} diff --git a/src/main/scala/inox/solvers/smtlib/EldaricaTarget.scala b/src/main/scala/inox/solvers/smtlib/EldaricaTarget.scala new file mode 100644 index 000000000..0da970233 --- /dev/null +++ b/src/main/scala/inox/solvers/smtlib/EldaricaTarget.scala @@ -0,0 +1,19 @@ +/* Copyright 2009-2018 EPFL, Lausanne */ + +package inox +package solvers +package smtlib + +import _root_.smtlib.trees.Terms.{Identifier => SMTIdentifier, _} +import _root_.smtlib.trees.Commands._ +import _root_.smtlib.theories._ +import _root_.smtlib.theories.cvc._ + +trait EldaricaTarget extends SMTLIBTarget with SMTLIBDebugger { + import context.{given, _} + import program._ + import program.trees._ + import program.symbols.{given, _} + + override protected def toSMT(e: Expr)(using bindings: Map[Identifier, Term]) = super.toSMT(e) +} diff --git a/src/main/scala/inox/solvers/smtlib/SMTLIBParser.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBParser.scala index 2641612db..b1f6cae2f 100644 --- a/src/main/scala/inox/solvers/smtlib/SMTLIBParser.scala +++ b/src/main/scala/inox/solvers/smtlib/SMTLIBParser.scala @@ -223,6 +223,10 @@ trait SMTLIBParser { val MapType(from, to) = fromSMT(sort): @unchecked FiniteMap(Seq.empty, d, from, to) + case AnnotatedTerm(term, attribute, attributes) => + // discard general annotations + fromSMT(term, otpe) + case _ => throw new MissformedSMTException(term, s"Unknown SMT term of class: ${term.getClass}:\n$term" ) diff --git a/src/main/scala/inox/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBSolver.scala index 77b1a1e8e..f3f0887f1 100644 --- a/src/main/scala/inox/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/inox/solvers/smtlib/SMTLIBSolver.scala @@ -63,10 +63,11 @@ abstract class SMTLIBSolver private(override val program: Program, config: Configuration, res: SExpr, assumptions: Set[Expr] - ): config.Response[Model, Assumptions] = config.cast(res match { + ): config.Response[Model, Assumptions] = + config.cast(res match { case CheckSatStatus(SatStatus) => if (config.withModel) { - val syms = variables.bSet + val syms = variables.bSet ++ predicates.bSet emit(GetModel()) match { case GetModelResponseSuccess(smodel) => // first-pass to gather functions @@ -78,11 +79,12 @@ abstract class SMTLIBSolver private(override val program: Program, val ctx = new Context(variables.bToA, Map(), modelFunDefs) val vars = smodel.flatMap { - case DefineFun(SMTFunDef(s, _, _, e)) if syms(s) => + case DefineFun(SMTFunDef(s, args, _, e)) if syms(s) => try { - val v = variables.toA(s) - val value = fromSMT(e, v.getType)(using ctx) - Some(v.toVal -> value) + val v = variables.getA(s).getOrElse(predicates.toA(s)) + val vargs = args.map(fromSMT(_)(using ctx).toVariable) + val value = fromSMT(e, v.getType)(using ctx.withVariables(args.map(_.name) zip vargs)) + Some(v.toVal -> value) } catch { case _: Unsupported => None case _: java.lang.StackOverflowError => None diff --git a/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala index 70c3076ac..e0e09f4c4 100644 --- a/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala @@ -72,13 +72,6 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers { } } - def parseSuccess() = { - val res = interpreter.parser.parseGenResponse - if (res != Success) { - reporter.warning("Unnexpected result from " + targetName + ": " + res + " expected success") - } - } - /* * Translation from Inox Expressions to SMTLIB terms and reverse */ @@ -115,6 +108,13 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers { protected val sorts = new IncrementalBijection[Type, Sort] protected val functions = new IncrementalBijection[TypedFunDef, SSymbol] protected val lambdas = new IncrementalBijection[FunctionType, SSymbol] + protected val predicates = new IncrementalBijection[Variable, SSymbol] + protected val predicateCalls= new IncrementalBijection[Expr, SSymbol] + + def registerPredicate(v: Variable): Unit = { + val s = id2sym(v.id) + predicates += v -> s + } /* Helper functions */ @@ -140,6 +140,10 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers { quantifiedTerm(quantifier, exprOps.variablesOf(body).toSeq.map(_.toVal), body) protected final def declareSort(tpe: Type): Sort = + // println(s"@@@@ CALLED WITH $tpe with ${tpe match + // case ADTType(id, _) => id.uniqueName + // case _ => "None" + // } and found? ${if sorts.containsA(tpe) then "cached" else "NONONO"}") sorts.cachedB(tpe)(computeSort(tpe)) protected def computeSort(t: Type): Sort = t match { @@ -200,8 +204,12 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers { protected def declareVariable(v: Variable): SSymbol = { variables.cachedB(v) { val s = id2sym(v.id) - val cmd = DeclareFun(s, List(), declareSort(v.getType)) - emit(cmd) + if predicates.containsA(v) then + () // handled by [[declarePredicate]] separately already + else + val cmd = DeclareFun(s, List(), declareSort(v.getType)) + emit(cmd) + s } } @@ -342,8 +350,23 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers { * ===== Everything else ===== */ case ap @ Application(caller, args) => - val dyn = declareLambda(caller.getType.asInstanceOf[FunctionType]) - FunctionApplication(dyn, (caller +: args).map(toSMT)) + caller match + case v: Variable if predicates.containsA(v) => + val pred = predicateCalls.cachedB(caller) { + caller match + case pred @ Variable(id, FunctionType(from, to), flags) => + val s = id2sym(id) + emit(DeclareFun( + s, + from.map(declareSort), + declareSort(to) + )) + s + } + FunctionApplication(pred, args.map(toSMT)) + case _ => // normal lambda + val dyn = declareLambda(caller.getType.asInstanceOf[FunctionType]) + FunctionApplication(dyn, (caller +: args).map(toSMT)) case Not(u) => Core.Not(toSMT(u)) diff --git a/src/test/scala/inox/solvers/InvariantSolverSuite.scala b/src/test/scala/inox/solvers/InvariantSolverSuite.scala new file mode 100644 index 000000000..1b70efcb1 --- /dev/null +++ b/src/test/scala/inox/solvers/InvariantSolverSuite.scala @@ -0,0 +1,72 @@ +/* Copyright 2009-2024 EPFL, Lausanne */ + +package inox +package solvers + +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.Tag + +class InvariantSolverSuite extends AnyFunSuite { + import inox.trees._ + import dsl._ + + val ctx = TestContext.empty + val p = InoxProgram(NoSymbols) + + val solverOpts = List( + (() => SolverFactory.hasZ3, "inv-z3"), + (() => true, "inv-eld"), + ) + + val solverNames = solverOpts.flatMap {case (cond, name) => if cond() then List(name) else Nil} + + def getSolver(ctx: Context) = + SimpleSolverAPI(p.getSolver(ctx)) + + protected def test(name: String, tags: Tag*)(body: Context => Unit): Unit = { + for + sname <- solverNames + ctx = TestContext(Options(Seq(optSelectedSolvers(Set(sname))))) + do + super.test(s"$name ($sname)", tags*) { + body(ctx) + } + } + + test("Validity of true") { ctx => + val solver = getSolver(ctx) + solver.solveVALID(BooleanLiteral(true)) match { + case Some(true) => () + case Some(false) => fail("True is not valid") + case None => fail("Solver returned unknown") + } + } + + test("Unsatisfiability of false") { ctx => + val solver = getSolver(ctx) + solver.solveSAT(BooleanLiteral(false)) match { + case SolverResponses.Unsat => () + case SolverResponses.SatWithModel(_) => fail("False is not invalid") + case _ => fail("Solver returned unknown response") + } + } + + test("Integer arithmetic") { ctx => + val solver = getSolver(ctx) + val x = Variable.fresh("x", IntegerType()) + val y = Variable.fresh("y", IntegerType()) + val z = Variable.fresh("z", IntegerType()) + + val eqs = List( + x === y + z, + y === IntegerLiteral(1), + z === IntegerLiteral(2) + ) + + solver.solveSAT(andJoin(eqs)) match { + case SolverResponses.SatWithModel(model) => () + case SolverResponses.Unsat => fail("Trivial integer arithmetic is unsat") + case _ => fail("Solver returned unknown response") + } + } +}