Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Use LHS components as synthesis components #327

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 35 additions & 29 deletions include/souper/Infer/InstSynthesis.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,13 @@ typedef std::pair<unsigned, unsigned> LocVar;
typedef std::pair<LocVar, Inst *> LocInst;

/// A component is a fixed-width instruction kind
/// or created from Origin
struct Component {
Inst::Kind Kind;
unsigned Width;
std::vector<unsigned> OpWidths;
Inst *Origin;
std::vector<Inst *> OriginOps;
};

/// Unsupported components kinds
Expand All @@ -94,35 +97,35 @@ static const std::set<Inst::Kind> UnsupportedCompKinds = {
/// a component of that width is instantiated.
/// Again, note that constants are treated as ordinary inputs
static const std::vector<Component> CompLibrary = {
Component{Inst::Add, 0, {0,0}},
Component{Inst::Sub, 0, {0,0}},
Component{Inst::Mul, 0, {0,0}},
Component{Inst::UDiv, 0, {0,0}},
Component{Inst::SDiv, 0, {0,0}},
Component{Inst::UDivExact, 0, {0,0}},
Component{Inst::SDivExact, 0, {0,0}},
Component{Inst::URem, 0, {0,0}},
Component{Inst::SRem, 0, {0,0}},
Component{Inst::And, 0, {0,0}},
Component{Inst::Or, 0, {0,0}},
Component{Inst::Xor, 0, {0,0}},
Component{Inst::Shl, 0, {0,0}},
Component{Inst::LShr, 0, {0,0}},
Component{Inst::LShrExact, 0, {0,0}},
Component{Inst::AShr, 0, {0,0}},
Component{Inst::AShrExact, 0, {0,0}},
Component{Inst::Select, 0, {1,0,0}},
Component{Inst::Eq, 1, {0,0}},
Component{Inst::Ne, 1, {0,0}},
Component{Inst::Ult, 1, {0,0}},
Component{Inst::Slt, 1, {0,0}},
Component{Inst::Ule, 1, {0,0}},
Component{Inst::Sle, 1, {0,0}},
Component{Inst::Add, 0, {0,0}, 0, {}},
Component{Inst::Sub, 0, {0,0}, 0, {}},
Component{Inst::Mul, 0, {0,0}, 0, {}},
Component{Inst::UDiv, 0, {0,0}, 0, {}},
Component{Inst::SDiv, 0, {0,0}, 0, {}},
Component{Inst::UDivExact, 0, {0,0}, 0, {}},
Component{Inst::SDivExact, 0, {0,0}, 0, {}},
Component{Inst::URem, 0, {0,0}, 0, {}},
Component{Inst::SRem, 0, {0,0}, 0, {}},
Component{Inst::And, 0, {0,0}, 0, {}},
Component{Inst::Or, 0, {0,0}, 0, {}},
Component{Inst::Xor, 0, {0,0}, 0, {}},
Component{Inst::Shl, 0, {0,0}, 0, {}},
Component{Inst::LShr, 0, {0,0}, 0, {}},
Component{Inst::LShrExact, 0, {0,0}, 0, {}},
Component{Inst::AShr, 0, {0,0}, 0, {}},
Component{Inst::AShrExact, 0, {0,0}, 0, {}},
Component{Inst::Select, 0, {1,0,0}, 0, {}},
Component{Inst::Eq, 1, {0,0}, 0, {}},
Component{Inst::Ne, 1, {0,0}, 0, {}},
Component{Inst::Ult, 1, {0,0}, 0, {}},
Component{Inst::Slt, 1, {0,0}, 0, {}},
Component{Inst::Ule, 1, {0,0}, 0, {}},
Component{Inst::Sle, 1, {0,0}, 0, {}},
//
Component{Inst::CtPop, 0, {0}},
Component{Inst::BSwap, 0, {0}},
Component{Inst::Cttz, 0, {0}},
Component{Inst::Ctlz, 0, {0}}
Component{Inst::CtPop, 0, {0}, 0, {}},
Component{Inst::BSwap, 0, {0}, 0, {}},
Component{Inst::Cttz, 0, {0}, 0, {}},
Component{Inst::Ctlz, 0, {0}, 0, {}}
};

class InstSynthesis {
Expand All @@ -132,13 +135,15 @@ class InstSynthesis {
const BlockPCs &BPCs,
const std::vector<InstMapping> &PCs,
Inst *TargetLHS, Inst *&RHS,
const std::vector<Inst *> &LHSComps,
InstContext &IC, unsigned Timeout);

private:
/// Local references
SMTLIBSolver *LSMTSolver;
const BlockPCs *LBPCs;
const std::vector<InstMapping> *LPCs;
const std::vector<Inst *> *LLHSComps;
InstContext *LIC;
unsigned LTimeout;

Expand Down Expand Up @@ -291,6 +296,7 @@ class InstSynthesis {

/// Helper functions
void filterFixedWidthIntrinsicComps();
Component getCompFromInst(Inst *);
void getInputVars(Inst *I, std::vector<Inst *> &InputVars);
std::string getLocVarStr(const LocVar &Loc, const std::string Prefix="");
LocVar getLocVarFromStr(const std::string &Str);
Expand Down Expand Up @@ -318,7 +324,7 @@ class InstSynthesis {
};

void findCands(Inst *Root, std::vector<Inst *> &Guesses, InstContext &IC,
int Max);
bool WidthMustMatch, bool FilterVars, int Max);

Inst *getInstCopy(Inst *I, InstContext &IC,
std::map<Inst *, Inst *> &InstCache,
Expand Down
11 changes: 8 additions & 3 deletions lib/Extractor/Solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ static cl::opt<bool> InferNop("souper-infer-nop",
static cl::opt<bool> StressNop("souper-stress-nop",
cl::desc("stress-test big queries in nop synthesis by always performing all of the small queries (slow!) (default=false)"),
cl::init(false));
static cl::opt<int>MaxNops("souper-max-nops",
static cl::opt<int>MaxCands("souper-max-cands",
cl::desc("maximum number of values from the LHS to try to use as the RHS (default=20)"),
cl::init(20));
static cl::opt<bool> InferInts("souper-infer-iN",
Expand Down Expand Up @@ -145,7 +145,8 @@ class BaseSolver : public Solver {

if (InferNop) {
std::vector<Inst *> Guesses;
findCands(LHS, Guesses, IC, MaxNops);
findCands(LHS, Guesses, IC, /*WidthMustMatch=*/true, /*FilterVars=*/false,
MaxCands);

Inst *Ante = IC.getConst(APInt(1, true));
BlockPCs BPCsCopy;
Expand Down Expand Up @@ -206,8 +207,12 @@ class BaseSolver : public Solver {
}

if (InferInsts && SMTSolver->supportsModels()) {
std::vector<Inst *> LHSComps;
findCands(LHS, LHSComps, IC, /*WidthMustMatch=*/false, /*FilterVars=*/true,
MaxCands);
InstSynthesis IS;
EC = IS.synthesize(SMTSolver.get(), BPCs, PCs, LHS, RHS, IC, Timeout);
EC = IS.synthesize(SMTSolver.get(), BPCs, PCs, LHS, RHS,
LHSComps, IC, Timeout);
if (EC || RHS)
return EC;
}
Expand Down
111 changes: 80 additions & 31 deletions lib/Infer/InstSynthesis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,15 @@ std::error_code InstSynthesis::synthesize(SMTLIBSolver *SMTSolver,
const BlockPCs &BPCs,
const std::vector<InstMapping> &PCs,
Inst *TargetLHS, Inst *&RHS,
const std::vector<Inst *> &LHSComps,
InstContext &IC, unsigned Timeout) {
std::error_code EC;

// init local refs
LSMTSolver = SMTSolver;
LBPCs = &BPCs;
LPCs = &PCs;
LLHSComps = &LHSComps;
LIC = &IC;
LTimeout = Timeout;

Expand All @@ -91,7 +93,7 @@ std::error_code InstSynthesis::synthesize(SMTLIBSolver *SMTSolver,

if (DebugLevel > 0) {
llvm::outs() << "; starting synthesis for LHS\n";
PrintReplacementLHS(llvm::outs(), BPCs, PCs, LHS, Context);
PrintReplacementLHS(llvm::outs(), BPCs, PCs, LHS, Context, true);
if (DebugLevel > 2)
printInitInfo();
}
Expand Down Expand Up @@ -322,7 +324,7 @@ void InstSynthesis::setCompLibrary() {
for (auto KindStr : splitString(CmdUserCompKinds.c_str())) {
Inst::Kind K = Inst::getKind(KindStr);
if (KindStr == Inst::getKindName(Inst::Const)) // Special case
InitConstComps.push_back(Component{Inst::Const, 0, {}});
InitConstComps.push_back(Component{Inst::Const, 0, {}, 0, {}});
else if (K == Inst::ZExt || K == Inst::SExt || K == Inst::Trunc)
report_fatal_error("don't use zext/sext/trunc explicitly");
else if (K == Inst::None)
Expand All @@ -338,13 +340,13 @@ void InstSynthesis::setCompLibrary() {
InitComps.push_back(Comp);
} else {
InitComps = CompLibrary;
InitConstComps.push_back(Component{Inst::Const, 0, {}});
InitConstComps.push_back(Component{Inst::Const, 0, {}, 0, {}});
}
for (auto const &In : Inputs) {
if (In->Width == DefaultWidth)
continue;
Comps.push_back(Component{Inst::ZExt, DefaultWidth, {In->Width}});
Comps.push_back(Component{Inst::SExt, DefaultWidth, {In->Width}});
Comps.push_back(Component{Inst::ZExt, DefaultWidth, {In->Width}, 0, {}});
Comps.push_back(Component{Inst::SExt, DefaultWidth, {In->Width}, 0, {}});
}
// Second, for each input/constant create a component of DefaultWidth
for (auto &Comp : InitComps) {
Expand All @@ -362,7 +364,11 @@ void InstSynthesis::setCompLibrary() {
}
// Third, create one trunc comp to match the output width if necessary
if (LHS->Width < DefaultWidth)
Comps.push_back(Component{Inst::Trunc, LHS->Width, {DefaultWidth}});
Comps.push_back(Component{Inst::Trunc, LHS->Width, {DefaultWidth}, 0, {}});
// Finally, add LHS components (if provided) directly to Comps,
// their widths are already initialized.
for (auto I : *LLHSComps)
Comps.push_back(getCompFromInst(I));
}

void InstSynthesis::initInputVars(InstContext &IC) {
Expand Down Expand Up @@ -438,10 +444,11 @@ void InstSynthesis::filterFixedWidthIntrinsicComps() {

void InstSynthesis::initComponents(InstContext &IC) {
for (unsigned J = 0; J < Comps.size(); ++J) {
auto const &Comp = Comps[J];
auto &Comp = Comps[J];
std::string LocVarStr;
// First, init component inputs
std::vector<Inst *> CompOps;
std::map<Inst *, Inst *> OpsReplacements;
std::vector<LocVar> OpsLocVar;
for (unsigned K = 0; K < Comp.OpWidths.size(); ++K) {
LocVar In = std::make_pair(J+1, K+1);
Expand All @@ -464,6 +471,11 @@ void InstSynthesis::initComponents(InstContext &IC) {
CompInstMap[In] = OpInst;
CompOps.push_back(OpInst);
OpsLocVar.push_back(In);
// Update OpsReplacements
if (Comp.Origin) {
assert(Comp.OriginOps.size());
OpsReplacements.insert(std::make_pair(Comp.OriginOps[K], OpInst));
}
}
// Store all input locations
CompOpLocVars.push_back(OpsLocVar);
Expand All @@ -479,13 +491,23 @@ void InstSynthesis::initComponents(InstContext &IC) {
// Third, instantiate the component (aka Inst)
assert(Comp.Width && "comp width not set");
Inst *CompInst;
if (Comp.Kind == Inst::Select) {
Inst *C = IC.getInst(Inst::Trunc, 1, {CompOps[0]});
CompInst = IC.getInst(Comp.Kind, Comp.Width, {C, CompOps[1], CompOps[2]});
} else {
CompInst = IC.getInst(Comp.Kind, Comp.Width, CompOps);
if (Comp.Origin) {
assert(Comp.OriginOps.size() == CompOps.size());
CompInst = replaceVars(Comp.Origin, *LIC, OpsReplacements);
if (Comp.Width < DefaultWidth && Comp.Kind != Inst::Trunc)
CompInst = IC.getInst(Inst::ZExt, DefaultWidth, {CompInst});
// Update LHS component
Comp.Origin = CompInst;
Comp.OriginOps = CompOps;
} else {
if (Comp.Kind == Inst::Select) {
Inst *C = IC.getInst(Inst::Trunc, 1, {CompOps[0]});
CompInst = IC.getInst(Comp.Kind, Comp.Width, {C, CompOps[1], CompOps[2]});
} else {
CompInst = IC.getInst(Comp.Kind, Comp.Width, CompOps);
if (Comp.Width < DefaultWidth && Comp.Kind != Inst::Trunc)
CompInst = IC.getInst(Inst::ZExt, DefaultWidth, {CompInst});
}
}
// Update CompInstMap map with concrete Inst
CompInstMap[Out] = CompInst;
Expand Down Expand Up @@ -517,12 +539,14 @@ void InstSynthesis::printInitInfo() {
llvm::outs() << "N: " << N << ", M: " << M << "\n";
llvm::outs() << "default width: " << DefaultWidth << "\n";
llvm::outs() << "output width: " << LHS->Width << "\n";
llvm::outs() << "component library: ";
llvm::outs() << "component library: " << Comps.size() << "\n";
for (auto const &Comp : Comps) {
llvm::outs() << Inst::getKindName(Comp.Kind) << " (" << Comp.Width << ", { ";
for (auto const &Width : Comp.OpWidths)
llvm::outs() << Width << " ";
llvm::outs() << "}); ";
llvm::outs() << "})\n";
if (Comp.Origin)
PrintReplacementRHS(llvm::outs(), Comp.Origin, Context, true);
}
if (Comps.size())
llvm::outs() << "\n";
Expand Down Expand Up @@ -980,19 +1004,35 @@ Inst *InstSynthesis::createInstFromWiring(
llvm::outs() << "- creating inst " << Inst::getKindName(Comp.Kind)
<< ", width " << Comp.Width << "\n";
llvm::outs() << "before junk removal:\n";
PrintReplacementRHS(llvm::outs(), IC.getInst(Comp.Kind, Comp.Width, Ops),
Context);
if (Comp.Origin)
PrintReplacementRHS(llvm::outs(), Comp.Origin, Context);
else
PrintReplacementRHS(llvm::outs(), IC.getInst(Comp.Kind, Comp.Width, Ops),
Context);
}
// Sanity checks
if (Ops.size() == 2 && Ops[0]->K == Inst::Const && Ops[1]->K == Inst::Const)
report_fatal_error("inst operands are constants!");
assert(Comp.Width == 1 || Comp.Width == DefaultWidth ||
Comp.Width == LHS->Width);
// Create instruction
if (Comp.Kind == Inst::Select) {
// Instruction is a LHS component
if (Comp.Origin) {
assert(Comp.OriginOps.size() == Ops.size());
std::map<Inst *, Inst *> OpsReplacements;
for (unsigned J = 0; J < Ops.size(); ++J)
OpsReplacements.insert(std::make_pair(Comp.OriginOps[J], Ops[J]));
Inst *Copy = replaceVars(Comp.Origin, *LIC, OpsReplacements);
// Update ops
Ops = Copy->Ops;
}
// Create instruction from a component
if (Comp.Kind == Inst::Phi) {
assert(Comp.Origin && "Phi support for LHS components only");
return IC.getPhi(Comp.Origin->B, Ops);
} else if (Comp.Kind == Inst::Select) {
Ops[0] = IC.getInst(Inst::Trunc, 1, {Ops[0]});
return createCleanInst(Comp.Kind, Comp.Width, Ops, IC);
} if (Comp.Width < DefaultWidth && Comp.Kind != Inst::Trunc) {
} else if (Comp.Width < DefaultWidth && Comp.Kind != Inst::Trunc) {
Inst *Ret = createCleanInst(Comp.Kind, Comp.Width, Ops, IC);
return IC.getInst(Inst::ZExt, DefaultWidth, {Ret});
} else
Expand Down Expand Up @@ -1214,6 +1254,18 @@ Inst *InstSynthesis::createCleanInst(Inst::Kind Kind, unsigned Width,
return IC.getInst(Kind, Width, Ops);
}

Component InstSynthesis::getCompFromInst(Inst *I) {
std::vector<Inst *> IV;
getInputVars(I, IV);
sort(IV.begin(), IV.end());
IV.erase(unique(IV.begin(), IV.end()), IV.end());
std::vector<unsigned> OpWidths;
for (auto In : IV)
OpWidths.push_back(In->Width);

return Component{I->K, I->Width, OpWidths, I, IV};
}

void InstSynthesis::getInputVars(Inst *I, std::vector<Inst *> &InputVars) {
if (I->K == Inst::Var)
InputVars.push_back(I);
Expand Down Expand Up @@ -1456,7 +1508,7 @@ void InstSynthesis::constrainConstWiring(const Inst *Cand,
}

void findCands(Inst *Root, std::vector<Inst *> &Guesses, InstContext &IC,
int Max) {
bool WidthMustMatch, bool FilterVars, int Max) {
// breadth-first search
std::set<Inst *> Visited;
std::queue<std::tuple<Inst *,int>> Q;
Expand All @@ -1472,19 +1524,16 @@ void findCands(Inst *Root, std::vector<Inst *> &Guesses, InstContext &IC,
for (auto Op : I->Ops)
Q.push(std::make_tuple(Op, Benefit));
}
if (Benefit > 1 && I->Width == Root->Width && I->Available)
if (Benefit > 1 && I->Available && I->K != Inst::Const
&& I->K != Inst::UntypedConst) {
if (WidthMustMatch && I->Width != Root->Width)
continue;
if (FilterVars && I->K == Inst::Var)
continue;
Guesses.emplace_back(I);
// TODO: run experiments and see if it's worth doing these
if (0) {
if (Benefit > 2 && I->Width > Root->Width)
Guesses.emplace_back(IC.getInst(Inst::Trunc, Root->Width, {I}));
if (Benefit > 2 && I->Width < Root->Width) {
Guesses.emplace_back(IC.getInst(Inst::SExt, Root->Width, {I}));
Guesses.emplace_back(IC.getInst(Inst::ZExt, Root->Width, {I}));
}
if (Guesses.size() >= Max)
return;
}
if (Guesses.size() >= Max)
return;
}
}
}
Expand Down
14 changes: 14 additions & 0 deletions test/Infer/four-adds.opt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
; REQUIRES: solver, solver-model

; -souper-synthesis-comps=const is just a hack to avoid the initialization of the whole component library
; RUN: %souper-check %solver -infer-rhs -souper-infer-inst -souper-synthesis-comps=const -souper-synthesis-ignore-cost %s > %t1
; RUN: %FileCheck %s < %t1

; CHECK: result %4

%0:i32 = var
%1:i32 = add 1:i32, %0
%2:i32 = add 1:i32, %1
%3:i32 = add 1:i32, %2
%4:i32 = add 1:i32, %3
infer %4
Loading