Skip to content

Commit

Permalink
[feat] Use FastCexSolver in restricted cases that are are fairly ea…
Browse files Browse the repository at this point in the history
…sy to solve
  • Loading branch information
misonijnik committed Oct 24, 2023
1 parent c70df43 commit 3ac7f67
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 87 deletions.
21 changes: 21 additions & 0 deletions include/klee/ADT/DisjointSetUnion.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,27 @@ class DisjointSetUnion {
}
}

void getAllDependentSets(ValueType value,
std::vector<ref<const SetType>> &result) const {
ref<const SetType> compare = new SetType(value);
for (auto &r : roots) {
ref<const SetType> ics = disjointSets.at(r);
if (SetType::intersects(ics, compare)) {
result.push_back(ics);
}
}
}
void getAllIndependentSets(ValueType value,
std::vector<ref<const SetType>> &result) const {
ref<const SetType> compare = new SetType(value);
for (auto &r : roots) {
ref<const SetType> ics = disjointSets.at(r);
if (!SetType::intersects(ics, compare)) {
result.push_back(ics);
}
}
}

DisjointSetUnion() {}

DisjointSetUnion(const internalStorage_ty &is) {
Expand Down
6 changes: 5 additions & 1 deletion include/klee/Solver/IncompleteSolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,18 @@ class IncompleteSolver {
/// StagedSolver - Adapter class for staging an incomplete solver with
/// a complete secondary solver, to form an (optimized) complete
/// solver.

typedef std::function<bool(const Query &)> QueryPredicate;

class StagedSolverImpl : public SolverImpl {
private:
std::unique_ptr<IncompleteSolver> primary;
std::unique_ptr<Solver> secondary;
QueryPredicate predicate;

public:
StagedSolverImpl(std::unique_ptr<IncompleteSolver> primary,
std::unique_ptr<Solver> secondary);
std::unique_ptr<Solver> secondary, QueryPredicate predicate);

bool computeTruth(const Query &, bool &isValid);
bool computeValidity(const Query &, PartialValidity &result);
Expand Down
2 changes: 1 addition & 1 deletion lib/ADT/SparseStorage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void SparseStorage<unsigned char>::print(llvm::raw_ostream &os,
}
os << "] default: ";
}
os << defaultValue;
os << ((unsigned)defaultValue);
}

