diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h index a4512862136a8b..ed1cb8488c29eb 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h @@ -284,6 +284,33 @@ class SeedContainer { #endif // NDEBUG }; +class SeedCollector { + SeedContainer StoreSeeds; + SeedContainer LoadSeeds; + Context &Ctx; + + /// \Returns the number of SeedBundle groups for all seed types. + /// This is to be used for limiting compilation time. + unsigned totalNumSeedGroups() const { + return StoreSeeds.size() + LoadSeeds.size(); + } + +public: + SeedCollector(BasicBlock *BB, ScalarEvolution &SE); + ~SeedCollector(); + + iterator_range getStoreSeeds() { + return {StoreSeeds.begin(), StoreSeeds.end()}; + } + iterator_range getLoadSeeds() { + return {LoadSeeds.begin(), LoadSeeds.end()}; + } +#ifndef NDEBUG + void print(raw_ostream &OS) const; + LLVM_DUMP_METHOD void dump() const; +#endif +}; + } // namespace llvm::sandboxir #endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_SEEDCOLLECTOR_H diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h new file mode 100644 index 00000000000000..64f57edb38484e --- /dev/null +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h @@ -0,0 +1,30 @@ +//===- VecUtils.h -----------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Collector for SandboxVectorizer related convenience functions that don't +// belong in other classes. + +#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H +#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H + +class Utils { +public: + /// \Returns the number of elements in \p Ty. That is the number of lanes if a + /// fixed vector or 1 if scalar. ScalableVectors have unknown size and + /// therefore are unsupported. + static int getNumElements(Type *Ty) { + assert(!isa(Ty)); + return Ty->isVectorTy() ? cast(Ty)->getNumElements() : 1; + } + /// Returns \p Ty if scalar or its element type if vector. + static Type *getElementType(Type *Ty) { + return Ty->isVectorTy() ? cast(Ty)->getElementType() : Ty; + } +} + +#endif LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp index 66fac080a7b7cc..0d928af1902073 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp @@ -22,6 +22,16 @@ namespace llvm::sandboxir { cl::opt SeedBundleSizeLimit( "sbvec-seed-bundle-size-limit", cl::init(32), cl::Hidden, cl::desc("Limit the size of the seed bundle to cap compilation time.")); +#define LoadSeedsDef "loads" +#define StoreSeedsDef "stores" +cl::opt CollectSeeds( + "sbvec-collect-seeds", cl::init(LoadSeedsDef "," StoreSeedsDef), cl::Hidden, + cl::desc("Collect these seeds. Use empty for none or a comma-separated " + "list of '" LoadSeedsDef "' and '" StoreSeedsDef "'.")); +cl::opt SeedGroupsLimit( + "sbvec-seed-groups-limit", cl::init(256), cl::Hidden, + cl::desc("Limit the number of collected seeds groups in a BB to " + "cap compilation time.")); MutableArrayRef SeedBundle::getSlice(unsigned StartIdx, unsigned MaxVecRegBits, @@ -131,4 +141,61 @@ void SeedContainer::print(raw_ostream &OS) const { LLVM_DUMP_METHOD void SeedContainer::dump() const { print(dbgs()); } #endif // NDEBUG +template static bool isValidMemSeed(LoadOrStoreT *LSI) { + if (LSI->isSimple()) + return true; + auto *Ty = Utils::getExpectedType(LSI); + // Omit types that are architecturally unvectorizable + if (Ty->isX86_FP80Ty() || Ty->isPPC_FP128Ty()) + return false; + // Omit vector types without compile-time-known lane counts + if (isa(Ty)) + return false; + if (auto *VTy = dyn_cast(Ty)) + return VectorType::isValidElementType(VTy->getElementType()); + return VectorType::isValidElementType(Ty); +} + +template bool isValidMemSeed(LoadInst *LSI); +template bool isValidMemSeed(StoreInst *LSI); + +SeedCollector::SeedCollector(BasicBlock *BB, ScalarEvolution &SE) + : StoreSeeds(SE), LoadSeeds(SE), Ctx(BB->getContext()) { + // TODO: Register a callback for updating the Collector data structures upon + // instr removal + + bool CollectStores = CollectSeeds.find(StoreSeedsDef) != std::string::npos; + bool CollectLoads = CollectSeeds.find(LoadSeedsDef) != std::string::npos; + if (!CollectStores && !CollectLoads) + return; + // Actually collect the seeds. + for (auto &I : *BB) { + if (StoreInst *SI = dyn_cast(&I)) + if (CollectStores && isValidMemSeed(SI)) + StoreSeeds.insert(SI); + if (LoadInst *LI = dyn_cast(&I)) + if (CollectLoads && isValidMemSeed(LI)) + LoadSeeds.insert(LI); + // Cap compilation time. + if (totalNumSeedGroups() > SeedGroupsLimit) + break; + } +} + +SeedCollector::~SeedCollector() { + // TODO: Unregister the callback for updating the seed datastructures upon + // instr removal +} + +#ifndef NDEBUG +void SeedCollector::print(raw_ostream &OS) const { + OS << "=== StoreSeeds ===\n"; + StoreSeeds.print(OS); + OS << "=== LoadSeeds ===\n"; + LoadSeeds.print(OS); +} + +void SeedCollector::dump() const { print(dbgs()); } +#endif + } // namespace llvm::sandboxir diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp index 82b230d50c4ec9..4e28413a931a61 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp @@ -268,3 +268,171 @@ define void @foo(ptr %ptrA, float %val, ptr %ptrB) { } EXPECT_EQ(Cnt, 0u); } + +TEST_F(SeedBundleTest, ConsecutiveStores) { + // Where "Consecutive" means the stores address consecutive locations in + // memory, but not in program order. Check to see that the collector puts them + // in the proper order for vectorization. + parseIR(C, R"IR( +define void @foo(ptr noalias %ptr, float %val) { +bb: + %ptr0 = getelementptr float, ptr %ptr, i32 0 + %ptr1 = getelementptr float, ptr %ptr, i32 1 + %ptr2 = getelementptr float, ptr %ptr, i32 2 + %ptr3 = getelementptr float, ptr %ptr, i32 3 + store float %val, ptr %ptr0 + store float %val, ptr %ptr2 + store float %val, ptr %ptr1 + store float %val, ptr %ptr3 + ret void +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + DominatorTree DT(LLVMF); + TargetLibraryInfoImpl TLII; + TargetLibraryInfo TLI(TLII); + DataLayout DL(M->getDataLayout()); + LoopInfo LI(DT); + AssumptionCache AC(LLVMF); + ScalarEvolution SE(LLVMF, TLI, AC, DT, LI); + + sandboxir::Context Ctx(C); + auto &F = *Ctx.createFunction(&LLVMF); + auto BB = F.begin(); + sandboxir::SeedCollector SC(&*BB, SE); + + // Find the stores + auto It = std::next(BB->begin(), 4); + // StX with X as the order by offset in memory + auto *St0 = &*It++; + auto *St2 = &*It++; + auto *St1 = &*It++; + auto *St3 = &*It++; + + auto StoreSeedsRange = SC.getStoreSeeds(); + auto &SB = *StoreSeedsRange.begin(); + // Expect just one vector of store seeds + EXPECT_EQ(range_size(StoreSeedsRange), 1u); + EXPECT_THAT(SB, testing::ElementsAre(St0, St1, St2, St3)); +} + +TEST_F(SeedBundleTest, StoresWithGaps) { + parseIR(C, R"IR( +define void @foo(ptr noalias %ptr, float %val) { +bb: + %ptr0 = getelementptr float, ptr %ptr, i32 0 + %ptr1 = getelementptr float, ptr %ptr, i32 3 + %ptr2 = getelementptr float, ptr %ptr, i32 5 + %ptr3 = getelementptr float, ptr %ptr, i32 7 + store float %val, ptr %ptr0 + store float %val, ptr %ptr2 + store float %val, ptr %ptr1 + store float %val, ptr %ptr3 + ret void +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + DominatorTree DT(LLVMF); + TargetLibraryInfoImpl TLII; + TargetLibraryInfo TLI(TLII); + DataLayout DL(M->getDataLayout()); + LoopInfo LI(DT); + AssumptionCache AC(LLVMF); + ScalarEvolution SE(LLVMF, TLI, AC, DT, LI); + + sandboxir::Context Ctx(C); + auto &F = *Ctx.createFunction(&LLVMF); + auto BB = F.begin(); + sandboxir::SeedCollector SC(&*BB, SE); + + // Find the stores + auto It = std::next(BB->begin(), 4); + // StX with X as the order by offset in memory + auto *St0 = &*It++; + auto *St2 = &*It++; + auto *St1 = &*It++; + auto *St3 = &*It++; + + auto StoreSeedsRange = SC.getStoreSeeds(); + auto &SB = *StoreSeedsRange.begin(); + // Expect just one vector of store seeds + EXPECT_EQ(range_size(StoreSeedsRange), 1u); + EXPECT_THAT(SB, testing::ElementsAre(St0, St1, St2, St3)); +} + +TEST_F(SeedBundleTest, VectorStores) { + parseIR(C, R"IR( +define void @foo(ptr noalias %ptr, <2 x float> %val) { +bb: + %ptr0 = getelementptr float, ptr %ptr, i32 0 + %ptr1 = getelementptr float, ptr %ptr, i32 1 + store <2 x float> %val, ptr %ptr1 + store <2 x float> %val, ptr %ptr0 + ret void +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + DominatorTree DT(LLVMF); + TargetLibraryInfoImpl TLII; + TargetLibraryInfo TLI(TLII); + DataLayout DL(M->getDataLayout()); + LoopInfo LI(DT); + AssumptionCache AC(LLVMF); + ScalarEvolution SE(LLVMF, TLI, AC, DT, LI); + + sandboxir::Context Ctx(C); + auto &F = *Ctx.createFunction(&LLVMF); + auto BB = F.begin(); + sandboxir::SeedCollector SC(&*BB, SE); + + // Find the stores + auto It = std::next(BB->begin(), 2); + // StX with X as the order by offset in memory + auto *St1 = &*It++; + auto *St0 = &*It++; + + auto StoreSeedsRange = SC.getStoreSeeds(); + EXPECT_EQ(range_size(StoreSeedsRange), 1u); + auto &SB = *StoreSeedsRange.begin(); + EXPECT_THAT(SB, testing::ElementsAre(St0, St1)); +} + +TEST_F(SeedBundleTest, MixedScalarVectors) { + parseIR(C, R"IR( +define void @foo(ptr noalias %ptr, float %v, <2 x float> %val) { +bb: + %ptr0 = getelementptr float, ptr %ptr, i32 0 + %ptr1 = getelementptr float, ptr %ptr, i32 1 + %ptr3 = getelementptr float, ptr %ptr, i32 3 + store float %v, ptr %ptr0 + store float %v, ptr %ptr3 + store <2 x float> %val, ptr %ptr1 + ret void +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + DominatorTree DT(LLVMF); + TargetLibraryInfoImpl TLII; + TargetLibraryInfo TLI(TLII); + DataLayout DL(M->getDataLayout()); + LoopInfo LI(DT); + AssumptionCache AC(LLVMF); + ScalarEvolution SE(LLVMF, TLI, AC, DT, LI); + + sandboxir::Context Ctx(C); + auto &F = *Ctx.createFunction(&LLVMF); + auto BB = F.begin(); + sandboxir::SeedCollector SC(&*BB, SE); + + // Find the stores + auto It = std::next(BB->begin(), 3); + // StX with X as the order by offset in memory + auto *St0 = &*It++; + auto *St3 = &*It++; + auto *St1 = &*It++; + + auto StoreSeedsRange = SC.getStoreSeeds(); + EXPECT_EQ(range_size(StoreSeedsRange), 1u); + auto &SB = *StoreSeedsRange.begin(); + EXPECT_THAT(SB, testing::ElementsAre(St0, St1, St3)); +}