diff --git a/include/souper/Infer/EnumerativeSynthesis.h b/include/souper/Infer/EnumerativeSynthesis.h index c15b15d14..84f8ad659 100644 --- a/include/souper/Infer/EnumerativeSynthesis.h +++ b/include/souper/Infer/EnumerativeSynthesis.h @@ -30,6 +30,8 @@ namespace souper { class EnumerativeSynthesis { public: + EnumerativeSynthesis(); + // Synthesize an instruction from the specification in LHS std::error_code synthesize(SMTLIBSolver *SMTSolver, const BlockPCs &BPCs, @@ -38,6 +40,10 @@ class EnumerativeSynthesis { bool CheckAllGuesses, InstContext &IC, unsigned Timeout); + std::vector + generateExprs(InstContext &IC, size_t CountLimit, + std::vector Vars, size_t Width); + }; } diff --git a/lib/Infer/EnumerativeSynthesis.cpp b/lib/Infer/EnumerativeSynthesis.cpp index 66bce8352..4ee94d45f 100644 --- a/lib/Infer/EnumerativeSynthesis.cpp +++ b/lib/Infer/EnumerativeSynthesis.cpp @@ -33,8 +33,8 @@ extern unsigned DebugLevel; using namespace souper; using namespace llvm; -static const std::vector UnaryOperators = { - Inst::CtPop, Inst::BSwap, Inst::BitReverse, Inst::Cttz, Inst::Ctlz, Inst::Freeze +static std::vector UnaryOperators = { + Inst::CtPop, Inst::BSwap, Inst::BitReverse, Inst::Cttz, Inst::Ctlz }; static const std::vector BinaryOperators = { @@ -91,6 +91,9 @@ namespace { static cl::opt IgnoreCost("souper-enumerative-synthesis-ignore-cost", cl::desc("Ignore cost of RHSes -- just generate them. (default=false)"), cl::init(false)); + static cl::opt SynFreeze("souper-synthesize-freeze", + cl::desc("Generate Freeze (default=true)"), + cl::init(true)); static cl::opt MaxLHSCands("souper-max-lhs-cands", cl::desc("Gather at most this many values from a LHS to use as synthesis inputs (default=8)"), cl::init(8)); @@ -881,3 +884,41 @@ EnumerativeSynthesis::synthesize(SMTLIBSolver *SMTSolver, return EC; } + +EnumerativeSynthesis::EnumerativeSynthesis() { + if (SynFreeze) { + UnaryOperators.push_back(Inst::Freeze); + } +} + +std::vector +EnumerativeSynthesis::generateExprs(InstContext &IC, size_t CountLimit, + std::vector Vars, size_t Width) { + MaxNumInstructions = CountLimit; + + std::set Visited; + std::vector PruneFuncs = { [&Visited](Inst *I, std::vector &ReservedInsts) { + return CountPrune(I, ReservedInsts, Visited); + }}; + auto PruneCallback = MkPruneFunc(PruneFuncs); + + std::vector Guesses; + + int TooExpensive = CountLimit + 1; + + for (auto I : Vars) { + if (I->Width == Width) + addGuess(I, Width, IC, TooExpensive, Guesses, TooExpensive); + } + + auto Generate = [&Guesses](Inst *Guess) { + Guesses.push_back(Guess); + return true; + }; + + getGuesses(Vars, Width, TooExpensive, IC, nullptr, + nullptr, TooExpensive, PruneCallback, Generate); + + return Guesses; +} + diff --git a/test/Generalize/symbolize.opt b/test/Generalize/symbolize.opt new file mode 100644 index 000000000..7c52b493a --- /dev/null +++ b/test/Generalize/symbolize.opt @@ -0,0 +1,11 @@ +; REQUIRES: solver, synthesis +; RUN: %generalize -symbolize --souper-synthesize-freeze=false --generalization-num-results=2 %s | %souper-check > %t +; RUN: %FileCheck %s < %t + +%x:i8 = var +%foo = add %x, 2 +%bar = sub %foo, %x +infer %bar +result 2:i8 +;CHECK: LGTM +;CHECK: LGTM diff --git a/tools/generalize.cpp b/tools/generalize.cpp index 865125c6f..92393e101 100644 --- a/tools/generalize.cpp +++ b/tools/generalize.cpp @@ -3,7 +3,7 @@ #include "llvm/Support/KnownBits.h" #include "souper/Infer/Preconditions.h" - +#include "souper/Infer/EnumerativeSynthesis.h" #include "souper/Inst/InstGraph.h" #include "souper/Parser/Parser.h" #include "souper/Tool/GetSolver.h" @@ -31,11 +31,21 @@ static llvm::cl::opt RemoveLeaf("remove-leaf", "(default=false)"), llvm::cl::init(false)); +static llvm::cl::opt SymbolizeConstant("symbolize", + llvm::cl::desc("Try to replace a concrete constant with a symbolic constant." + "(default=false)"), + llvm::cl::init(false)); + static llvm::cl::opt FixIt("fixit", llvm::cl::desc("Given an invalid optimization, generate a valid one." "(default=false)"), llvm::cl::init(false)); +static cl::opt NumResults("generalization-num-results", + cl::desc("Number of Generalization Results"), + cl::init(5)); + + void Generalize(InstContext &IC, Solver *S, ParsedReplacement Input) { bool FoundWP = false; std::vector> Results; @@ -54,10 +64,120 @@ void Generalize(InstContext &IC, Solver *S, ParsedReplacement Input) { } } +void SymbolizeAndGeneralize(InstContext &IC, + Solver *S, ParsedReplacement Input) { + std::vector LHSConsts, RHSConsts; + auto Pred = [](Inst *I) {return I->K == Inst::Const;}; + findInsts(Input.Mapping.LHS, LHSConsts, Pred); + findInsts(Input.Mapping.RHS, RHSConsts, Pred); + + if (LHSConsts.size() != 1 || RHSConsts.size() != 1) { + return; + // TODO: Relax this restriction later + } + + // Replace LHSConst[0] with a new variable and RHSConst[0] + // with a synthesized function. + + auto FakeConst = IC.createVar(LHSConsts[0]->Width, "fakeconst"); + + // Does it makes sense for the expression to depend on other variables? + // If yes, expand the third argument to include inputs + EnumerativeSynthesis ES; + auto Guesses = ES.generateExprs(IC, 2, {FakeConst}, + RHSConsts[0]->Width); + + // Discarding guesses with symbolic constants + // Find a way to avoid this + // Here is the problem that needs to be solved for this: + // Given f and g\C, find N and P such that P -> (f == c\N) + // Weakest possible P is desirable. + + std::vector Filtered; + for (auto Guess : Guesses) { + std::set ConstSet; + std::map ResultConstMap; + souper::getConstants(Guess, ConstSet); + if (ConstSet.empty()) { + Filtered.push_back(Guess); + } + } + std::swap(Guesses, Filtered); + + std::vector>> + Preconditions; + + std::map InstCache{{LHSConsts[0], FakeConst}}; + std::map BlockCache; + std::map ConstMap; + auto LHS = getInstCopy(Input.Mapping.LHS, IC, InstCache, + BlockCache, &ConstMap, false); + for (auto &Guess : Guesses) { + std::map InstCache{{RHSConsts[0], Guess}}; + auto RHS = getInstCopy(Input.Mapping.RHS, IC, InstCache, + BlockCache, &ConstMap, false); + std::vector> Results; + bool FoundWP = false; + InstMapping Mapping(LHS, RHS); + S->abstractPrecondition(Input.BPCs, Input.PCs, Mapping, IC, FoundWP, Results); + + Preconditions.push_back(Results); + if (!FoundWP) { + Guess = nullptr; // TODO: Better failure indicator + } else { + Guess = RHS; + } + } + + std::vector Idx; + std::vector Utility; + for (size_t i = 0; i < Guesses.size(); ++i) { + Idx.push_back(i); + } + for (size_t i = 0; i < Preconditions.size(); ++i) { + Utility.push_back(0); + if (!Guesses[i]) continue; + if (Preconditions[i].empty()) { + Utility[i] = 1000; // High magic number + } + + for (auto V : Preconditions[i]) { + for (auto P : V) { + auto W = P.second.getBitWidth(); + Utility[i] += (W - P.second.Zero.countPopulation()); + Utility[i] += (W - P.second.One.countPopulation()); + } + } + } + + std::sort(Idx.begin(), Idx.end(), [&Utility](size_t a, size_t b) { + return Utility[a] > Utility[b]; + }); + for (size_t i = 0; i < std::min(Idx.size(), NumResults.getValue()); ++i) { + if (Preconditions[Idx[i]].empty()) { + ReplacementContext RC; + auto LHSStr = RC.printInst(LHS, llvm::outs(), true); + llvm::outs() << "infer " << LHSStr << "\n"; + auto RHSStr = RC.printInst(Guesses[Idx[i]], llvm::outs(), true); + llvm::outs() << "result " << RHSStr << "\n\n"; + } + for (auto Results : Preconditions[Idx[i]]) { + for (auto Pair : Results) { + Pair.first->KnownOnes = Pair.second.One; + Pair.first->KnownZeros = Pair.second.Zero; + } + ReplacementContext RC; + auto LHSStr = RC.printInst(LHS, llvm::outs(), true); + llvm::outs() << "infer " << LHSStr << "\n"; + auto RHSStr = RC.printInst(Guesses[Idx[i]], llvm::outs(), true); + llvm::outs() << "result " << RHSStr << "\n\n"; + } + } +} + // TODO: Return modified instructions instead of just printing out void RemoveLeafAndGeneralize(InstContext &IC, Solver *S, ParsedReplacement Input) { - if (DebugLevel > 1) { llvm::errs() << "Attempting to generalize by removing leaf.\n"; } @@ -155,7 +275,9 @@ int main(int argc, char **argv) { RemoveLeafAndGeneralize(IC, S.get(), Input); } // if (EviscerateRoot) {...} - // if (SymbolizeConstant) {...} + if (SymbolizeConstant) { + SymbolizeAndGeneralize(IC, S.get(), Input); + } // if (LiberateWidth) {...} }