diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h index 7f6e6d11e5f53a..0ddc227e3a02b4 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h @@ -29,35 +29,61 @@ namespace llvm::sandboxir { +class DependencyGraph; +class MemDGNode; + +/// SubclassIDs for isa/dyn_cast etc. +enum class DGNodeID { + DGNode, + MemDGNode, +}; + /// A DependencyGraph Node that points to an Instruction and contains memory /// dependency edges. class DGNode { +protected: Instruction *I; + // TODO: Use a PointerIntPair for SubclassID and I. + /// For isa/dyn_cast etc. + DGNodeID SubclassID; /// Memory predecessors. - DenseSet MemPreds; - /// This is true if this may read/write memory, or if it has some ordering - /// constraints, like with stacksave/stackrestore and alloca/inalloca. - bool IsMem; + DenseSet MemPreds; + + DGNode(Instruction *I, DGNodeID ID) : I(I), SubclassID(ID) {} + friend class MemDGNode; // For constructor. public: - DGNode(Instruction *I) : I(I) { - IsMem = I->isMemDepCandidate() || - (isa(I) && cast(I)->isUsedWithInAlloca()) || - I->isStackSaveOrRestoreIntrinsic(); + DGNode(Instruction *I) : I(I), SubclassID(DGNodeID::DGNode) { + assert(!isMemDepCandidate(I) && "Expected Non-Mem instruction, "); + } + DGNode(const DGNode &Other) = delete; + virtual ~DGNode() = default; + /// \Returns true if this is before \p Other in program order. + bool comesBefore(const DGNode *Other) { return I->comesBefore(Other->I); } + /// \Returns true if \p I is a memory dependency candidate instruction. + static bool isMemDepCandidate(Instruction *I) { + AllocaInst *Alloca; + return I->isMemDepCandidate() || + ((Alloca = dyn_cast(I)) && + Alloca->isUsedWithInAlloca()) || + I->isStackSaveOrRestoreIntrinsic(); } + Instruction *getInstruction() const { return I; } - void addMemPred(DGNode *PredN) { MemPreds.insert(PredN); } + void addMemPred(MemDGNode *PredN) { MemPreds.insert(PredN); } /// \Returns all memory dependency predecessors. - iterator_range::const_iterator> memPreds() const { + iterator_range::const_iterator> memPreds() const { return make_range(MemPreds.begin(), MemPreds.end()); } /// \Returns true if there is a memory dependency N->this. - bool hasMemPred(DGNode *N) const { return MemPreds.count(N); } - /// \Returns true if this may read/write memory, or if it has some ordering - /// constraints, like with stacksave/stackrestore and alloca/inalloca. - bool isMem() const { return IsMem; } + bool hasMemPred(DGNode *N) const { + if (auto *MN = dyn_cast(N)) + return MemPreds.count(MN); + return false; + } + #ifndef NDEBUG - void print(raw_ostream &OS, bool PrintDeps = true) const; + virtual void print(raw_ostream &OS, bool PrintDeps = true) const; friend raw_ostream &operator<<(DGNode &N, raw_ostream &OS) { N.print(OS); return OS; @@ -66,9 +92,46 @@ class DGNode { #endif // NDEBUG }; +/// A DependencyGraph Node for instructions that may read/write memory, or have +/// some ordering constraints, like with stacksave/stackrestore and +/// alloca/inalloca. +class MemDGNode final : public DGNode { + MemDGNode *PrevMemN = nullptr; + MemDGNode *NextMemN = nullptr; + + void setNextNode(MemDGNode *N) { NextMemN = N; } + void setPrevNode(MemDGNode *N) { PrevMemN = N; } + friend class DependencyGraph; // For setNextNode(), setPrevNode(). + +public: + MemDGNode(Instruction *I) : DGNode(I, DGNodeID::MemDGNode) { + assert(isMemDepCandidate(I) && "Expected Mem instruction!"); + } + static bool classof(const DGNode *Other) { + return Other->SubclassID == DGNodeID::MemDGNode; + } + /// \Returns the previous Mem DGNode in instruction order. + MemDGNode *getPrevNode() const { return PrevMemN; } + /// \Returns the next Mem DGNode in instruction order. + MemDGNode *getNextNode() const { return NextMemN; } +}; + +/// Convenience builders for a MemDGNode interval. +class MemDGNodeIntervalBuilder { +public: + /// Given \p Instrs it finds their closest mem nodes in the interval and + /// returns the corresponding mem range. Note: BotN (or its neighboring mem + /// node) is included in the range. + static Interval make(const Interval &Instrs, + DependencyGraph &DAG); + static Interval makeEmpty() { return {}; } +}; + class DependencyGraph { private: DenseMap> InstrToNodeMap; + /// The DAG spans across all instructions in this interval. + Interval DAGInterval; public: DependencyGraph() {} @@ -77,10 +140,20 @@ class DependencyGraph { auto It = InstrToNodeMap.find(I); return It != InstrToNodeMap.end() ? It->second.get() : nullptr; } + /// Like getNode() but returns nullptr if \p I is nullptr. + DGNode *getNodeOrNull(Instruction *I) const { + if (I == nullptr) + return nullptr; + return getNode(I); + } DGNode *getOrCreateNode(Instruction *I) { auto [It, NotInMap] = InstrToNodeMap.try_emplace(I); - if (NotInMap) - It->second = std::make_unique(I); + if (NotInMap) { + if (DGNode::isMemDepCandidate(I)) + It->second = std::make_unique(I); + else + It->second = std::make_unique(I); + } return It->second.get(); } /// Build/extend the dependency graph such that it includes \p Instrs. Returns diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp index 67b56451c7b594..ce295e8bf5df3f 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp @@ -31,6 +31,25 @@ void DGNode::dump() const { } #endif // NDEBUG +Interval +MemDGNodeIntervalBuilder::make(const Interval &Instrs, + DependencyGraph &DAG) { + // If top or bottom instructions are not mem-dep candidate nodes we need to + // walk down/up the chain and find the mem-dep ones. + Instruction *MemTopI = Instrs.top(); + Instruction *MemBotI = Instrs.bottom(); + while (!DGNode::isMemDepCandidate(MemTopI) && MemTopI != MemBotI) + MemTopI = MemTopI->getNextNode(); + while (!DGNode::isMemDepCandidate(MemBotI) && MemBotI != MemTopI) + MemBotI = MemBotI->getPrevNode(); + // If we couldn't find a mem node in range TopN - BotN then it's empty. + if (!DGNode::isMemDepCandidate(MemTopI)) + return {}; + // Now that we have the mem-dep nodes, create and return the range. + return Interval(cast(DAG.getNode(MemTopI)), + cast(DAG.getNode(MemBotI))); +} + Interval DependencyGraph::extend(ArrayRef Instrs) { if (Instrs.empty()) return {}; @@ -39,10 +58,18 @@ Interval DependencyGraph::extend(ArrayRef Instrs) { auto *TopI = Interval.top(); auto *BotI = Interval.bottom(); DGNode *LastN = getOrCreateNode(TopI); + MemDGNode *LastMemN = dyn_cast(LastN); for (Instruction *I = TopI->getNextNode(), *E = BotI->getNextNode(); I != E; I = I->getNextNode()) { auto *N = getOrCreateNode(I); - N->addMemPred(LastN); + N->addMemPred(LastMemN); + // Build the Mem node chain. + if (auto *MemN = dyn_cast(N)) { + MemN->setPrevNode(LastMemN); + if (LastMemN != nullptr) + LastMemN->setNextNode(MemN); + LastMemN = MemN; + } LastN = N; } return Interval; diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp index d8b6f519982eb1..28ab38ce3d3536 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp @@ -29,7 +29,7 @@ struct DependencyGraphTest : public testing::Test { } }; -TEST_F(DependencyGraphTest, DGNode_IsMem) { +TEST_F(DependencyGraphTest, MemDGNode) { parseIR(C, R"IR( declare void @llvm.sideeffect() declare void @llvm.pseudoprobe(i64, i64, i32, i64) @@ -66,16 +66,16 @@ define void @foo(i8 %v1, ptr %ptr) { sandboxir::DependencyGraph DAG; DAG.extend({&*BB->begin(), BB->getTerminator()}); - EXPECT_TRUE(DAG.getNode(Store)->isMem()); - EXPECT_TRUE(DAG.getNode(Load)->isMem()); - EXPECT_FALSE(DAG.getNode(Add)->isMem()); - EXPECT_TRUE(DAG.getNode(StackSave)->isMem()); - EXPECT_TRUE(DAG.getNode(StackRestore)->isMem()); - EXPECT_FALSE(DAG.getNode(SideEffect)->isMem()); - EXPECT_FALSE(DAG.getNode(PseudoProbe)->isMem()); - EXPECT_TRUE(DAG.getNode(FakeUse)->isMem()); - EXPECT_TRUE(DAG.getNode(Call)->isMem()); - EXPECT_FALSE(DAG.getNode(Ret)->isMem()); + EXPECT_TRUE(isa(DAG.getNode(Store))); + EXPECT_TRUE(isa(DAG.getNode(Load))); + EXPECT_FALSE(isa(DAG.getNode(Add))); + EXPECT_TRUE(isa(DAG.getNode(StackSave))); + EXPECT_TRUE(isa(DAG.getNode(StackRestore))); + EXPECT_FALSE(isa(DAG.getNode(SideEffect))); + EXPECT_FALSE(isa(DAG.getNode(PseudoProbe))); + EXPECT_TRUE(isa(DAG.getNode(FakeUse))); + EXPECT_TRUE(isa(DAG.getNode(Call))); + EXPECT_FALSE(isa(DAG.getNode(Ret))); } TEST_F(DependencyGraphTest, Basic) { @@ -115,3 +115,100 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) { EXPECT_THAT(N1->memPreds(), testing::ElementsAre(N0)); EXPECT_THAT(N2->memPreds(), testing::ElementsAre(N1)); } + +TEST_F(DependencyGraphTest, MemDGNode_getPrevNode_getNextNode) { + parseIR(C, R"IR( +define void @foo(ptr %ptr, i8 %v0, i8 %v1) { + store i8 %v0, ptr %ptr + add i8 %v0, %v0 + store i8 %v1, ptr %ptr + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + auto It = BB->begin(); + auto *S0 = cast(&*It++); + [[maybe_unused]] auto *Add = cast(&*It++); + auto *S1 = cast(&*It++); + [[maybe_unused]] auto *Ret = cast(&*It++); + + sandboxir::DependencyGraph DAG; + DAG.extend({&*BB->begin(), BB->getTerminator()}); + + auto *S0N = cast(DAG.getNode(S0)); + auto *S1N = cast(DAG.getNode(S1)); + + EXPECT_EQ(S0N->getPrevNode(), nullptr); + EXPECT_EQ(S0N->getNextNode(), S1N); + + EXPECT_EQ(S1N->getPrevNode(), S0N); + EXPECT_EQ(S1N->getNextNode(), nullptr); +} + +TEST_F(DependencyGraphTest, DGNodeRange) { + parseIR(C, R"IR( +define void @foo(ptr %ptr, i8 %v0, i8 %v1) { + add i8 %v0, %v0 + store i8 %v0, ptr %ptr + add i8 %v0, %v0 + store i8 %v1, ptr %ptr + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + auto It = BB->begin(); + auto *Add0 = cast(&*It++); + auto *S0 = cast(&*It++); + auto *Add1 = cast(&*It++); + auto *S1 = cast(&*It++); + auto *Ret = cast(&*It++); + + sandboxir::DependencyGraph DAG; + DAG.extend({&*BB->begin(), BB->getTerminator()}); + + auto *S0N = cast(DAG.getNode(S0)); + auto *S1N = cast(DAG.getNode(S1)); + + // Check empty range. + EXPECT_THAT(sandboxir::MemDGNodeIntervalBuilder::makeEmpty(), + testing::ElementsAre()); + + // Returns the pointers in Range. + auto getPtrVec = [](const auto &Range) { + SmallVector Vec; + for (const sandboxir::DGNode &N : Range) + Vec.push_back(&N); + return Vec; + }; + // Both TopN and BotN are memory. + EXPECT_THAT( + getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({S0, S1}, DAG)), + testing::ElementsAre(S0N, S1N)); + // Only TopN is memory. + EXPECT_THAT( + getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({S0, Ret}, DAG)), + testing::ElementsAre(S0N, S1N)); + EXPECT_THAT( + getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({S0, Add1}, DAG)), + testing::ElementsAre(S0N)); + // Only BotN is memory. + EXPECT_THAT( + getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({Add0, S1}, DAG)), + testing::ElementsAre(S0N, S1N)); + EXPECT_THAT( + getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({Add0, S0}, DAG)), + testing::ElementsAre(S0N)); + // Neither TopN or BotN is memory. + EXPECT_THAT( + getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({Add0, Ret}, DAG)), + testing::ElementsAre(S0N, S1N)); + EXPECT_THAT( + getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({Add0, Add0}, DAG)), + testing::ElementsAre()); +}