Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove support for multi-op OpDescriptions. #72

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions example/ExampleMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Module.h"
Expand Down Expand Up @@ -145,7 +146,7 @@ struct VisitorNest {
raw_ostream *out = nullptr;
VisitorInnermost inner;

void visitBinaryOperator(BinaryOperator &inst) {
void visitBinaryOperator(Instruction &inst) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a regression in a sense, and it would be nice if we could avoid it. What if the VisitorBuilder always used an OpSet, even for singletons?

Basically, what I'm thinking of is to just get rid of VisitorKey entirely and always using OpSet. Or at least removing the OpDescription case from the VisitorKey (and keeping the set and intrinsic cases).

This technically makes some of the VisitorBuilder code less efficient, but:

  • By design, visitor building should happen only once, so a small regression isn't too bad
  • All the OpSet::get implementations should still be fast since they should return references to static local variables
  • Fixing this particular regression and cleaning up the code slightly is a benefit that I'd say outweighs a minor performance different during visitor building

(That would also mean removing OpDescription::get ... actually, I wonder if perhaps we end up removing OpDescription altogether in the end? I haven't fully thought this through to the end.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theoretically, that would be possible. Maybe let us discuss a mid-term plan for that offline, since I regard OpDescription as a useful layer of abstraction right now.

Copy link
Contributor Author

@tsymalla tsymalla Nov 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking a bit more, I doubt an intermediate solution makes sense. The issue here is the group of operations, namingly UnaryInstruction and BinaryOperator, as currently implemented by the getClass specialization approach.

If we want to register a BinaryOperator case to the Visitor and make it forward to the actual ::get specialization of the OpSet, this requires a not-so-beautiful template overload for the parameter pack case:

template <typename OpT> static const OpSet &get();

  // Construct an OpSet from a set of dialect ops, given as template
  // arguments.
  template <typename... OpTs, std::size_t Count = sizeof...(OpTs), std::enable_if_t<(Count > 1), bool> = true> static const OpSet get() {
    static OpSet set;
    (... && appendT<OpTs>(set));
    return set;
  }

Otherwise, the compiler will complain about ambiguity, in which he is correct. Even though we have a single-argument template overload of ::get, we can make it work, but that requires the dialect generator to generate OpSet overloads instead of OpDescription overloads, because without, nothing else will work.
That in turn will require us to modify OpMap to accept OpSets instead of OpDescriptions. We cannot make use of the getClass approach without adding a specific VisitorBuilder::addClass method that only forwards to the specific OpSet::get specialization.

Of course we can make it work, but at this point, we have several choices:

  • Introduce OpSet support as first-class citizen, removing support for BinaryOperator and UnaryInstruction
  • Fully remove support for OpDescription::get as part of this PR and do all of the required changes (I have a half-baked version on my disk)
  • Integrate OpSet in the Visitor and only add specific OpSet::get specializations next to the existing OpDescription overloads
  • Leave it in the current state as of this PR, but accept the regression

I tend to fully drop support for OpDescription::get as part of this PR.
We will of course require special overloads of the VisitorBuilder::add methods to prefer single-argument template invocations, so we can do forwarding with the specific OpT as argument type (OpT &Op instead of llvm::Instruction &I), but maybe that will be a bit cleaner.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*out << "visiting BinaryOperator: " << inst << '\n';
}
};
Expand Down Expand Up @@ -204,11 +205,13 @@ template <bool rpot> const Visitor<VisitorContainer> &getExampleVisitor() {
}
}
});
b.add<UnaryInstruction>(
[](VisitorNest &self, UnaryInstruction &inst) {
*self.out << "visiting UnaryInstruction: " << inst << '\n';
});
b.add(&VisitorNest::visitBinaryOperator);
b.addSet(OpSet::getClass<UnaryInstruction>(),
[](VisitorNest &self, llvm::Instruction &inst) {
*self.out << "visiting UnaryInstruction: " << inst
<< '\n';
});
b.addSet(OpSet::getClass<BinaryOperator>(),
&VisitorNest::visitBinaryOperator);
b.nest<raw_ostream>([](VisitorBuilder<raw_ostream> &b) {
b.add<xd::WriteOp>([](raw_ostream &out, xd::WriteOp &op) {
out << "visiting WriteOp: " << op << '\n';
Expand Down
7 changes: 1 addition & 6 deletions include/llvm-dialects/Dialect/OpDescription.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#pragma once

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"

#include <variant>
Expand Down Expand Up @@ -44,7 +43,6 @@ class OpDescription {
: m_kind(hasOverloads ? Kind::DialectWithOverloads : Kind::Dialect),
m_op(mnemonic) {}
OpDescription(Kind kind, unsigned opcode) : m_kind(kind), m_op(opcode) {}
OpDescription(Kind kind, llvm::MutableArrayRef<unsigned> opcodes);

static OpDescription fromCoreOp(unsigned op) { return {Kind::Core, op}; }

Expand All @@ -69,8 +67,6 @@ class OpDescription {

unsigned getOpcode() const;

llvm::ArrayRef<unsigned> getOpcodes() const;

llvm::StringRef getMnemonic() const {
assert(m_kind == Kind::Dialect || m_kind == Kind::DialectWithOverloads);
return std::get<llvm::StringRef>(m_op);
Expand All @@ -92,9 +88,8 @@ class OpDescription {

// Holds one of:
// - core instruction opcode or intrinsic ID
// - sorted array of opcodes or intrinsic IDs
// - mnemonic
std::variant<unsigned, llvm::ArrayRef<unsigned>, llvm::StringRef> m_op;
std::variant<unsigned, llvm::StringRef> m_op;
};

} // namespace llvm_dialects
15 changes: 0 additions & 15 deletions include/llvm-dialects/Dialect/OpMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,6 @@ template <typename ValueT> class OpMap final {
// Check if the map contains an op described by an OpDescription.
bool contains(const OpDescription &desc) const {
if (desc.isCoreOp() || desc.isIntrinsic()) {
assert(desc.getOpcodes().size() == 1 &&
"OpMap only supports querying of single core opcodes and "
"intrinsics.");

const unsigned op = desc.getOpcode();
return (desc.isCoreOp() && containsCoreOp(op)) ||
(desc.isIntrinsic() && containsIntrinsic(op));
Expand Down Expand Up @@ -240,9 +236,6 @@ template <typename ValueT> class OpMap final {
return {found, false};

if (desc.isCoreOp() || desc.isIntrinsic()) {
assert(desc.getOpcodes().size() == 1 &&
"OpMap: Can only emplace a single op at a time.");

const unsigned op = desc.getOpcode();
if (desc.isCoreOp()) {
auto [it, inserted] =
Expand Down Expand Up @@ -578,10 +571,6 @@ template <typename ValueT, bool isConst> class OpMapIteratorBase final {
OpMapIteratorBase(OpMapT *map, const OpDescription &desc)
: m_map{map}, m_desc{desc} {
if (desc.isCoreOp() || desc.isIntrinsic()) {
assert(desc.getOpcodes().size() == 1 &&
"OpMapIterator only supports querying of single core opcodes and "
"intrinsics.");

const unsigned op = desc.getOpcode();

if (desc.isCoreOp()) {
Expand Down Expand Up @@ -659,10 +648,6 @@ template <typename ValueT, bool isConst> class OpMapIteratorBase final {
template <bool Proxy = isConst, typename = std::enable_if_t<!Proxy>>
bool erase() {
if (m_desc.isCoreOp() || m_desc.isIntrinsic()) {
assert(m_desc.getOpcodes().size() == 1 &&
"OpMapIterator only supports erasing of single core opcodes and "
"intrinsics.");

const unsigned op = m_desc.getOpcode();

if (m_desc.isCoreOp())
Expand Down
12 changes: 4 additions & 8 deletions include/llvm-dialects/Dialect/OpSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class OpSet final {
return set;
}

template <typename ClassT> static const OpSet &getClass();

// Construct an OpSet from a set of dialect ops, given as template
// arguments.
template <typename... OpTs> static const OpSet get() {
Expand Down Expand Up @@ -119,10 +121,6 @@ class OpSet final {
// Checks if a given OpDescription is stored in the set.
bool contains(const OpDescription &desc) const {
if (desc.isCoreOp() || desc.isIntrinsic()) {
assert(desc.getOpcodes().size() == 1 &&
"OpSet only supports querying of single core opcodes and "
"intrinsics.");

const unsigned op = desc.getOpcode();
return (desc.isCoreOp() && containsCoreOp(op)) ||
(desc.isIntrinsic() && containsIntrinsicID(op));
Expand Down Expand Up @@ -189,15 +187,13 @@ class OpSet final {
// Tries to insert a given description in the internal data structures.
void tryInsertOp(const OpDescription &desc) {
if (desc.isCoreOp()) {
for (const unsigned op : desc.getOpcodes())
m_coreOpcodes.insert(op);
m_coreOpcodes.insert(desc.getOpcode());

return;
}

if (desc.isIntrinsic()) {
for (const unsigned op : desc.getOpcodes())
m_intrinsicIDs.insert(op);
m_intrinsicIDs.insert(desc.getOpcode());

return;
}
Expand Down
67 changes: 67 additions & 0 deletions include/llvm-dialects/Dialect/Visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,28 +369,60 @@ class VisitorBuilder : private detail::VisitorBuilderBase {

Visitor<PayloadT> build() { return VisitorBuilderBase::build(); }

//
// Add a simple visitor case for a single templatized OpDescription and a
// functor.
//
tsymalla marked this conversation as resolved.
Show resolved Hide resolved
template <typename OpT> VisitorBuilder &add(void (*fn)(PayloadT &, OpT &)) {
addCase<OpT>(detail::VisitorKey::op<OpT>(), fn);
return *this;
}

//
// Add a visitor case for a templatized OpSet and a functor.
//
template <typename... OpTs>
VisitorBuilder &addSet(void (*fn)(PayloadT &, llvm::Instruction &I)) {
addSetCase(detail::VisitorKey::opSet<OpTs...>(), fn);
return *this;
}

//
// Add a visitor case for a OpSet passed by const reference and a functor.
//
VisitorBuilder &addSet(const OpSet &opSet,
void (*fn)(PayloadT &, llvm::Instruction &I)) {
addSetCase(detail::VisitorKey::opSet(opSet), fn);
return *this;
}

//
// Add a visitor case for a member function and a single OpDescription OpSet.
//
template <typename OpT> VisitorBuilder &add(void (PayloadT::*fn)(OpT &)) {
addMemberFnCase<OpT>(detail::VisitorKey::op<OpT>(), fn);
return *this;
}

//
// Add a visitor case for a member function and a templatized OpSet.
//
template <typename... OpTs>
VisitorBuilder &addSet(void (PayloadT::*fn)(llvm::Instruction &)) {
addMemberFnSetCase<OpTs...>(detail::VisitorKey::opSet<OpTs...>(), fn);
return *this;
}

//
// Add a visitor case for a member function and a OpSet, passed by const
// reference.
//
VisitorBuilder &addSet(const OpSet &opSet,
void (PayloadT::*fn)(llvm::Instruction &)) {
addMemberFnSetCase(detail::VisitorKey::opSet(opSet), fn);
return *this;
}

VisitorBuilder &addIntrinsic(unsigned id,
void (*fn)(PayloadT &, llvm::IntrinsicInst &)) {
addCase<llvm::IntrinsicInst>(detail::VisitorKey::intrinsic(id), fn);
Expand Down Expand Up @@ -457,6 +489,24 @@ class VisitorBuilder : private detail::VisitorBuilderBase {
VisitorBuilderBase::add(key, &VisitorBuilder::memberFnForwarder<OpT>, data);
}

template <typename... OpTs>
void addMemberFnSetCase(detail::VisitorKey key,
void (PayloadT::*fn)(Instruction &)) {
detail::VisitorCallbackData data{};
static_assert(sizeof(fn) <= sizeof(data.data));
memcpy(&data.data, &fn, sizeof(fn));
VisitorBuilderBase::add(key, &VisitorBuilder::memberFnSetForwarder<OpTs>...,
data);
}

void addMemberFnSetCase(detail::VisitorKey key,
void (PayloadT::*fn)(Instruction &)) {
detail::VisitorCallbackData data{};
static_assert(sizeof(fn) <= sizeof(data.data));
memcpy(&data.data, &fn, sizeof(fn));
VisitorBuilderBase::add(key, &VisitorBuilder::memberFnSetForwarder, data);
}

template <typename OpT>
static void forwarder(const detail::VisitorCallbackData &data, void *payload,
llvm::Instruction *op) {
Expand All @@ -480,6 +530,23 @@ class VisitorBuilder : private detail::VisitorBuilderBase {
PayloadT *self = static_cast<PayloadT *>(payload);
(self->*fn)(*llvm::cast<OpT>(op));
}

template <typename... OpTs>
static void memberFnSetForwarder(const detail::VisitorCallbackData &data,
void *payload, llvm::Instruction *op) {
tsymalla marked this conversation as resolved.
Show resolved Hide resolved
void (PayloadT::*fn)(Instruction &);
memcpy(&fn, &data.data, sizeof(fn));
PayloadT *self = static_cast<PayloadT *>(payload);
(self->*fn)(*op);
}

static void memberFnSetForwarder(const detail::VisitorCallbackData &data,
void *payload, llvm::Instruction *op) {
void (PayloadT::*fn)(Instruction &);
memcpy(&fn, &data.data, sizeof(fn));
PayloadT *self = static_cast<PayloadT *>(payload);
(self->*fn)(*op);
}
};

} // namespace llvm_dialects
44 changes: 13 additions & 31 deletions lib/Dialect/OpDescription.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include "llvm-dialects/Dialect/OpDescription.h"
#include "llvm-dialects/Dialect/OpSet.h"

#include "llvm-dialects/Dialect/Dialect.h"

Expand All @@ -25,25 +26,11 @@
using namespace llvm_dialects;
using namespace llvm;

OpDescription::OpDescription(Kind kind, MutableArrayRef<unsigned> opcodes)
: m_kind(kind), m_op(opcodes) {
llvm::sort(opcodes);
}

unsigned OpDescription::getOpcode() const {
const ArrayRef<unsigned> opcodes{getOpcodes()};
assert(!opcodes.empty() && "OpDescription does not contain any opcode!");

return opcodes.front();
}

ArrayRef<unsigned> OpDescription::getOpcodes() const {
assert(m_kind == Kind::Core || m_kind == Kind::Intrinsic);

if (auto *op = std::get_if<unsigned>(&m_op))
return *op;

return std::get<ArrayRef<unsigned>>(m_op);
llvm_unreachable("OpDescription does not contain any opcode!");
}

bool OpDescription::matchInstruction(const Instruction &inst) const {
Expand All @@ -57,9 +44,7 @@ bool OpDescription::matchInstruction(const Instruction &inst) const {
if (auto *op = std::get_if<unsigned>(&m_op))
return inst.getOpcode() == *op;

auto opcodes = std::get<ArrayRef<unsigned>>(m_op);
auto it = llvm::lower_bound(opcodes, inst.getOpcode());
return it != opcodes.end() && *it == inst.getOpcode();
return false;
}

if (auto *call = dyn_cast<CallInst>(&inst)) {
Expand Down Expand Up @@ -87,15 +72,13 @@ bool OpDescription::matchIntrinsic(unsigned intrinsicId) const {
if (auto *op = std::get_if<unsigned>(&m_op))
return *op == intrinsicId;

auto opcodes = std::get<ArrayRef<unsigned>>(m_op);
auto it = llvm::lower_bound(opcodes, intrinsicId);
return it != opcodes.end() && *it == intrinsicId;
return false;
}

// ============================================================================
// Descriptions of core instructions.

template <> const OpDescription &OpDescription::get<UnaryInstruction>() {
template <> const OpSet &OpSet::getClass<UnaryInstruction>() {
static unsigned opcodes[] = {
Instruction::Alloca,
Instruction::Load,
Expand All @@ -106,17 +89,17 @@ template <> const OpDescription &OpDescription::get<UnaryInstruction>() {
#define HANDLE_CAST_INST(num, opcode, Class) Instruction::opcode,
#include "llvm/IR/Instruction.def"
};
static const OpDescription desc{Kind::Core, opcodes};
return desc;
static const OpSet set = OpSet::fromCoreOpcodes(opcodes);
return set;
}

template <> const OpDescription &OpDescription::get<BinaryOperator>() {
template <> const OpSet &OpSet::getClass<BinaryOperator>() {
static unsigned opcodes[] = {
#define HANDLE_BINARY_INST(num, opcode, Class) Instruction::opcode,
#include "llvm/IR/Instruction.def"
};
static const OpDescription desc{Kind::Core, opcodes};
return desc;
static const OpSet set = OpSet::fromCoreOpcodes({opcodes});
return set;
}

// Generate OpDescription for all dedicate instruction classes.
Expand All @@ -136,10 +119,9 @@ template <> const OpDescription &OpDescription::get<BinaryOperator>() {
return desc; \
}
#define HANDLE_INTRINSIC_DESC_OPCODE_SET(Class, ...) \
template <> const OpDescription &OpDescription::get<Class>() { \
static unsigned opcodes[] = {__VA_ARGS__}; \
static const OpDescription desc{Kind::Intrinsic, opcodes}; \
return desc; \
template <> const OpSet &OpSet::getClass<Class>() { \
static const OpSet set = OpSet::fromIntrinsicIDs({__VA_ARGS__}); \
return set; \
}

// ============================================================================
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/Visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ void VisitorTemplate::add(VisitorKey key, VisitorCallback *fn,
const OpDescription *opDesc = key.m_description;

if (opDesc->isCoreOp()) {
for (const unsigned op : opDesc->getOpcodes())
m_opMap[OpDescription::fromCoreOp(op)].push_back(handlerIdx);
m_opMap[OpDescription::fromCoreOp(opDesc->getOpcode())].push_back(
handlerIdx);
} else if (opDesc->isIntrinsic()) {
for (const unsigned op : opDesc->getOpcodes())
m_opMap[OpDescription::fromIntrinsic(op)].push_back(handlerIdx);
m_opMap[OpDescription::fromIntrinsic(opDesc->getOpcode())].push_back(
handlerIdx);
} else {
m_opMap[*opDesc].push_back(handlerIdx);
}
Expand Down
Loading