Skip to content

Commit

Permalink
Generalize constants
Browse files Browse the repository at this point in the history
  • Loading branch information
manasij7479 committed Dec 2, 2020
1 parent 63c8fc3 commit b06aef9
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 5 deletions.
6 changes: 6 additions & 0 deletions include/souper/Infer/EnumerativeSynthesis.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ namespace souper {

class EnumerativeSynthesis {
public:
EnumerativeSynthesis();

// Synthesize an instruction from the specification in LHS
std::error_code synthesize(SMTLIBSolver *SMTSolver,
const BlockPCs &BPCs,
Expand All @@ -38,6 +40,10 @@ class EnumerativeSynthesis {
bool CheckAllGuesses,
InstContext &IC, unsigned Timeout);

std::vector<Inst *>
generateExprs(InstContext &IC, size_t CountLimit,
std::vector<Inst *> Vars, size_t Width);

};
}

Expand Down
45 changes: 43 additions & 2 deletions lib/Infer/EnumerativeSynthesis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ extern unsigned DebugLevel;
using namespace souper;
using namespace llvm;

static const std::vector<Inst::Kind> UnaryOperators = {
Inst::CtPop, Inst::BSwap, Inst::BitReverse, Inst::Cttz, Inst::Ctlz, Inst::Freeze
static std::vector<Inst::Kind> UnaryOperators = {
Inst::CtPop, Inst::BSwap, Inst::BitReverse, Inst::Cttz, Inst::Ctlz
};

static const std::vector<Inst::Kind> BinaryOperators = {
Expand Down Expand Up @@ -91,6 +91,9 @@ namespace {
static cl::opt<bool> IgnoreCost("souper-enumerative-synthesis-ignore-cost",
cl::desc("Ignore cost of RHSes -- just generate them. (default=false)"),
cl::init(false));
static cl::opt<bool> SynFreeze("souper-synthesize-freeze",
cl::desc("Generate Freeze (default=true)"),
cl::init(true));
static cl::opt<unsigned> MaxLHSCands("souper-max-lhs-cands",
cl::desc("Gather at most this many values from a LHS to use as synthesis inputs (default=8)"),
cl::init(8));
Expand Down Expand Up @@ -881,3 +884,41 @@ EnumerativeSynthesis::synthesize(SMTLIBSolver *SMTSolver,

return EC;
}

EnumerativeSynthesis::EnumerativeSynthesis() {
if (SynFreeze) {
UnaryOperators.push_back(Inst::Freeze);
}
}

std::vector<Inst *>
EnumerativeSynthesis::generateExprs(InstContext &IC, size_t CountLimit,
std::vector<Inst *> Vars, size_t Width) {
MaxNumInstructions = CountLimit;

std::set<Inst*> Visited;
std::vector<PruneFunc> PruneFuncs = { [&Visited](Inst *I, std::vector<Inst*> &ReservedInsts) {
return CountPrune(I, ReservedInsts, Visited);
}};
auto PruneCallback = MkPruneFunc(PruneFuncs);

std::vector<Inst *> Guesses;

int TooExpensive = CountLimit + 1;

for (auto I : Vars) {
if (I->Width == Width)
addGuess(I, Width, IC, TooExpensive, Guesses, TooExpensive);
}

auto Generate = [&Guesses](Inst *Guess) {
Guesses.push_back(Guess);
return true;
};

getGuesses(Vars, Width, TooExpensive, IC, nullptr,
nullptr, TooExpensive, PruneCallback, Generate);

return Guesses;
}

11 changes: 11 additions & 0 deletions test/Generalize/symbolize.opt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
; REQUIRES: solver, synthesis
; RUN: %generalize -symbolize --souper-synthesize-freeze=false --generalization-num-results=2 %s | %souper-check > %t
; RUN: %FileCheck %s < %t

%x:i8 = var
%foo = add %x, 2
%bar = sub %foo, %x
infer %bar
result 2:i8
;CHECK: LGTM
;CHECK: LGTM
128 changes: 125 additions & 3 deletions tools/generalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include "llvm/Support/KnownBits.h"

#include "souper/Infer/Preconditions.h"

#include "souper/Infer/EnumerativeSynthesis.h"
#include "souper/Inst/InstGraph.h"
#include "souper/Parser/Parser.h"
#include "souper/Tool/GetSolver.h"
Expand Down Expand Up @@ -31,11 +31,21 @@ static llvm::cl::opt<bool> RemoveLeaf("remove-leaf",
"(default=false)"),
llvm::cl::init(false));

static llvm::cl::opt<bool> SymbolizeConstant("symbolize",
llvm::cl::desc("Try to replace a concrete constant with a symbolic constant."
"(default=false)"),
llvm::cl::init(false));

static llvm::cl::opt<bool> FixIt("fixit",
llvm::cl::desc("Given an invalid optimization, generate a valid one."
"(default=false)"),
llvm::cl::init(false));

static cl::opt<size_t> NumResults("generalization-num-results",
cl::desc("Number of Generalization Results"),
cl::init(5));


void Generalize(InstContext &IC, Solver *S, ParsedReplacement Input) {
bool FoundWP = false;
std::vector<std::map<Inst *, llvm::KnownBits>> Results;
Expand All @@ -54,10 +64,120 @@ void Generalize(InstContext &IC, Solver *S, ParsedReplacement Input) {
}
}

void SymbolizeAndGeneralize(InstContext &IC,
Solver *S, ParsedReplacement Input) {
std::vector<Inst *> LHSConsts, RHSConsts;
auto Pred = [](Inst *I) {return I->K == Inst::Const;};
findInsts(Input.Mapping.LHS, LHSConsts, Pred);
findInsts(Input.Mapping.RHS, RHSConsts, Pred);

if (LHSConsts.size() != 1 || RHSConsts.size() != 1) {
return;
// TODO: Relax this restriction later
}

// Replace LHSConst[0] with a new variable and RHSConst[0]
// with a synthesized function.

auto FakeConst = IC.createVar(LHSConsts[0]->Width, "fakeconst");

// Does it makes sense for the expression to depend on other variables?
// If yes, expand the third argument to include inputs
EnumerativeSynthesis ES;
auto Guesses = ES.generateExprs(IC, 2, {FakeConst},
RHSConsts[0]->Width);

// Discarding guesses with symbolic constants
// Find a way to avoid this
// Here is the problem that needs to be solved for this:
// Given f and g\C, find N and P such that P -> (f == c\N)
// Weakest possible P is desirable.

std::vector<Inst *> Filtered;
for (auto Guess : Guesses) {
std::set<Inst *> ConstSet;
std::map <Inst *, llvm::APInt> ResultConstMap;
souper::getConstants(Guess, ConstSet);
if (ConstSet.empty()) {
Filtered.push_back(Guess);
}
}
std::swap(Guesses, Filtered);

std::vector<std::vector<std::map<Inst *, llvm::KnownBits>>>
Preconditions;

std::map<Inst *, Inst *> InstCache{{LHSConsts[0], FakeConst}};
std::map<Block *, Block *> BlockCache;
std::map<Inst *, APInt> ConstMap;
auto LHS = getInstCopy(Input.Mapping.LHS, IC, InstCache,
BlockCache, &ConstMap, false);
for (auto &Guess : Guesses) {
std::map<Inst *, Inst *> InstCache{{RHSConsts[0], Guess}};
auto RHS = getInstCopy(Input.Mapping.RHS, IC, InstCache,
BlockCache, &ConstMap, false);
std::vector<std::map<Inst *, llvm::KnownBits>> Results;
bool FoundWP = false;
InstMapping Mapping(LHS, RHS);
S->abstractPrecondition(Input.BPCs, Input.PCs, Mapping, IC, FoundWP, Results);

Preconditions.push_back(Results);
if (!FoundWP) {
Guess = nullptr; // TODO: Better failure indicator
} else {
Guess = RHS;
}
}

std::vector<size_t> Idx;
std::vector<int> Utility;
for (size_t i = 0; i < Guesses.size(); ++i) {
Idx.push_back(i);
}
for (size_t i = 0; i < Preconditions.size(); ++i) {
Utility.push_back(0);
if (!Guesses[i]) continue;
if (Preconditions[i].empty()) {
Utility[i] = 1000; // High magic number
}

for (auto V : Preconditions[i]) {
for (auto P : V) {
auto W = P.second.getBitWidth();
Utility[i] += (W - P.second.Zero.countPopulation());
Utility[i] += (W - P.second.One.countPopulation());
}
}
}

std::sort(Idx.begin(), Idx.end(), [&Utility](size_t a, size_t b) {
return Utility[a] > Utility[b];
});
for (size_t i = 0; i < std::min(Idx.size(), NumResults.getValue()); ++i) {
if (Preconditions[Idx[i]].empty()) {
ReplacementContext RC;
auto LHSStr = RC.printInst(LHS, llvm::outs(), true);
llvm::outs() << "infer " << LHSStr << "\n";
auto RHSStr = RC.printInst(Guesses[Idx[i]], llvm::outs(), true);
llvm::outs() << "result " << RHSStr << "\n\n";
}
for (auto Results : Preconditions[Idx[i]]) {
for (auto Pair : Results) {
Pair.first->KnownOnes = Pair.second.One;
Pair.first->KnownZeros = Pair.second.Zero;
}
ReplacementContext RC;
auto LHSStr = RC.printInst(LHS, llvm::outs(), true);
llvm::outs() << "infer " << LHSStr << "\n";
auto RHSStr = RC.printInst(Guesses[Idx[i]], llvm::outs(), true);
llvm::outs() << "result " << RHSStr << "\n\n";
}
}
}

// TODO: Return modified instructions instead of just printing out
void RemoveLeafAndGeneralize(InstContext &IC,
Solver *S, ParsedReplacement Input) {

if (DebugLevel > 1) {
llvm::errs() << "Attempting to generalize by removing leaf.\n";
}
Expand Down Expand Up @@ -155,7 +275,9 @@ int main(int argc, char **argv) {
RemoveLeafAndGeneralize(IC, S.get(), Input);
}
// if (EviscerateRoot) {...}
// if (SymbolizeConstant) {...}
if (SymbolizeConstant) {
SymbolizeAndGeneralize(IC, S.get(), Input);
}
// if (LiberateWidth) {...}
}

Expand Down

0 comments on commit b06aef9

Please sign in to comment.