template <>
Expand Down
18 changes: 2 additions & 16 deletions lib/Expr/IndependentConstraintSetUnion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,27 +95,13 @@ void IndependentConstraintSetUnion::reEvaluateConcretization(
void IndependentConstraintSetUnion::getAllIndependentConstraintSets(
ref<Expr> e,
std::vector<ref<const IndependentConstraintSet>> &result) const {
ref<const IndependentConstraintSet> compare =
new IndependentConstraintSet(new ExprEitherSymcrete::left(e));
for (auto &r : roots) {
ref<const IndependentConstraintSet> ics = disjointSets.at(r);
if (!IndependentConstraintSet::intersects(ics, compare)) {
result.push_back(ics);
}
}
getAllIndependentSets(new ExprEitherSymcrete::left(e), result);
}

void IndependentConstraintSetUnion::getAllDependentConstraintSets(
ref<Expr> e,
std::vector<ref<const IndependentConstraintSet>> &result) const {
ref<const IndependentConstraintSet> compare =
new IndependentConstraintSet(new ExprEitherSymcrete::left(e));
for (auto &r : roots) {
ref<const IndependentConstraintSet> ics = disjointSets.at(r);
if (IndependentConstraintSet::intersects(ics, compare)) {
result.push_back(ics);
}
}
getAllDependentSets(new ExprEitherSymcrete::left(e), result);
}

void IndependentConstraintSetUnion::addExpr(ref<Expr> e) {
Expand Down
84 changes: 70 additions & 14 deletions lib/Solver/FastCexSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "klee/Solver/IncompleteSolver.h"
#include "klee/Support/Debug.h"
#include "klee/Support/ErrorHandling.h"
#include "klee/Support/OptionCategories.h"

#include "klee/Support/CompilerWarning.h"
DISABLE_WARNING_PUSH
Expand All @@ -33,6 +34,20 @@ DISABLE_WARNING_POP
#include <vector>

using namespace klee;
using namespace llvm;

namespace {
enum class FastCexSolverType { EQUALITY, ALL };

cl::opt<FastCexSolverType> FastCexFor(
"fast-cex-for",
cl::desc(
"Specifiy a query predicate to filter queries for FastCexSolver using"),
cl::values(clEnumValN(FastCexSolverType::EQUALITY, "equality",
"Query with only equality expressions"),
clEnumValN(FastCexSolverType::ALL, "all", "All queries")),
cl::init(FastCexSolverType::EQUALITY), cl::cat(SolvingCat));
} // namespace

// Hacker's Delight, pgs 58-63
static uint64_t minOR(uint64_t a, uint64_t b, uint64_t c, uint64_t d) {
Expand Down Expand Up @@ -403,10 +418,11 @@ class CexPossibleEvaluator : public ExprEvaluator {
ref<Expr> getInitialValue(const Array &array, unsigned index) {
// If the index is out of range, we cannot assign it a value, since that
// value cannot be part of the assignment.
ref<ConstantExpr> constantArraySize = dyn_cast<ConstantExpr>(array.size);
ref<ConstantExpr> constantArraySize =
dyn_cast<ConstantExpr>(visit(array.size));
if (!constantArraySize) {
klee_error(
"FIXME: Arrays of symbolic sizes are unsupported in FastCex\n");
klee_error("FIXME: CexPossibleEvaluator: Arrays of symbolic sizes are "
"unsupported in FastCex\n");
std::abort();
}

Expand All @@ -433,11 +449,11 @@ class CexExactEvaluator : public ExprEvaluator {
ref<Expr> getInitialValue(const Array &array, unsigned index) {
// If the index is out of range, we cannot assign it a value, since that
// value cannot be part of the assignment.
ref<ConstantExpr> constantArraySize = dyn_cast<ConstantExpr>(array.size);
ref<ConstantExpr> constantArraySize =
dyn_cast<ConstantExpr>(visit(array.size));
if (!constantArraySize) {
klee_error(
"FIXME: Arrays of symbolic sizes are unsupported in FastCex\n");
std::abort();
return ReadExpr::create(UpdateList(&array, 0),
ConstantExpr::alloc(index, array.getDomain()));
}

if (index >= constantArraySize->getZExtValue()) {
Expand Down Expand Up @@ -485,10 +501,11 @@ class CexData {
CexObjectData &getObjectData(const Array *A) {
CexObjectData *&Entry = objects[A];

ref<ConstantExpr> constantArraySize = dyn_cast<ConstantExpr>(A->size);
ref<ConstantExpr> constantArraySize =
dyn_cast<ConstantExpr>(evaluatePossible(A->size));
if (!constantArraySize) {
klee_error(
"FIXME: Arrays of symbolic sizes are unsupported in FastCex\n");
klee_error("FIXME: CexData: Arrays of symbolic sizes are unsupported in "
"FastCex\n");
std::abort();
}

Expand Down Expand Up @@ -529,7 +546,7 @@ class CexData {
// to see if this is an initial read or not.
if (ConstantExpr *CE = dyn_cast<ConstantExpr>(re->index)) {
if (ref<ConstantExpr> constantArraySize =
dyn_cast<ConstantExpr>(array->size)) {
dyn_cast<ConstantExpr>(evaluatePossible(array->size))) {
uint64_t index = CE->getZExtValue();

if (index < constantArraySize->getZExtValue()) {
Expand Down Expand Up @@ -1171,6 +1188,7 @@ bool FastCexSolver::computeInitialValues(
const Query &query, const std::vector<const Array *> &objects,
std::vector<SparseStorage<unsigned char>> &values, bool &hasSolution) {
CexData cd;
query.dump();

bool isValid;
bool success = propagateValues(query, cd, true, isValid);
Expand All @@ -1187,7 +1205,7 @@ bool FastCexSolver::computeInitialValues(
for (unsigned i = 0; i != objects.size(); ++i) {
const Array *array = objects[i];
assert(array);
SparseStorage<unsigned char> data;
SparseStorage<unsigned char> data(0);
ref<ConstantExpr> arrayConstantSize =
dyn_cast<ConstantExpr>(cd.evaluatePossible(array->size));
assert(arrayConstantSize &&
Expand All @@ -1212,7 +1230,45 @@ bool FastCexSolver::computeInitialValues(
return true;
}

class OnlyEqualityWithConstantQueryPredicate {
public:
explicit OnlyEqualityWithConstantQueryPredicate() {}

bool operator()(const Query &query) const {
for (auto constraint : query.constraints.cs()) {
if (const EqExpr *ee = dyn_cast<EqExpr>(constraint)) {
if (!isa<ConstantExpr>(ee->left)) {
return false;
}
} else {
return false;
}
}
if (ref<EqExpr> ee = dyn_cast<EqExpr>(query.negateExpr().expr)) {
if (!isa<ConstantExpr>(ee->left)) {
return false;
}
} else {
return false;
}
return true;
}
};

class TrueQueryPredicate {
public:
explicit TrueQueryPredicate() {}

bool operator()(const Query &query) const { return true; }
};

std::unique_ptr<Solver> klee::createFastCexSolver(std::unique_ptr<Solver> s) {
return std::make_unique<Solver>(std::make_unique<StagedSolverImpl>(
std::make_unique<FastCexSolver>(), std::move(s)));
if (FastCexFor == FastCexSolverType::EQUALITY) {
return std::make_unique<Solver>(std::make_unique<StagedSolverImpl>(
std::make_unique<FastCexSolver>(), std::move(s),
OnlyEqualityWithConstantQueryPredicate()));
} else {
return std::make_unique<Solver>(std::make_unique<StagedSolverImpl>(
std::make_unique<FastCexSolver>(), std::move(s), TrueQueryPredicate()));
}
}
117 changes: 68 additions & 49 deletions lib/Solver/IncompleteSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,60 +49,68 @@ PartialValidity IncompleteSolver::computeValidity(const Query &query) {
/***/

StagedSolverImpl::StagedSolverImpl(std::unique_ptr<IncompleteSolver> primary,
std::unique_ptr<Solver> secondary)
: primary(std::move(primary)), secondary(std::move(secondary)) {}
std::unique_ptr<Solver> secondary,
QueryPredicate predicate)
: primary(std::move(primary)), secondary(std::move(secondary)),
predicate(predicate) {}

bool StagedSolverImpl::computeTruth(const Query &query, bool &isValid) {
PartialValidity trueResult = primary->computeTruth(query);
if (predicate(query)) {
PartialValidity trueResult = primary->computeTruth(query);

if (trueResult != PValidity::None) {
isValid = (trueResult == PValidity::MustBeTrue);
return true;
if (trueResult != PValidity::None) {
isValid = (trueResult == PValidity::MustBeTrue);
return true;
}
}

return secondary->impl->computeTruth(query, isValid);
}

bool StagedSolverImpl::computeValidity(const Query &query,
PartialValidity &result) {
bool tmp;

switch (primary->computeValidity(query)) {
case PValidity::MustBeTrue:
result = PValidity::MustBeTrue;
break;
case PValidity::MustBeFalse:
result = PValidity::MustBeFalse;
break;
case PValidity::TrueOrFalse:
result = PValidity::TrueOrFalse;
break;
case PValidity::MayBeTrue:
if (secondary->impl->computeTruth(query, tmp)) {

result = tmp ? PValidity::MustBeTrue : PValidity::TrueOrFalse;
} else {
result = PValidity::MayBeTrue;
}
break;
case PValidity::MayBeFalse:
if (secondary->impl->computeTruth(query.negateExpr(), tmp)) {
result = tmp ? PValidity::MustBeFalse : PValidity::TrueOrFalse;
} else {
result = PValidity::MayBeFalse;
if (predicate(query)) {
bool tmp;

switch (primary->computeValidity(query)) {
case PValidity::MustBeTrue:
result = PValidity::MustBeTrue;
break;
case PValidity::MustBeFalse:
result = PValidity::MustBeFalse;
break;
case PValidity::TrueOrFalse:
result = PValidity::TrueOrFalse;
break;
case PValidity::MayBeTrue:
if (secondary->impl->computeTruth(query, tmp)) {

result = tmp ? PValidity::MustBeTrue : PValidity::TrueOrFalse;
} else {
result = PValidity::MayBeTrue;
}
break;
case PValidity::MayBeFalse:
if (secondary->impl->computeTruth(query.negateExpr(), tmp)) {
result = tmp ? PValidity::MustBeFalse : PValidity::TrueOrFalse;
} else {
result = PValidity::MayBeFalse;
}
break;
default:
if (!secondary->impl->computeValidity(query, result))
return false;
break;
}
break;
default:
if (!secondary->impl->computeValidity(query, result))
return false;
break;
} else {
return secondary->impl->computeValidity(query, result);
}

return true;
}

bool StagedSolverImpl::computeValue(const Query &query, ref<Expr> &result) {
if (primary->computeValue(query, result))
if (predicate(query) && primary->computeValue(query, result))
return true;

return secondary->impl->computeValue(query, result);
Expand All @@ -111,25 +119,28 @@ bool StagedSolverImpl::computeValue(const Query &query, ref<Expr> &result) {
bool StagedSolverImpl::computeInitialValues(
const Query &query, const std::vector<const Array *> &objects,
std::vector<SparseStorage<unsigned char>> &values, bool &hasSolution) {
if (primary->computeInitialValues(query, objects, values, hasSolution))
if (predicate(query) &&
primary->computeInitialValues(query, objects, values, hasSolution))
return true;

return secondary->impl->computeInitialValues(query, objects, values,
hasSolution);
}

bool StagedSolverImpl::check(const Query &query, ref<SolverResponse> &result) {
std::vector<const Array *> objects;
findSymbolicObjects(query, objects);
std::vector<SparseStorage<unsigned char>> values;

bool hasSolution;

bool primaryResult =
primary->computeInitialValues(query, objects, values, hasSolution);
if (primaryResult && hasSolution) {
result = new InvalidResponse(objects, values);
return true;
if (predicate(query)) {
std::vector<const Array *> objects;
findSymbolicObjects(query, objects);
std::vector<SparseStorage<unsigned char>> values;

bool hasSolution;

bool primaryResult =
primary->computeInitialValues(query, objects, values, hasSolution);
if (primaryResult && hasSolution) {
result = new InvalidResponse(objects, values);
return true;
}
}

return secondary->impl->check(query, result);
Expand All @@ -138,6 +149,14 @@ bool StagedSolverImpl::check(const Query &query, ref<SolverResponse> &result) {
bool StagedSolverImpl::computeValidityCore(const Query &query,
ValidityCore &validityCore,
bool &isValid) {
if (predicate(query)) {
PartialValidity trueResult = primary->computeTruth(query);

if (trueResult == PValidity::MayBeFalse) {
isValid = false;
return true;
}
}
return secondary->impl->computeValidityCore(query, validityCore, isValid);
}

Expand Down
Loading

0 comments on commit 3ac7f67

Please sign in to comment.