Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement support for pre-visitor callbacks. #84

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions example/ExampleMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ std::unique_ptr<Module> createModuleExample(LLVMContext &context) {

struct VisitorInnermost {
int counter = 0;
raw_ostream *out = nullptr;
};

struct VisitorNest {
Expand All @@ -177,6 +178,13 @@ struct llvm_dialects::VisitorPayloadProjection<VisitorNest, raw_ostream> {
static raw_ostream &project(VisitorNest &nest) { return *nest.out; }
};

template <>
struct llvm_dialects::VisitorPayloadProjection<VisitorInnermost, raw_ostream> {
static raw_ostream &project(VisitorInnermost &innerMost) {
return *innerMost.out;
}
};

LLVM_DIALECTS_VISITOR_PAYLOAD_PROJECT_FIELD(VisitorContainer, nest)
LLVM_DIALECTS_VISITOR_PAYLOAD_PROJECT_FIELD(VisitorNest, inner)

Expand Down Expand Up @@ -215,8 +223,8 @@ template <bool rpot> const Visitor<VisitorContainer> &getExampleVisitor() {
b.addSet(complexSet, [](VisitorNest &self, llvm::Instruction &op) {
assert((op.getOpcode() == Instruction::Ret ||
(isa<IntrinsicInst>(&op) &&
cast<IntrinsicInst>(&op)->getIntrinsicID() ==
Intrinsic::umin)) &&
cast<IntrinsicInst>(&op)->getIntrinsicID() ==
Intrinsic::umin)) &&
"Unexpected operation detected while visiting OpSet!");

if (op.getOpcode() == Instruction::Ret) {
Expand Down Expand Up @@ -249,10 +257,36 @@ template <bool rpot> const Visitor<VisitorContainer> &getExampleVisitor() {
Intrinsic::umax, [](raw_ostream &out, IntrinsicInst &umax) {
out << "visiting umax intrinsic: " << umax << '\n';
});
b.addPreVisitCallback<xd::ReadOp, xd::WriteOp>(
[](raw_ostream &out, llvm::Instruction &inst) {
if (isa<xd::ReadOp>(inst))
out << "Will visit ReadOp next: " << inst << '\n';
else if (isa<xd::WriteOp>(inst))
out << "Will visit WriteOp next: " << inst << '\n';
else
llvm_unreachable("Unexpected op!");
});

b.addPreVisitCallback([](raw_ostream &out, Instruction &inst) {
if (isa<IntrinsicInst>(inst))
out << "Pre-visiting intrinsic instruction: " << inst << '\n';
});
});
b.nest<VisitorInnermost>([](VisitorBuilder<VisitorInnermost> &b) {
b.add<xd::ITruncOp>([](VisitorInnermost &inner,
xd::ITruncOp &op) { inner.counter++; });
b.add<xd::ITruncOp>(
[](VisitorInnermost &inner, xd::ITruncOp &op) {
inner.counter++;
*inner.out
<< "Counter after visiting ITruncOp: " << inner.counter
<< '\n';
});

b.addPreVisitCallback<xd::ITruncOp>(
[](VisitorInnermost &inner, Instruction &op) {
if (isa<xd::ITruncOp>(op))
*inner.out << "Counter before visiting ITruncOp: "
<< inner.counter << '\n';
});
});
})
.setStrategy(rpot ? VisitorStrategy::ReversePostOrder
Expand All @@ -267,6 +301,7 @@ void exampleVisit(Module &module) {

VisitorContainer container;
container.nest.out = &outs();
container.nest.inner.out = &outs();
visitor.visit(container, module);

outs() << "inner.counter = " << container.nest.inner.counter << '\n';
Expand Down
7 changes: 6 additions & 1 deletion include/llvm-dialects/Dialect/OpSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class OpSet final {
// arguments.
template <typename... OpTs> static const OpSet get() {
static OpSet set;
(... && appendT<OpTs>(set));
(void)(... && appendT<OpTs>(set));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine I suppose, but see #87

return set;
}

Expand Down Expand Up @@ -153,6 +153,11 @@ class OpSet final {
return isMatchingDialectOp(func.getName());
}

bool empty() const {
return m_coreOpcodes.empty() && m_intrinsicIDs.empty() &&
m_dialectOps.empty();
}

// -------------------------------------------------------------
// Convenience getters to access the internal data structures.
// -------------------------------------------------------------
Expand Down
50 changes: 47 additions & 3 deletions include/llvm-dialects/Dialect/Visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,27 @@ class VisitorTemplate {
friend class VisitorBuilderBase;

public:
enum class VisitorCallbackType : uint8_t { PreVisit = 0, Visit = 1 };

void setStrategy(VisitorStrategy strategy);
void add(VisitorKey key, VisitorCallback *fn, VisitorCallbackData data,
VisitorHandler::Projection projection);
VisitorHandler::Projection projection,
VisitorCallbackType visitorCallbackTy = VisitorCallbackType::Visit);

private:
void storeHandlersInOpMap(const VisitorKey &key, unsigned handlerIdx,
VisitorCallbackType callbackTy);

VisitorStrategy m_strategy = VisitorStrategy::Default;
std::vector<PayloadProjection> m_projections;
std::vector<VisitorHandler> m_handlers;
OpMap<llvm::SmallVector<unsigned>> m_opMap;

struct Handlers {
llvm::SmallVector<unsigned> PreVisitHandlers;
llvm::SmallVector<unsigned> VisitHandlers;
};

OpMap<Handlers> m_opMap;
};

/// @brief Base class for VisitorBuilders
Expand Down Expand Up @@ -279,6 +291,9 @@ class VisitorBuilderBase {

void setStrategy(VisitorStrategy strategy);

void addPreVisitCallback(VisitorKey key, VisitorCallback *fn,
VisitorCallbackData data);

void add(VisitorKey key, VisitorCallback *fn, VisitorCallbackData data);

VisitorBase build();
Expand Down Expand Up @@ -307,6 +322,11 @@ class VisitorBase {
class BuildHelper;
using HandlerRange = std::pair<unsigned, unsigned>;

struct MappedHandlers {
HandlerRange PreVisitCallbacks;
HandlerRange VisitCallbacks;
};

void call(HandlerRange handlers, void *payload,
llvm::Instruction &inst) const;
VisitorResult call(const VisitorHandler &handler, void *payload,
Expand All @@ -319,7 +339,7 @@ class VisitorBase {
VisitorStrategy m_strategy;
std::vector<PayloadProjection> m_projections;
std::vector<VisitorHandler> m_handlers;
OpMap<HandlerRange> m_opMap;
OpMap<MappedHandlers> m_opMap;
};

} // namespace detail
Expand Down Expand Up @@ -386,6 +406,20 @@ class VisitorBuilder : private detail::VisitorBuilderBase {
return *this;
}

VisitorBuilder &
addPreVisitCallback(const OpSet &opSet,
VisitorResult (*fn)(PayloadT &, llvm::Instruction &I)) {
addPreVisitCase(detail::VisitorKey::opSet(opSet), fn);
return *this;
}

template <typename... OpTs>
VisitorBuilder &addPreVisitCallback(void (*fn)(PayloadT &,
llvm::Instruction &I)) {
addPreVisitCase(detail::VisitorKey::opSet<OpTs...>(), fn);
return *this;
}

Visitor<PayloadT> build() { return VisitorBuilderBase::build(); }

template <typename OpT>
Expand Down Expand Up @@ -510,6 +544,16 @@ class VisitorBuilder : private detail::VisitorBuilderBase {
VisitorBuilderBase::add(key, &VisitorBuilder::setForwarder<ReturnT>, data);
}

template <typename ReturnT>
void addPreVisitCase(detail::VisitorKey key,
ReturnT (*fn)(PayloadT &, llvm::Instruction &)) {
detail::VisitorCallbackData data{};
static_assert(sizeof(fn) <= sizeof(data.data));
memcpy(&data.data, &fn, sizeof(fn));
VisitorBuilderBase::addPreVisitCallback(
key, &VisitorBuilder::setForwarder<ReturnT>, data);
}

template <typename OpT, typename ReturnT>
void addMemberFnCase(detail::VisitorKey key, ReturnT (PayloadT::*fn)(OpT &)) {
detail::VisitorCallbackData data{};
Expand Down
106 changes: 74 additions & 32 deletions lib/Dialect/Visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/Debug.h"

Expand All @@ -44,50 +43,76 @@ void VisitorTemplate::setStrategy(VisitorStrategy strategy) {
m_strategy = strategy;
}

void VisitorTemplate::add(VisitorKey key, VisitorCallback *fn,
VisitorCallbackData data,
VisitorHandler::Projection projection) {
VisitorHandler handler;
handler.callback = fn;
handler.data = data;
handler.projection = projection;

m_handlers.emplace_back(handler);
void VisitorTemplate::storeHandlersInOpMap(
const VisitorKey &key, unsigned handlerIdx,
VisitorCallbackType visitorCallbackTy) {
const auto HandlerList =
[&](const OpDescription &opDescription) -> llvm::SmallVector<unsigned> & {
if (visitorCallbackTy == VisitorCallbackType::PreVisit)
return m_opMap[opDescription].PreVisitHandlers;

const unsigned handlerIdx = m_handlers.size() - 1;
return m_opMap[opDescription].VisitHandlers;
};

if (key.m_kind == VisitorKey::Kind::Intrinsic) {
m_opMap[OpDescription::fromIntrinsic(key.m_intrinsicId)].push_back(
handlerIdx);
HandlerList(OpDescription::fromIntrinsic(key.m_intrinsicId))
.push_back(handlerIdx);
} else if (key.m_kind == VisitorKey::Kind::OpDescription) {
const OpDescription *opDesc = key.m_description;

if (opDesc->isCoreOp()) {
for (const unsigned op : opDesc->getOpcodes())
m_opMap[OpDescription::fromCoreOp(op)].push_back(handlerIdx);
HandlerList(OpDescription::fromCoreOp(op)).push_back(handlerIdx);
} else if (opDesc->isIntrinsic()) {
for (const unsigned op : opDesc->getOpcodes())
m_opMap[OpDescription::fromIntrinsic(op)].push_back(handlerIdx);
HandlerList(OpDescription::fromIntrinsic(op)).push_back(handlerIdx);
} else {
m_opMap[*opDesc].push_back(handlerIdx);
HandlerList(*opDesc).push_back(handlerIdx);
}
} else if (key.m_kind == VisitorKey::Kind::OpSet) {
const OpSet *opSet = key.m_set;

if (visitorCallbackTy == VisitorCallbackType::PreVisit && opSet->empty()) {
// This adds a handler for every stored op.
// Note: should be used with caution.
for (auto it : m_opMap)
it.second.PreVisitHandlers.push_back(handlerIdx);

return;
}

for (unsigned opcode : opSet->getCoreOpcodes())
m_opMap[OpDescription::fromCoreOp(opcode)].push_back(handlerIdx);
HandlerList(OpDescription::fromCoreOp(opcode)).push_back(handlerIdx);

for (unsigned intrinsicID : opSet->getIntrinsicIDs())
m_opMap[OpDescription::fromIntrinsic(intrinsicID)].push_back(handlerIdx);
HandlerList(OpDescription::fromIntrinsic(intrinsicID))
.push_back(handlerIdx);

for (const auto &dialectOpPair : opSet->getDialectOps()) {
m_opMap[OpDescription::fromDialectOp(dialectOpPair.isOverload,
dialectOpPair.mnemonic)]
for (const auto &dialectOpPair : opSet->getDialectOps())
HandlerList(OpDescription::fromDialectOp(dialectOpPair.isOverload,
dialectOpPair.mnemonic))
.push_back(handlerIdx);
}
}
}

void VisitorTemplate::add(VisitorKey key, VisitorCallback *fn,
VisitorCallbackData data,
VisitorHandler::Projection projection,
VisitorCallbackType visitorCallbackTy) {
assert(visitorCallbackTy != VisitorCallbackType::PreVisit || key.m_set);

VisitorHandler handler;
handler.callback = fn;
handler.data = data;
handler.projection = projection;

m_handlers.emplace_back(handler);

const unsigned handlerIdx = m_handlers.size() - 1;

storeHandlersInOpMap(key, handlerIdx, visitorCallbackTy);
}

VisitorBuilderBase::VisitorBuilderBase() : m_template(&m_ownedTemplate) {}

VisitorBuilderBase::VisitorBuilderBase(VisitorBuilderBase *parent,
Expand Down Expand Up @@ -144,6 +169,13 @@ void VisitorBuilderBase::setStrategy(VisitorStrategy strategy) {
m_template->setStrategy(strategy);
}

void VisitorBuilderBase::addPreVisitCallback(VisitorKey key,
VisitorCallback *fn,
VisitorCallbackData data) {
m_template->add(key, fn, data, m_projection,
VisitorTemplate::VisitorCallbackType::PreVisit);
}

void VisitorBuilderBase::add(VisitorKey key, VisitorCallback *fn,
VisitorCallbackData data) {
m_template->add(key, fn, data, m_projection);
Expand Down Expand Up @@ -192,9 +224,12 @@ VisitorBase::VisitorBase(VisitorTemplate &&templ)
BuildHelper helper(*this, templ.m_handlers);

m_opMap.reserve(templ.m_opMap);

for (auto it : templ.m_opMap)
m_opMap[it.first] = helper.mapHandlers(it.second);
for (auto it : templ.m_opMap) {
m_opMap[it.first].PreVisitCallbacks =
helper.mapHandlers(it.second.PreVisitHandlers);
m_opMap[it.first].VisitCallbacks =
helper.mapHandlers(it.second.VisitHandlers);
}
}

void VisitorBase::call(HandlerRange handlers, void *payload,
Expand Down Expand Up @@ -223,11 +258,14 @@ VisitorResult VisitorBase::call(const VisitorHandler &handler, void *payload,
}

void VisitorBase::visit(void *payload, Instruction &inst) const {
auto handlers = m_opMap.find(inst);
if (!handlers)
auto mappedHandlers = m_opMap.find(inst);
if (!mappedHandlers)
return;

call(*handlers.val(), payload, inst);
auto &callbacks = *mappedHandlers.val();

call(callbacks.PreVisitCallbacks, payload, inst);
call(callbacks.VisitCallbacks, payload, inst);
}

template <typename FilterT>
Expand All @@ -241,19 +279,23 @@ void VisitorBase::visitByDeclarations(void *payload, llvm::Module &module,

LLVM_DEBUG(dbgs() << "visit " << decl.getName() << '\n');

auto handlers = m_opMap.find(decl);
if (!handlers) {
auto mappedHandlers = m_opMap.find(decl);
if (!mappedHandlers) {
// Neither a matched intrinsic nor a matched dialect op; skip.
continue;
}

auto &callbacks = *mappedHandlers.val();

for (Use &use : make_early_inc_range(decl.uses())) {
if (auto *inst = dyn_cast<Instruction>(use.getUser())) {
if (!filter(*inst))
continue;
if (auto *callInst = dyn_cast<CallInst>(inst)) {
if (&use == &callInst->getCalledOperandUse())
call(*handlers.val(), payload, *callInst);
if (&use == &callInst->getCalledOperandUse()) {
call(callbacks.PreVisitCallbacks, payload, *callInst);
call(callbacks.VisitCallbacks, payload, *callInst);
}
}
}
}
Expand Down
Loading
Loading