diff --git a/include/klee/ADT/DisjointSetUnion.h b/include/klee/ADT/DisjointSetUnion.h index b479ca0765d..746fdc38b16 100644 --- a/include/klee/ADT/DisjointSetUnion.h +++ b/include/klee/ADT/DisjointSetUnion.h @@ -150,6 +150,27 @@ class DisjointSetUnion { } } + void getAllDependentSets(ValueType value, + std::vector> &result) const { + ref compare = new SetType(value); + for (auto &r : roots) { + ref ics = disjointSets.at(r); + if (SetType::intersects(ics, compare)) { + result.push_back(ics); + } + } + } + void getAllIndependentSets(ValueType value, + std::vector> &result) const { + ref compare = new SetType(value); + for (auto &r : roots) { + ref ics = disjointSets.at(r); + if (!SetType::intersects(ics, compare)) { + result.push_back(ics); + } + } + } + DisjointSetUnion() {} DisjointSetUnion(const internalStorage_ty &is) { diff --git a/include/klee/Solver/IncompleteSolver.h b/include/klee/Solver/IncompleteSolver.h index 65dac30c4be..777e6722c43 100644 --- a/include/klee/Solver/IncompleteSolver.h +++ b/include/klee/Solver/IncompleteSolver.h @@ -58,14 +58,18 @@ class IncompleteSolver { /// StagedSolver - Adapter class for staging an incomplete solver with /// a complete secondary solver, to form an (optimized) complete /// solver. + +typedef std::function QueryPredicate; + class StagedSolverImpl : public SolverImpl { private: std::unique_ptr primary; std::unique_ptr secondary; + QueryPredicate predicate; public: StagedSolverImpl(std::unique_ptr primary, - std::unique_ptr secondary); + std::unique_ptr secondary, QueryPredicate predicate); bool computeTruth(const Query &, bool &isValid); bool computeValidity(const Query &, PartialValidity &result); diff --git a/lib/ADT/SparseStorage.cpp b/lib/ADT/SparseStorage.cpp index dc24e7d9937..935cb6922ab 100644 --- a/lib/ADT/SparseStorage.cpp +++ b/lib/ADT/SparseStorage.cpp @@ -35,7 +35,7 @@ void SparseStorage::print(llvm::raw_ostream &os, } os << "] default: "; } - os << defaultValue; + os << ((unsigned)defaultValue); } template <> diff --git a/lib/Expr/IndependentConstraintSetUnion.cpp b/lib/Expr/IndependentConstraintSetUnion.cpp index 01358e26234..f4ecf876de7 100644 --- a/lib/Expr/IndependentConstraintSetUnion.cpp +++ b/lib/Expr/IndependentConstraintSetUnion.cpp @@ -95,27 +95,13 @@ void IndependentConstraintSetUnion::reEvaluateConcretization( void IndependentConstraintSetUnion::getAllIndependentConstraintSets( ref e, std::vector> &result) const { - ref compare = - new IndependentConstraintSet(new ExprEitherSymcrete::left(e)); - for (auto &r : roots) { - ref ics = disjointSets.at(r); - if (!IndependentConstraintSet::intersects(ics, compare)) { - result.push_back(ics); - } - } + getAllIndependentSets(new ExprEitherSymcrete::left(e), result); } void IndependentConstraintSetUnion::getAllDependentConstraintSets( ref e, std::vector> &result) const { - ref compare = - new IndependentConstraintSet(new ExprEitherSymcrete::left(e)); - for (auto &r : roots) { - ref ics = disjointSets.at(r); - if (IndependentConstraintSet::intersects(ics, compare)) { - result.push_back(ics); - } - } + getAllDependentSets(new ExprEitherSymcrete::left(e), result); } void IndependentConstraintSetUnion::addExpr(ref e) { diff --git a/lib/Solver/FastCexSolver.cpp b/lib/Solver/FastCexSolver.cpp index f4c1b377aad..6d640eff863 100644 --- a/lib/Solver/FastCexSolver.cpp +++ b/lib/Solver/FastCexSolver.cpp @@ -18,6 +18,7 @@ #include "klee/Solver/IncompleteSolver.h" #include "klee/Support/Debug.h" #include "klee/Support/ErrorHandling.h" +#include "klee/Support/OptionCategories.h" #include "klee/Support/CompilerWarning.h" DISABLE_WARNING_PUSH @@ -33,6 +34,20 @@ DISABLE_WARNING_POP #include using namespace klee; +using namespace llvm; + +namespace { +enum class FastCexSolverType { EQUALITY, ALL }; + +cl::opt FastCexFor( + "fast-cex-for", + cl::desc( + "Specifiy a query predicate to filter queries for FastCexSolver using"), + cl::values(clEnumValN(FastCexSolverType::EQUALITY, "equality", + "Query with only equality expressions"), + clEnumValN(FastCexSolverType::ALL, "all", "All queries")), + cl::init(FastCexSolverType::EQUALITY), cl::cat(SolvingCat)); +} // namespace // Hacker's Delight, pgs 58-63 static uint64_t minOR(uint64_t a, uint64_t b, uint64_t c, uint64_t d) { @@ -403,10 +418,12 @@ class CexPossibleEvaluator : public ExprEvaluator { ref getInitialValue(const Array &array, unsigned index) { // If the index is out of range, we cannot assign it a value, since that // value cannot be part of the assignment. - ref constantArraySize = dyn_cast(array.size); + ref constantArraySize = + dyn_cast(visit(array.size)); if (!constantArraySize) { - klee_error( - "FIXME: Arrays of symbolic sizes are unsupported in FastCex\n"); + visit(array.size)->dump(); + klee_error("FIXME: CexPossibleEvaluator: Arrays of symbolic sizes are " + "unsupported in FastCex\n"); std::abort(); } @@ -433,11 +450,11 @@ class CexExactEvaluator : public ExprEvaluator { ref getInitialValue(const Array &array, unsigned index) { // If the index is out of range, we cannot assign it a value, since that // value cannot be part of the assignment. - ref constantArraySize = dyn_cast(array.size); + ref constantArraySize = + dyn_cast(visit(array.size)); if (!constantArraySize) { - klee_error( - "FIXME: Arrays of symbolic sizes are unsupported in FastCex\n"); - std::abort(); + return ReadExpr::create(UpdateList(&array, 0), + ConstantExpr::alloc(index, array.getDomain())); } if (index >= constantArraySize->getZExtValue()) { @@ -485,10 +502,11 @@ class CexData { CexObjectData &getObjectData(const Array *A) { CexObjectData *&Entry = objects[A]; - ref constantArraySize = dyn_cast(A->size); + ref constantArraySize = + dyn_cast(evaluatePossible(A->size)); if (!constantArraySize) { - klee_error( - "FIXME: Arrays of symbolic sizes are unsupported in FastCex\n"); + klee_error("FIXME: CexData: Arrays of symbolic sizes are unsupported in " + "FastCex\n"); std::abort(); } @@ -529,7 +547,7 @@ class CexData { // to see if this is an initial read or not. if (ConstantExpr *CE = dyn_cast(re->index)) { if (ref constantArraySize = - dyn_cast(array->size)) { + dyn_cast(evaluatePossible(array->size))) { uint64_t index = CE->getZExtValue(); if (index < constantArraySize->getZExtValue()) { @@ -1171,6 +1189,7 @@ bool FastCexSolver::computeInitialValues( const Query &query, const std::vector &objects, std::vector> &values, bool &hasSolution) { CexData cd; + query.dump(); bool isValid; bool success = propagateValues(query, cd, true, isValid); @@ -1187,7 +1206,7 @@ bool FastCexSolver::computeInitialValues( for (unsigned i = 0; i != objects.size(); ++i) { const Array *array = objects[i]; assert(array); - SparseStorage data; + SparseStorage data(0); ref arrayConstantSize = dyn_cast(cd.evaluatePossible(array->size)); assert(arrayConstantSize && @@ -1212,7 +1231,45 @@ bool FastCexSolver::computeInitialValues( return true; } +class OnlyEqualityWithConstantQueryPredicate { +public: + explicit OnlyEqualityWithConstantQueryPredicate() {} + + bool operator()(const Query &query) const { + for (auto constraint : query.constraints.cs()) { + if (const EqExpr *ee = dyn_cast(constraint)) { + if (!isa(ee->left)) { + return false; + } + } else { + return false; + } + } + if (ref ee = dyn_cast(query.negateExpr().expr)) { + if (!isa(ee->left)) { + return false; + } + } else { + return false; + } + return true; + } +}; + +class TrueQueryPredicate { +public: + explicit TrueQueryPredicate() {} + + bool operator()(const Query &query) const { return true; } +}; + std::unique_ptr klee::createFastCexSolver(std::unique_ptr s) { - return std::make_unique(std::make_unique( - std::make_unique(), std::move(s))); + if (FastCexFor == FastCexSolverType::EQUALITY) { + return std::make_unique(std::make_unique( + std::make_unique(), std::move(s), + OnlyEqualityWithConstantQueryPredicate())); + } else { + return std::make_unique(std::make_unique( + std::make_unique(), std::move(s), TrueQueryPredicate())); + } } diff --git a/lib/Solver/IncompleteSolver.cpp b/lib/Solver/IncompleteSolver.cpp index 85ad5a8d6d2..10436c4f634 100644 --- a/lib/Solver/IncompleteSolver.cpp +++ b/lib/Solver/IncompleteSolver.cpp @@ -49,15 +49,19 @@ PartialValidity IncompleteSolver::computeValidity(const Query &query) { /***/ StagedSolverImpl::StagedSolverImpl(std::unique_ptr primary, - std::unique_ptr secondary) - : primary(std::move(primary)), secondary(std::move(secondary)) {} + std::unique_ptr secondary, + QueryPredicate predicate) + : primary(std::move(primary)), secondary(std::move(secondary)), + predicate(predicate) {} bool StagedSolverImpl::computeTruth(const Query &query, bool &isValid) { - PartialValidity trueResult = primary->computeTruth(query); + if (predicate(query)) { + PartialValidity trueResult = primary->computeTruth(query); - if (trueResult != PValidity::None) { - isValid = (trueResult == PValidity::MustBeTrue); - return true; + if (trueResult != PValidity::None) { + isValid = (trueResult == PValidity::MustBeTrue); + return true; + } } return secondary->impl->computeTruth(query, isValid); @@ -65,44 +69,48 @@ bool StagedSolverImpl::computeTruth(const Query &query, bool &isValid) { bool StagedSolverImpl::computeValidity(const Query &query, PartialValidity &result) { - bool tmp; - - switch (primary->computeValidity(query)) { - case PValidity::MustBeTrue: - result = PValidity::MustBeTrue; - break; - case PValidity::MustBeFalse: - result = PValidity::MustBeFalse; - break; - case PValidity::TrueOrFalse: - result = PValidity::TrueOrFalse; - break; - case PValidity::MayBeTrue: - if (secondary->impl->computeTruth(query, tmp)) { - - result = tmp ? PValidity::MustBeTrue : PValidity::TrueOrFalse; - } else { - result = PValidity::MayBeTrue; - } - break; - case PValidity::MayBeFalse: - if (secondary->impl->computeTruth(query.negateExpr(), tmp)) { - result = tmp ? PValidity::MustBeFalse : PValidity::TrueOrFalse; - } else { - result = PValidity::MayBeFalse; + if (predicate(query)) { + bool tmp; + + switch (primary->computeValidity(query)) { + case PValidity::MustBeTrue: + result = PValidity::MustBeTrue; + break; + case PValidity::MustBeFalse: + result = PValidity::MustBeFalse; + break; + case PValidity::TrueOrFalse: + result = PValidity::TrueOrFalse; + break; + case PValidity::MayBeTrue: + if (secondary->impl->computeTruth(query, tmp)) { + + result = tmp ? PValidity::MustBeTrue : PValidity::TrueOrFalse; + } else { + result = PValidity::MayBeTrue; + } + break; + case PValidity::MayBeFalse: + if (secondary->impl->computeTruth(query.negateExpr(), tmp)) { + result = tmp ? PValidity::MustBeFalse : PValidity::TrueOrFalse; + } else { + result = PValidity::MayBeFalse; + } + break; + default: + if (!secondary->impl->computeValidity(query, result)) + return false; + break; } - break; - default: - if (!secondary->impl->computeValidity(query, result)) - return false; - break; + } else { + return secondary->impl->computeValidity(query, result); } return true; } bool StagedSolverImpl::computeValue(const Query &query, ref &result) { - if (primary->computeValue(query, result)) + if (predicate(query) && primary->computeValue(query, result)) return true; return secondary->impl->computeValue(query, result); @@ -111,7 +119,8 @@ bool StagedSolverImpl::computeValue(const Query &query, ref &result) { bool StagedSolverImpl::computeInitialValues( const Query &query, const std::vector &objects, std::vector> &values, bool &hasSolution) { - if (primary->computeInitialValues(query, objects, values, hasSolution)) + if (predicate(query) && + primary->computeInitialValues(query, objects, values, hasSolution)) return true; return secondary->impl->computeInitialValues(query, objects, values, @@ -119,17 +128,19 @@ bool StagedSolverImpl::computeInitialValues( } bool StagedSolverImpl::check(const Query &query, ref &result) { - std::vector objects; - findSymbolicObjects(query, objects); - std::vector> values; - - bool hasSolution; - - bool primaryResult = - primary->computeInitialValues(query, objects, values, hasSolution); - if (primaryResult && hasSolution) { - result = new InvalidResponse(objects, values); - return true; + if (predicate(query)) { + std::vector objects; + findSymbolicObjects(query, objects); + std::vector> values; + + bool hasSolution; + + bool primaryResult = + primary->computeInitialValues(query, objects, values, hasSolution); + if (primaryResult && hasSolution) { + result = new InvalidResponse(objects, values); + return true; + } } return secondary->impl->check(query, result); @@ -138,6 +149,14 @@ bool StagedSolverImpl::check(const Query &query, ref &result) { bool StagedSolverImpl::computeValidityCore(const Query &query, ValidityCore &validityCore, bool &isValid) { + if (predicate(query)) { + PartialValidity trueResult = primary->computeTruth(query); + + if (trueResult == PValidity::MayBeFalse) { + isValid = false; + return true; + } + } return secondary->impl->computeValidityCore(query, validityCore, isValid); } diff --git a/lib/Solver/SolverCmdLine.cpp b/lib/Solver/SolverCmdLine.cpp index 0f51525d535..a71fc0100b3 100644 --- a/lib/Solver/SolverCmdLine.cpp +++ b/lib/Solver/SolverCmdLine.cpp @@ -43,7 +43,7 @@ cl::OptionCategory SolvingCat("Constraint solving options", "These options impact constraint solving."); cl::opt UseFastCexSolver( - "use-fast-cex-solver", cl::init(false), + "use-fast-cex-solver", cl::init(true), cl::desc("Enable an experimental range-based solver (default=false)"), cl::cat(SolvingCat)); diff --git a/test/Feature/DanglingConcreteReadExpr.c b/test/Feature/DanglingConcreteReadExpr.c index 1f8a5a347aa..588072125cc 100644 --- a/test/Feature/DanglingConcreteReadExpr.c +++ b/test/Feature/DanglingConcreteReadExpr.c @@ -1,7 +1,7 @@ // RUN: %clang %s -emit-llvm %O0opt -c -o %t1.bc // RUN: rm -rf %t.klee-out // RUN: %klee --optimize=false --output-dir=%t.klee-out %t1.bc -// RUN: grep "total queries = 1" %t.klee-out/info +// RUN: grep "total queries = 0" %t.klee-out/info #include @@ -12,8 +12,7 @@ int main() { y = x; - // should be exactly one query (prove x is 10) - // eventually should be 0 when we have fast solver + // should be exactly 0 query, finally we have enough optimizations if (x == 10) { assert(y == 10); } diff --git a/test/Solver/FastCexSolver.kquery b/test/Solver/FastCexSolver.kquery index 271609859bb..8e82c844683 100644 --- a/test/Solver/FastCexSolver.kquery +++ b/test/Solver/FastCexSolver.kquery @@ -1,4 +1,4 @@ -# RUN: %kleaver --use-fast-cex-solver --solver-backend=dummy %s > %t +# RUN: %kleaver --use-fast-cex-solver --fast-cex-for=all --solver-backend=dummy %s > %t # RUN: not grep FAIL %t makeSymbolic0 : (array (w64 4) (makeSymbolic arr1 0))