diff --git a/example/ExampleMain.cpp b/example/ExampleMain.cpp index 91203cc..c961474 100644 --- a/example/ExampleMain.cpp +++ b/example/ExampleMain.cpp @@ -32,6 +32,7 @@ #include "llvm/AsmParser/Parser.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/InstrTypes.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" @@ -145,7 +146,7 @@ struct VisitorNest { raw_ostream *out = nullptr; VisitorInnermost inner; - void visitBinaryOperator(BinaryOperator &inst) { + void visitBinaryOperator(Instruction &inst) { *out << "visiting BinaryOperator: " << inst << '\n'; } }; @@ -204,11 +205,13 @@ template const Visitor &getExampleVisitor() { } } }); - b.add( - [](VisitorNest &self, UnaryInstruction &inst) { - *self.out << "visiting UnaryInstruction: " << inst << '\n'; - }); - b.add(&VisitorNest::visitBinaryOperator); + b.addSet(OpSet::getClass(), + [](VisitorNest &self, llvm::Instruction &inst) { + *self.out << "visiting UnaryInstruction: " << inst + << '\n'; + }); + b.addSet(OpSet::getClass(), + &VisitorNest::visitBinaryOperator); b.nest([](VisitorBuilder &b) { b.add([](raw_ostream &out, xd::WriteOp &op) { out << "visiting WriteOp: " << op << '\n'; diff --git a/include/llvm-dialects/Dialect/OpDescription.h b/include/llvm-dialects/Dialect/OpDescription.h index 1a19a21..f1cc3c6 100644 --- a/include/llvm-dialects/Dialect/OpDescription.h +++ b/include/llvm-dialects/Dialect/OpDescription.h @@ -16,7 +16,6 @@ #pragma once -#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include @@ -44,7 +43,6 @@ class OpDescription { : m_kind(hasOverloads ? Kind::DialectWithOverloads : Kind::Dialect), m_op(mnemonic) {} OpDescription(Kind kind, unsigned opcode) : m_kind(kind), m_op(opcode) {} - OpDescription(Kind kind, llvm::MutableArrayRef opcodes); static OpDescription fromCoreOp(unsigned op) { return {Kind::Core, op}; } @@ -69,8 +67,6 @@ class OpDescription { unsigned getOpcode() const; - llvm::ArrayRef getOpcodes() const; - llvm::StringRef getMnemonic() const { assert(m_kind == Kind::Dialect || m_kind == Kind::DialectWithOverloads); return std::get(m_op); @@ -92,9 +88,8 @@ class OpDescription { // Holds one of: // - core instruction opcode or intrinsic ID - // - sorted array of opcodes or intrinsic IDs // - mnemonic - std::variant, llvm::StringRef> m_op; + std::variant m_op; }; } // namespace llvm_dialects diff --git a/include/llvm-dialects/Dialect/OpMap.h b/include/llvm-dialects/Dialect/OpMap.h index 8324a03..26517fa 100644 --- a/include/llvm-dialects/Dialect/OpMap.h +++ b/include/llvm-dialects/Dialect/OpMap.h @@ -137,10 +137,6 @@ template class OpMap final { // Check if the map contains an op described by an OpDescription. bool contains(const OpDescription &desc) const { if (desc.isCoreOp() || desc.isIntrinsic()) { - assert(desc.getOpcodes().size() == 1 && - "OpMap only supports querying of single core opcodes and " - "intrinsics."); - const unsigned op = desc.getOpcode(); return (desc.isCoreOp() && containsCoreOp(op)) || (desc.isIntrinsic() && containsIntrinsic(op)); @@ -240,9 +236,6 @@ template class OpMap final { return {found, false}; if (desc.isCoreOp() || desc.isIntrinsic()) { - assert(desc.getOpcodes().size() == 1 && - "OpMap: Can only emplace a single op at a time."); - const unsigned op = desc.getOpcode(); if (desc.isCoreOp()) { auto [it, inserted] = @@ -578,10 +571,6 @@ template class OpMapIteratorBase final { OpMapIteratorBase(OpMapT *map, const OpDescription &desc) : m_map{map}, m_desc{desc} { if (desc.isCoreOp() || desc.isIntrinsic()) { - assert(desc.getOpcodes().size() == 1 && - "OpMapIterator only supports querying of single core opcodes and " - "intrinsics."); - const unsigned op = desc.getOpcode(); if (desc.isCoreOp()) { @@ -659,10 +648,6 @@ template class OpMapIteratorBase final { template > bool erase() { if (m_desc.isCoreOp() || m_desc.isIntrinsic()) { - assert(m_desc.getOpcodes().size() == 1 && - "OpMapIterator only supports erasing of single core opcodes and " - "intrinsics."); - const unsigned op = m_desc.getOpcode(); if (m_desc.isCoreOp()) diff --git a/include/llvm-dialects/Dialect/OpSet.h b/include/llvm-dialects/Dialect/OpSet.h index c585285..2cb9882 100644 --- a/include/llvm-dialects/Dialect/OpSet.h +++ b/include/llvm-dialects/Dialect/OpSet.h @@ -89,6 +89,8 @@ class OpSet final { return set; } + template static const OpSet &getClass(); + // Construct an OpSet from a set of dialect ops, given as template // arguments. template static const OpSet get() { @@ -119,10 +121,6 @@ class OpSet final { // Checks if a given OpDescription is stored in the set. bool contains(const OpDescription &desc) const { if (desc.isCoreOp() || desc.isIntrinsic()) { - assert(desc.getOpcodes().size() == 1 && - "OpSet only supports querying of single core opcodes and " - "intrinsics."); - const unsigned op = desc.getOpcode(); return (desc.isCoreOp() && containsCoreOp(op)) || (desc.isIntrinsic() && containsIntrinsicID(op)); @@ -189,15 +187,13 @@ class OpSet final { // Tries to insert a given description in the internal data structures. void tryInsertOp(const OpDescription &desc) { if (desc.isCoreOp()) { - for (const unsigned op : desc.getOpcodes()) - m_coreOpcodes.insert(op); + m_coreOpcodes.insert(desc.getOpcode()); return; } if (desc.isIntrinsic()) { - for (const unsigned op : desc.getOpcodes()) - m_intrinsicIDs.insert(op); + m_intrinsicIDs.insert(desc.getOpcode()); return; } diff --git a/include/llvm-dialects/Dialect/Visitor.h b/include/llvm-dialects/Dialect/Visitor.h index a3706f1..a22cbf2 100644 --- a/include/llvm-dialects/Dialect/Visitor.h +++ b/include/llvm-dialects/Dialect/Visitor.h @@ -369,28 +369,60 @@ class VisitorBuilder : private detail::VisitorBuilderBase { Visitor build() { return VisitorBuilderBase::build(); } + // + // Add a simple visitor case for a single templatized OpDescription and a + // functor. + // template VisitorBuilder &add(void (*fn)(PayloadT &, OpT &)) { addCase(detail::VisitorKey::op(), fn); return *this; } + // + // Add a visitor case for a templatized OpSet and a functor. + // template VisitorBuilder &addSet(void (*fn)(PayloadT &, llvm::Instruction &I)) { addSetCase(detail::VisitorKey::opSet(), fn); return *this; } + // + // Add a visitor case for a OpSet passed by const reference and a functor. + // VisitorBuilder &addSet(const OpSet &opSet, void (*fn)(PayloadT &, llvm::Instruction &I)) { addSetCase(detail::VisitorKey::opSet(opSet), fn); return *this; } + // + // Add a visitor case for a member function and a single OpDescription OpSet. + // template VisitorBuilder &add(void (PayloadT::*fn)(OpT &)) { addMemberFnCase(detail::VisitorKey::op(), fn); return *this; } + // + // Add a visitor case for a member function and a templatized OpSet. + // + template + VisitorBuilder &addSet(void (PayloadT::*fn)(llvm::Instruction &)) { + addMemberFnSetCase(detail::VisitorKey::opSet(), fn); + return *this; + } + + // + // Add a visitor case for a member function and a OpSet, passed by const + // reference. + // + VisitorBuilder &addSet(const OpSet &opSet, + void (PayloadT::*fn)(llvm::Instruction &)) { + addMemberFnSetCase(detail::VisitorKey::opSet(opSet), fn); + return *this; + } + VisitorBuilder &addIntrinsic(unsigned id, void (*fn)(PayloadT &, llvm::IntrinsicInst &)) { addCase(detail::VisitorKey::intrinsic(id), fn); @@ -457,6 +489,24 @@ class VisitorBuilder : private detail::VisitorBuilderBase { VisitorBuilderBase::add(key, &VisitorBuilder::memberFnForwarder, data); } + template + void addMemberFnSetCase(detail::VisitorKey key, + void (PayloadT::*fn)(Instruction &)) { + detail::VisitorCallbackData data{}; + static_assert(sizeof(fn) <= sizeof(data.data)); + memcpy(&data.data, &fn, sizeof(fn)); + VisitorBuilderBase::add(key, &VisitorBuilder::memberFnSetForwarder..., + data); + } + + void addMemberFnSetCase(detail::VisitorKey key, + void (PayloadT::*fn)(Instruction &)) { + detail::VisitorCallbackData data{}; + static_assert(sizeof(fn) <= sizeof(data.data)); + memcpy(&data.data, &fn, sizeof(fn)); + VisitorBuilderBase::add(key, &VisitorBuilder::memberFnSetForwarder, data); + } + template static void forwarder(const detail::VisitorCallbackData &data, void *payload, llvm::Instruction *op) { @@ -480,6 +530,23 @@ class VisitorBuilder : private detail::VisitorBuilderBase { PayloadT *self = static_cast(payload); (self->*fn)(*llvm::cast(op)); } + + template + static void memberFnSetForwarder(const detail::VisitorCallbackData &data, + void *payload, llvm::Instruction *op) { + void (PayloadT::*fn)(Instruction &); + memcpy(&fn, &data.data, sizeof(fn)); + PayloadT *self = static_cast(payload); + (self->*fn)(*op); + } + + static void memberFnSetForwarder(const detail::VisitorCallbackData &data, + void *payload, llvm::Instruction *op) { + void (PayloadT::*fn)(Instruction &); + memcpy(&fn, &data.data, sizeof(fn)); + PayloadT *self = static_cast(payload); + (self->*fn)(*op); + } }; } // namespace llvm_dialects diff --git a/lib/Dialect/OpDescription.cpp b/lib/Dialect/OpDescription.cpp index d01a682..c3fe01b 100644 --- a/lib/Dialect/OpDescription.cpp +++ b/lib/Dialect/OpDescription.cpp @@ -15,6 +15,7 @@ */ #include "llvm-dialects/Dialect/OpDescription.h" +#include "llvm-dialects/Dialect/OpSet.h" #include "llvm-dialects/Dialect/Dialect.h" @@ -25,25 +26,11 @@ using namespace llvm_dialects; using namespace llvm; -OpDescription::OpDescription(Kind kind, MutableArrayRef opcodes) - : m_kind(kind), m_op(opcodes) { - llvm::sort(opcodes); -} - unsigned OpDescription::getOpcode() const { - const ArrayRef opcodes{getOpcodes()}; - assert(!opcodes.empty() && "OpDescription does not contain any opcode!"); - - return opcodes.front(); -} - -ArrayRef OpDescription::getOpcodes() const { - assert(m_kind == Kind::Core || m_kind == Kind::Intrinsic); - if (auto *op = std::get_if(&m_op)) return *op; - return std::get>(m_op); + llvm_unreachable("OpDescription does not contain any opcode!"); } bool OpDescription::matchInstruction(const Instruction &inst) const { @@ -57,9 +44,7 @@ bool OpDescription::matchInstruction(const Instruction &inst) const { if (auto *op = std::get_if(&m_op)) return inst.getOpcode() == *op; - auto opcodes = std::get>(m_op); - auto it = llvm::lower_bound(opcodes, inst.getOpcode()); - return it != opcodes.end() && *it == inst.getOpcode(); + return false; } if (auto *call = dyn_cast(&inst)) { @@ -87,15 +72,13 @@ bool OpDescription::matchIntrinsic(unsigned intrinsicId) const { if (auto *op = std::get_if(&m_op)) return *op == intrinsicId; - auto opcodes = std::get>(m_op); - auto it = llvm::lower_bound(opcodes, intrinsicId); - return it != opcodes.end() && *it == intrinsicId; + return false; } // ============================================================================ // Descriptions of core instructions. -template <> const OpDescription &OpDescription::get() { +template <> const OpSet &OpSet::getClass() { static unsigned opcodes[] = { Instruction::Alloca, Instruction::Load, @@ -106,17 +89,17 @@ template <> const OpDescription &OpDescription::get() { #define HANDLE_CAST_INST(num, opcode, Class) Instruction::opcode, #include "llvm/IR/Instruction.def" }; - static const OpDescription desc{Kind::Core, opcodes}; - return desc; + static const OpSet set = OpSet::fromCoreOpcodes(opcodes); + return set; } -template <> const OpDescription &OpDescription::get() { +template <> const OpSet &OpSet::getClass() { static unsigned opcodes[] = { #define HANDLE_BINARY_INST(num, opcode, Class) Instruction::opcode, #include "llvm/IR/Instruction.def" }; - static const OpDescription desc{Kind::Core, opcodes}; - return desc; + static const OpSet set = OpSet::fromCoreOpcodes({opcodes}); + return set; } // Generate OpDescription for all dedicate instruction classes. @@ -136,10 +119,9 @@ template <> const OpDescription &OpDescription::get() { return desc; \ } #define HANDLE_INTRINSIC_DESC_OPCODE_SET(Class, ...) \ - template <> const OpDescription &OpDescription::get() { \ - static unsigned opcodes[] = {__VA_ARGS__}; \ - static const OpDescription desc{Kind::Intrinsic, opcodes}; \ - return desc; \ + template <> const OpSet &OpSet::getClass() { \ + static const OpSet set = OpSet::fromIntrinsicIDs({__VA_ARGS__}); \ + return set; \ } // ============================================================================ diff --git a/lib/Dialect/Visitor.cpp b/lib/Dialect/Visitor.cpp index 2096be0..e740f60 100644 --- a/lib/Dialect/Visitor.cpp +++ b/lib/Dialect/Visitor.cpp @@ -63,11 +63,11 @@ void VisitorTemplate::add(VisitorKey key, VisitorCallback *fn, const OpDescription *opDesc = key.m_description; if (opDesc->isCoreOp()) { - for (const unsigned op : opDesc->getOpcodes()) - m_opMap[OpDescription::fromCoreOp(op)].push_back(handlerIdx); + m_opMap[OpDescription::fromCoreOp(opDesc->getOpcode())].push_back( + handlerIdx); } else if (opDesc->isIntrinsic()) { - for (const unsigned op : opDesc->getOpcodes()) - m_opMap[OpDescription::fromIntrinsic(op)].push_back(handlerIdx); + m_opMap[OpDescription::fromIntrinsic(opDesc->getOpcode())].push_back( + handlerIdx); } else { m_opMap[*opDesc].push_back(handlerIdx); }