diff --git a/build_tools/scripts/generate_release_index.py b/build_tools/scripts/generate_release_index.py index 70a4eeb57fde..9db8cb591f6a 100755 --- a/build_tools/scripts/generate_release_index.py +++ b/build_tools/scripts/generate_release_index.py @@ -44,7 +44,8 @@ def get_all(self): url = f"https://api.github.com/repos/{self._repo}/releases" page = 1 - while True: + # GitHub limits API responses to the first 1000 results. + while page * self._per_page < 1000: response = self._session.get( url, params={ diff --git a/compiler/plugins/target/CUDA/CUDATarget.cpp b/compiler/plugins/target/CUDA/CUDATarget.cpp index 1fffecae41d4..b5ef37f185da 100644 --- a/compiler/plugins/target/CUDA/CUDATarget.cpp +++ b/compiler/plugins/target/CUDA/CUDATarget.cpp @@ -58,7 +58,6 @@ struct CUDAOptions { bool clUsePtxas = false; std::string clUsePtxasFrom; std::string clUsePtxasParams; - bool enableLegacySync = false; void bindOptions(OptionsBinder &binder) { static llvm::cl::OptionCategory category("CUDA HAL Target"); @@ -101,12 +100,6 @@ struct CUDAOptions { "iree-hal-cuda-use-ptxas-params", clUsePtxasParams, llvm::cl::cat(category), llvm::cl::desc("Passes the given additional parameters to ptxas.")); - - binder.opt( - "iree-hal-cuda-enable-legacy-sync", enableLegacySync, - llvm::cl::cat(category), - llvm::cl::desc( - "Enable legacy sync mode that handles semaphores synchronously.")); } }; } // namespace @@ -391,12 +384,6 @@ class CUDATargetBackend final : public TargetBackend { Builder b(context); SmallVector configItems; - // Indicates that the runtime HAL driver operates only in the legacy - // synchronous mode. - if (options.enableLegacySync) { - configItems.emplace_back(b.getStringAttr("legacy_sync"), b.getUnitAttr()); - } - configItems.emplace_back(b.getStringAttr("executable_targets"), getExecutableTargets(context)); diff --git a/compiler/plugins/target/ROCM/ROCMTargetFeatures.cpp b/compiler/plugins/target/ROCM/ROCMTargetFeatures.cpp index 622037c28838..3da61cf5ace0 100644 --- a/compiler/plugins/target/ROCM/ROCMTargetFeatures.cpp +++ b/compiler/plugins/target/ROCM/ROCMTargetFeatures.cpp @@ -21,7 +21,8 @@ static ArrayAttr getMfmaArrayAttr(MLIRContext *context, } ArrayAttr getROCMSupportedMmaAttrs(MLIRContext *context, StringRef targetArch) { - if (targetArch == "gfx940") { + // MI300a/x + if (targetArch == "gfx940" || targetArch == "gfx942") { return getMfmaArrayAttr(context, {IREE::GPU::MFMAIntrinsic::F16_16x16x16_F32, IREE::GPU::MFMAIntrinsic::F16_32x32x8_F32}); diff --git a/compiler/plugins/target/ROCM/test/BUILD.bazel b/compiler/plugins/target/ROCM/test/BUILD.bazel index 17fc79e50c4b..dcdd8f1a83d5 100644 --- a/compiler/plugins/target/ROCM/test/BUILD.bazel +++ b/compiler/plugins/target/ROCM/test/BUILD.bazel @@ -15,7 +15,10 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( - ["smoketest.mlir"], + [ + "smoketest.mlir", + "target_device_features.mlir", + ], include = ["*.mlir"], ), cfg = "//compiler:lit.cfg.py", diff --git a/compiler/plugins/target/ROCM/test/CMakeLists.txt b/compiler/plugins/target/ROCM/test/CMakeLists.txt index 3f58b6e44ae1..f9805b6059dd 100644 --- a/compiler/plugins/target/ROCM/test/CMakeLists.txt +++ b/compiler/plugins/target/ROCM/test/CMakeLists.txt @@ -15,6 +15,7 @@ iree_lit_test_suite( lit SRCS "smoketest.mlir" + "target_device_features.mlir" TOOLS FileCheck iree-opt diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir new file mode 100644 index 000000000000..9f8bf60f74e3 --- /dev/null +++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir @@ -0,0 +1,27 @@ +// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targets=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx940 %s | FileCheck %s --check-prefix=MI300 +// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targets=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx942 %s | FileCheck %s --check-prefix=MI300 + +// MI300: mma_intrinsics = [#iree_gpu.mfma_layout, #iree_gpu.mfma_layout] + +stream.executable public @reduce_dispatch { + stream.executable.export @reduce_dispatch workgroups(%arg0: index) -> (index, index, index) { + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + stream.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @reduce_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding) { + %c0 = arith.constant 0 : index + %arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> + %arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> + %0 = tensor.empty() : tensor + %1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor> -> tensor<16xf32> + %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%1 : tensor<16xf32>) outs(%0 : tensor) { + ^bb0(%arg2: f32, %arg3: f32): + %4 = arith.addf %arg2, %arg3 : f32 + linalg.yield %4 : f32 + } -> tensor + flow.dispatch.tensor.store %3, %arg1, offsets=[], sizes=[], strides=[] : tensor -> !flow.dispatch.tensor> + return + } + } +} diff --git a/compiler/plugins/target/WebGPU/WebGPUTarget.cpp b/compiler/plugins/target/WebGPU/WebGPUTarget.cpp index af0293d25262..6fd63fae5777 100644 --- a/compiler/plugins/target/WebGPU/WebGPUTarget.cpp +++ b/compiler/plugins/target/WebGPU/WebGPUTarget.cpp @@ -75,10 +75,6 @@ class WebGPUTargetBackend : public TargetBackend { Builder b(context); SmallVector configItems; - // Indicates that the runtime HAL driver operates only in the legacy - // synchronous mode. - configItems.emplace_back(b.getStringAttr("legacy_sync"), b.getUnitAttr()); - configItems.emplace_back(b.getStringAttr("executable_targets"), getExecutableTargets(context)); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp index 0dd0fd8a2e2e..4252d26cf5aa 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp @@ -6,13 +6,13 @@ #include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" -#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h" +#include "iree/compiler/Codegen/Utils/VectorOpUtils.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" -#define DEBUG_TYPE "iree-amdgpu-distribute-contract" +#define DEBUG_TYPE "iree-codegen-amdgpu-distribute-contract" namespace mlir::iree_compiler { namespace { @@ -20,75 +20,6 @@ namespace { using namespace mlir::iree_compiler::IREE::VectorExt; using VectorValue = TypedValue; -/// A class for querying information about a contract op. -class ContractOpDetail { -public: - enum class OpKind { MK_KN_MN, MK_NK_MN, UNKNOWN }; - - explicit ContractOpDetail(vector::ContractionOp op) { - opKind = inferOpKind(op.getContext(), op.getIndexingMapsArray()); - } - - OpKind getOpKind() const { return opKind; } - - // Returns the (LHS M, RHS N) dimension index pair. - std::optional> getOperandMNIndex() const { - switch (opKind) { - case OpKind::MK_KN_MN: - return std::make_pair(0, 1); - case OpKind::MK_NK_MN: - return std::make_pair(0, 0); - case OpKind::UNKNOWN: - break; - } - return std::nullopt; - } - - // Returns the (LHS K, RHS K) dimension index pair. - std::optional> getOperandKIndex() const { - switch (opKind) { - case OpKind::MK_KN_MN: - return std::make_pair(1, 0); - case OpKind::MK_NK_MN: - return std::make_pair(1, 1); - case OpKind::UNKNOWN: - break; - } - return std::nullopt; - } - - // Returns the result (M, N) dimension index pair. - std::optional> getResultMNIndex() const { - switch (opKind) { - case OpKind::MK_KN_MN: - case OpKind::MK_NK_MN: - return std::make_pair(0, 1); - default: - break; - } - return std::nullopt; - } - -private: - // Gets the kind of a contract op with the given indexing |maps|. - OpKind inferOpKind(MLIRContext *ctx, SmallVector maps) { - using MapList = ArrayRef>; - auto infer = [&](MapList m) { - return AffineMap::inferFromExprList(m, ctx); - }; - AffineExpr m, n, k; - bindDims(ctx, m, n, k); - if (maps == infer({{m, k}, {k, n}, {m, n}})) - return OpKind::MK_KN_MN; - if (maps == infer({{m, k}, {n, k}, {m, n}})) - return OpKind::MK_NK_MN; - return OpKind::UNKNOWN; - } - -private: - OpKind opKind = OpKind::UNKNOWN; -}; - /// Distributes `vector.contract` ops with nested layouts. struct DistributeContract final : OpDistributionPattern { using OpDistributionPattern::OpDistributionPattern; @@ -140,8 +71,8 @@ struct DistributeContract final : OpDistributionPattern { mfmaParams.blocks = mfmaAttr.getBlockSize(); // Infer the contract kind so that we know know to correlate M/N/K dims. - ContractOpDetail opDetail(contractOp); - if (opDetail.getOpKind() == ContractOpDetail::OpKind::UNKNOWN) { + VectorContractOpInfo opDetail(contractOp); + if (opDetail.getOpKind() == VectorContractOpInfo::OpKind::UNKNOWN) { return rewriter.notifyMatchFailure(contractOp, "unknown contract kind"); } @@ -243,7 +174,7 @@ struct DistributeContract final : OpDistributionPattern { } // Gets the batch size for matmul K dimensions. - std::optional getKBatchSize(const ContractOpDetail &opDetail, + std::optional getKBatchSize(const VectorContractOpInfo &opDetail, NestedLayoutAttr lhsLayout, NestedLayoutAttr rhsLayout) const { auto [lhsK, rhsK] = *opDetail.getOperandKIndex(); @@ -257,7 +188,7 @@ struct DistributeContract final : OpDistributionPattern { // Given a contract op's batch |resultOffsets|, fills its batch offsets for // both LHS and RHS. - void fillOperandBatchOffsets(const ContractOpDetail &opDetail, + void fillOperandBatchOffsets(const VectorContractOpInfo &opDetail, int64_t kOffset, ArrayRef resultOffsets, NestedLayoutAttr resultLayout, SmallVector &lhsOffsets, diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel index 7aa80c6fb83b..4cd91f82e9ec 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel @@ -85,6 +85,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect", "//compiler/src/iree/compiler/Codegen/Transforms", "//compiler/src/iree/compiler/Codegen/Utils", + "//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils", "//compiler/src/iree/compiler/Dialect/HAL/IR", "//llvm-external-projects/iree-dialects:IREELinalgExtDialect", "//llvm-external-projects/iree-dialects:IREELinalgExtTransforms", @@ -125,3 +126,18 @@ iree_compiler_cc_library( "@llvm-project//mlir:VectorTransforms", ], ) + +iree_compiler_cc_library( + name = "GPUHeuristics", + srcs = [ + "GPUHeuristics.cpp", + ], + hdrs = [ + "GPUHeuristics.h", + ], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt index 72651a5055c7..0e6362b51502 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt @@ -113,8 +113,23 @@ iree_cc_library( iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect iree::compiler::Codegen::Transforms iree::compiler::Codegen::Utils + iree::compiler::Codegen::Utils::VectorOpUtils iree::compiler::Dialect::HAL::IR PUBLIC ) +iree_cc_library( + NAME + GPUHeuristics + HDRS + "GPUHeuristics.h" + SRCS + "GPUHeuristics.cpp" + DEPS + LLVMSupport + MLIRIR + MLIRSupport + PUBLIC +) + ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp new file mode 100644 index 000000000000..b09a81738d4d --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp @@ -0,0 +1,113 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/GPU/GPUHeuristics.h" +#include "llvm/ADT/APInt.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "iree-codegen-gpu-heuristics" + +using llvm::APIntOps::GreatestCommonDivisor; + +namespace mlir::iree_compiler { + +std::optional +deduceMMASchedule(const GPUMatmulShapeType &problem, + ArrayRef intrinsics, + const GPUMMAHeuristicSeeds &seeds) { + for (auto [index, intrinsic] : llvm::enumerate(intrinsics)) { + if (problem.aType != intrinsic.aType || problem.bType != intrinsic.bType || + problem.cType != intrinsic.cType) { + continue; // Cannot use this intrinsic for mismatched types + } + + if (problem.mSize % intrinsic.mSize != 0 || + problem.nSize % intrinsic.nSize != 0 || + problem.kSize % intrinsic.kSize != 0) { + continue; // Cannot use this intrinsic for misaligned cases + } + + int64_t mTotalTileCount = problem.mSize / intrinsic.mSize; + int64_t nTotalTileCount = problem.nSize / intrinsic.nSize; + + int64_t remainingWarps = seeds.bestSubgroupCountPerWorkgroup; + int64_t remainingTiles = seeds.bestMNTileCountPerSubgroup; + // Assign more warps to the M dimension (used later) to balance thread + // counts along X and Y dimensions. + int64_t warpSqrt = 1ull + << (llvm::divideCeil(llvm::Log2_64(remainingWarps), 2)); + int64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2); + + int64_t mWarpCount = 0, nWarpCount = 0; + int64_t mTileCount = 0, nTileCount = 0; + + // See if the square root can divide mTotalTileCount. If so it means we can + // distribute to both dimensions evenly. Otherwise, try to distribute to N + // and then M. + if (mTotalTileCount > (warpSqrt * tileSqrt) && + mTotalTileCount % (warpSqrt * tileSqrt) == 0) { + mWarpCount = warpSqrt; + mTileCount = tileSqrt; + + remainingWarps /= warpSqrt; + remainingTiles /= tileSqrt; + + APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), + APInt(64, remainingWarps)); + nWarpCount = nGCD.getSExtValue(); + nTotalTileCount /= nWarpCount; + remainingWarps /= nWarpCount; + + nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), + APInt(64, remainingTiles)); + nTileCount = nGCD.getSExtValue(); + } else { + APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), + APInt(64, remainingWarps)); + nWarpCount = nGCD.getSExtValue(); + nTotalTileCount /= nWarpCount; + remainingWarps /= nWarpCount; + + nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), + APInt(64, remainingTiles)); + nTileCount = nGCD.getSExtValue(); + remainingTiles /= nTileCount; + + APInt mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount), + APInt(64, remainingWarps)); + mWarpCount = mGCD.getSExtValue(); + mTotalTileCount /= mWarpCount; + remainingWarps /= mWarpCount; + + mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount), + APInt(64, remainingTiles)); + mTileCount = mGCD.getSExtValue(); + } + + const uint64_t kTotalTileCount = problem.kSize / intrinsic.kSize; + APInt kGCD = GreatestCommonDivisor( + APInt(64, kTotalTileCount), APInt(64, seeds.bestKTileCountPerSubgroup)); + int64_t kTileCount = kGCD.getSExtValue(); + + LLVM_DEBUG({ + llvm::dbgs() << "chosen MMA schedule:\n"; + llvm::dbgs() << " intrinsic (M, N, K) = (" << intrinsic.mSize << ", " + << intrinsic.nSize << ", " << intrinsic.kSize << ")\n"; + llvm::dbgs() << " subgroup count (M, N) = (" << mWarpCount << ", " + << nWarpCount << ")\n"; + llvm::dbgs() << " subgroup tile count (M, N, K) = (" << mTileCount + << ", " << nTileCount << ", " << kTileCount << ")\n"; + }); + return GPUMMASchedule{index, intrinsic.mSize, intrinsic.nSize, + intrinsic.kSize, mWarpCount, nWarpCount, + mTileCount, nTileCount, kTileCount}; + } + return std::nullopt; +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h new file mode 100644 index 000000000000..6bf4ff84f338 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h @@ -0,0 +1,54 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "mlir/IR/Types.h" + +namespace mlir::iree_compiler { + +/// Struct containing information about a matmul's shape and type. +struct GPUMatmulShapeType { + int64_t mSize; + int64_t nSize; + int64_t kSize; + Type aType; + Type bType; + Type cType; + + GPUMatmulShapeType(int64_t m, int64_t n, int64_t k, Type a, Type b, Type c) + : mSize(m), nSize(n), kSize(k), aType(a), bType(b), cType(c) {} +}; + +/// Struct containing seed tile sizes for GPU MMA heuristics deduction logic. +struct GPUMMAHeuristicSeeds { + // The best number of subgroups to use per workgroup + int64_t bestSubgroupCountPerWorkgroup; + // The best number of total tiles along M*N dimensions per subgroup + int64_t bestMNTileCountPerSubgroup; + // The best number of tiles along K dimension per subgroup + int64_t bestKTileCountPerSubgroup; +}; + +struct GPUMMASchedule { + // Index of the chosen intrinsic into the list of given MMA intrinsics + uint64_t index; + int64_t mSize; // Native MMA size along M dimension + int64_t nSize; // Native MMA size along N dimension + int64_t kSize; // Native MMA size along K dimension + int64_t mWarpCount; // Number of subgroups along M dimension + int64_t nWarpCount; // Number of subgroups along N dimension + int64_t mTileCount; // Number of tiles per subgroup along M dimension + int64_t nTileCount; // Number of tiles per subgroup along N dimension + int64_t kTileCount; // Number of tiles along K dimension +}; + +/// Returns a schedule for using one of the given MMA |intrinsics| to target the +/// input |problem|. Returns std::nullopt if we cannot find such a schedule. +std::optional +deduceMMASchedule(const GPUMatmulShapeType &problem, + ArrayRef intrinsics, + const GPUMMAHeuristicSeeds &seeds); + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp index 62cd5b2697f5..6114f3fe0c93 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp @@ -5,12 +5,17 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h" +#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h" #include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h" #include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" +#include "mlir/IR/Visitors.h" #include "mlir/Rewrite/PatternApplicator.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "iree-codegen-gpu-vector-distribution" @@ -220,15 +225,15 @@ static bool canDistribute(Operation *op, VectorLayoutAnalysis &analysis) { }); } -void distributeVectorOps(Operation *root, - RewritePatternSet &distributionPatterns, - VectorLayoutOptions &options) { +LogicalResult distributeVectorOps(Operation *root, + RewritePatternSet &distributionPatterns, + VectorLayoutOptions &options) { // Run the analysis and determine the layouts. LLVM_DEBUG(llvm::dbgs() << "Running Layout Analysis\n"); VectorLayoutAnalysis analysis(root); options.setAnchorOps(analysis); if (failed(analysis.run())) - return; + return failure(); LLVM_DEBUG(llvm::dbgs() << "Layout Analysis Succeded\n"); LLVM_DEBUG(llvm::dbgs() << "\n\n"); @@ -245,7 +250,38 @@ void distributeVectorOps(Operation *root, LLVM_DEBUG(llvm::dbgs() << "\n\n"); FrozenRewritePatternSet frozenPatterns(std::move(distributionPatterns)); - return applyVectorDistribution(root, frozenPatterns); + applyVectorDistribution(root, frozenPatterns); + + RewritePatternSet patterns(root->getContext()); + IREE::VectorExt::ToSIMDOp::getCanonicalizationPatterns(patterns, + root->getContext()); + IREE::VectorExt::ToSIMTOp::getCanonicalizationPatterns(patterns, + root->getContext()); + if (failed(applyPatternsAndFoldGreedily(root, std::move(patterns)))) { + return failure(); + } + + if (options.verifyConversion()) { + WalkResult hasConversionOp = root->walk([](Operation *op) { + if (isa(op)) { + for (auto user : op->getUsers()) { + if (!isa( + user)) { + LLVM_DEBUG({ + llvm::dbgs() << "Found live cast op: " << *op << "\n"; + llvm::dbgs() << "With live user: " << *user << "\n"; + }); + return WalkResult::interrupt(); + } + } + } + return WalkResult::advance(); + }); + if (hasConversionOp.wasInterrupted()) { + return failure(); + } + } + return success(); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h index 9e365b035297..e740c598125d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h @@ -87,7 +87,11 @@ class OpTraitDistributionPattern : public DistributionPattern { /// distribution. class VectorLayoutOptions { public: - VectorLayoutOptions(Operation *root) : root(root) { + VectorLayoutOptions(Operation *root) : root(root), fullConversion(true) { + assert(root && "root operation must be non-null"); + } + VectorLayoutOptions(Operation *root, bool fullConversion) + : root(root), fullConversion(fullConversion) { assert(root && "root operation must be non-null"); } @@ -96,8 +100,11 @@ class VectorLayoutOptions { /// Set the anchor ops in the analysis rooted on the root operation. virtual void setAnchorOps(VectorLayoutAnalysis &analysis) = 0; + bool verifyConversion() const { return fullConversion; } + protected: Operation *root; + bool fullConversion = true; }; // namespace iree_compiler /// Distribute vector operations in the IR rooted at `root`. @@ -112,9 +119,9 @@ class VectorLayoutOptions { /// - Run a global analysis to determine how to distribute rest of the vector /// values keeping the initial anchors in mind. /// - Use the analysis information to distribute each operation. -void distributeVectorOps(Operation *root, - RewritePatternSet &distributionPatterns, - VectorLayoutOptions &options); +LogicalResult distributeVectorOps(Operation *root, + RewritePatternSet &distributionPatterns, + VectorLayoutOptions &options); } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir index 0b8078d3ff58..8ee8923dbea5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_contract_amdgpu.mlir @@ -387,13 +387,11 @@ builtin.module attributes { transform.with_named_sequence } { // CHECK: %[[A0_CAST:.+]] = vector.shape_cast %[[A_SLICE0]] : vector<1x1x1x4xf16> to vector<4xf16> // CHECK: %[[B0_CAST:.+]] = vector.shape_cast %[[B_SLICE0]] : vector<1x1x1x4xf16> to vector<4xf16> // CHECK: %[[MFMA0:.+]] = amdgpu.mfma %[[A0_CAST]] * %[[B0_CAST]] + %{{.+}} -// CHECK: %[[R0_CAST:.+]] = vector.shape_cast %[[MFMA0]] : vector<4x4xf32> to vector<4x1x1x4xf32> // CHECK: %[[A_SLICE1:.+]] = vector.extract %[[A_SIMT]][0, 1] : vector<1x1x1x4xf16> from vector<1x2x1x1x1x4xf16> // CHECK: %[[B_SLICE1:.+]] = vector.extract %[[B_SIMT]][1, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16> // CHECK: %[[A1_CAST:.+]] = vector.shape_cast %[[A_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16> // CHECK: %[[B1_CAST:.+]] = vector.shape_cast %[[B_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16> -// CHECK: %[[CAST:.+]] = vector.shape_cast %[[R0_CAST]] : vector<4x1x1x4xf32> to vector<4x4xf32> -// CHECK: %[[MFMA1:.+]] = amdgpu.mfma %[[A1_CAST]] * %[[B1_CAST]] + %[[CAST]] +// CHECK: %[[MFMA1:.+]] = amdgpu.mfma %[[A1_CAST]] * %[[B1_CAST]] + %[[MFMA0]] // CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<4x4xf32> to vector<4x1x1x4xf32> // CHECK: %[[INSERT:.+]] = vector.insert %[[R_CAST]], %{{.+}} [0, 0] : vector<4x1x1x4xf32> into vector<1x1x4x1x1x4xf32> // CHECK: %[[R:.+]] = iree_vector_ext.to_simd %[[INSERT]] : vector<1x1x4x1x1x4xf32> -> vector<32x32xf32> diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index da6f4ab2985b..78a86660a658 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp @@ -939,7 +939,8 @@ void transform_dialect::TestVectorLayoutAnalysisOp::getEffects( class TestVectorLayoutOptions : public VectorLayoutOptions { public: - TestVectorLayoutOptions(Operation *root) : VectorLayoutOptions(root) {} + TestVectorLayoutOptions(Operation *root) + : VectorLayoutOptions(root, /*fullConversion=*/false) {} void setAnchorOps(VectorLayoutAnalysis &analysis) override { setAnchorOpsFromAttributes(analysis, root); @@ -970,7 +971,9 @@ transform_dialect::TestGpuVectorDistribution::applyToOne( populateGPUDistributeNestedLayoutContractAMDGPUPatterns(patterns); if (getExperimental()) populateGPULayoutResolutionDistributionPatterns(patterns); - distributeVectorOps(target, patterns, options); + if (failed(distributeVectorOps(target, patterns, options))) { + return emitDefaultDefiniteFailure(target); + } return DiagnosedSilenceableFailure::success(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/test/remove_trivial_loops.mlir b/compiler/src/iree/compiler/Codegen/Common/test/remove_trivial_loops.mlir index f3b11f96e698..afc09ebc3dcb 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/remove_trivial_loops.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/remove_trivial_loops.mlir @@ -196,59 +196,55 @@ hal.executable private @both_workgroup_and_workitem { // ----- -#config = #iree_codegen.lowering_config -#device_target_cpu = #hal.device.target<"llvm-cpu", {executable_targets = [#hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-none-elf"}>]}> #pipeline_layout = #hal.pipeline.layout, #hal.descriptor_set.binding<1, storage_buffer>, #hal.descriptor_set.binding<2, storage_buffer>]>]> #translation = #iree_codegen.translation_info #map0 = affine_map<()[s0] -> (s0 ceildiv 4)> #map1 = affine_map<()[s0] -> (s0 * 4)> #map2 = affine_map<()[s0, s1] -> (-((s0 * -4 + 4) mod (s1 * 4)) + 4)> #map3 = affine_map<(d0)[s0] -> (d0 + s0)> -module attributes {hal.device.targets = [#device_target_cpu]} { - hal.executable private @simple_mul { - hal.executable.variant public @embedded_elf_x86_64 target(#hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-none-elf"}>) { - hal.executable.export public @simple_mul ordinal(0) layout(#pipeline_layout) attributes {translation_info = #translation} { - ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index): - %c1 = arith.constant 1 : index - %0 = affine.apply #map0()[%arg1] - hal.return %0, %c1, %c1 : index, index, index - } - builtin.module { - func.func @simple_mul() { - %cst = arith.constant 0.000000e+00 : f32 - %c4 = arith.constant 4 : index - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<4xf32> - memref.assume_alignment %0, 64 : memref<4xf32> - %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<4xf32> - memref.assume_alignment %1, 64 : memref<4xf32> - %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<4xf32> - memref.assume_alignment %2, 64 : memref<4xf32> - %workgroup_id_x = hal.interface.workgroup.id[0] : index - %workgroup_count_x = hal.interface.workgroup.count[0] : index - %3 = affine.apply #map1()[%workgroup_id_x] - %4 = affine.apply #map1()[%workgroup_count_x] - %5 = affine.apply #map2()[%workgroup_id_x, %workgroup_count_x] - scf.for %arg0 = %3 to %5 step %4 { - %6 = memref.subview %2[%arg0] [4] [1] : memref<4xf32> to memref<4xf32, #map3> - %7 = memref.subview %0[%arg0] [4] [1] : memref<4xf32> to memref<4xf32, #map3> - %8 = memref.subview %1[%arg0] [4] [1] : memref<4xf32> to memref<4xf32, #map3> - %9 = vector.transfer_read %7[%c0], %cst {in_bounds = [true]} : memref<4xf32, #map3>, vector<4xf32> - %10 = vector.transfer_read %8[%c0], %cst {in_bounds = [true]} : memref<4xf32, #map3>, vector<4xf32> - %11 = arith.mulf %9, %10 : vector<4xf32> - vector.transfer_write %11, %6[%c0] {in_bounds = [true]} : vector<4xf32>, memref<4xf32, #map3> - } - scf.for %arg0 = %5 to %c4 step %4 { - %6 = memref.subview %2[%arg0] [4] [1] : memref<4xf32> to memref<4xf32, #map3> - %7 = memref.subview %0[%arg0] [4] [1] : memref<4xf32> to memref<4xf32, #map3> - %8 = memref.subview %1[%arg0] [4] [1] : memref<4xf32> to memref<4xf32, #map3> - %9 = vector.transfer_read %7[%c0], %cst {in_bounds = [true]} : memref<4xf32, #map3>, vector<4xf32> - %10 = vector.transfer_read %8[%c0], %cst {in_bounds = [true]} : memref<4xf32, #map3>, vector<4xf32> - %11 = arith.mulf %9, %10 : vector<4xf32> - vector.transfer_write %11, %6[%c0] {in_bounds = [true]} : vector<4xf32>, memref<4xf32, #map3> - } - return +hal.executable private @simple_mul { + hal.executable.variant public @variant target(#hal.executable.target<"cuda", "cuda-nvptx-fb">) { + hal.executable.export public @simple_mul ordinal(0) layout(#pipeline_layout) attributes {translation_info = #translation} { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index): + %c1 = arith.constant 1 : index + %0 = affine.apply #map0()[%arg1] + hal.return %0, %c1, %c1 : index, index, index + } + builtin.module { + func.func @simple_mul() { + %cst = arith.constant 0.000000e+00 : f32 + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : memref<4xf32> + memref.assume_alignment %0, 64 : memref<4xf32> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<4xf32> + memref.assume_alignment %1, 64 : memref<4xf32> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<4xf32> + memref.assume_alignment %2, 64 : memref<4xf32> + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %3 = affine.apply #map1()[%workgroup_id_x] + %4 = affine.apply #map1()[%workgroup_count_x] + %5 = affine.apply #map2()[%workgroup_id_x, %workgroup_count_x] + scf.for %arg0 = %3 to %5 step %4 { + %6 = memref.subview %2[%arg0] [4] [1] : memref<4xf32> to memref<4xf32, #map3> + %7 = memref.subview %0[%arg0] [4] [1] : memref<4xf32> to memref<4xf32, #map3> + %8 = memref.subview %1[%arg0] [4] [1] : memref<4xf32> to memref<4xf32, #map3> + %9 = vector.transfer_read %7[%c0], %cst {in_bounds = [true]} : memref<4xf32, #map3>, vector<4xf32> + %10 = vector.transfer_read %8[%c0], %cst {in_bounds = [true]} : memref<4xf32, #map3>, vector<4xf32> + %11 = arith.mulf %9, %10 : vector<4xf32> + vector.transfer_write %11, %6[%c0] {in_bounds = [true]} : vector<4xf32>, memref<4xf32, #map3> } + scf.for %arg0 = %5 to %c4 step %4 { + %6 = memref.subview %2[%arg0] [4] [1] : memref<4xf32> to memref<4xf32, #map3> + %7 = memref.subview %0[%arg0] [4] [1] : memref<4xf32> to memref<4xf32, #map3> + %8 = memref.subview %1[%arg0] [4] [1] : memref<4xf32> to memref<4xf32, #map3> + %9 = vector.transfer_read %7[%c0], %cst {in_bounds = [true]} : memref<4xf32, #map3>, vector<4xf32> + %10 = vector.transfer_read %8[%c0], %cst {in_bounds = [true]} : memref<4xf32, #map3>, vector<4xf32> + %11 = arith.mulf %9, %10 : vector<4xf32> + vector.transfer_write %11, %6[%c0] {in_bounds = [true]} : vector<4xf32>, memref<4xf32, #map3> + } + return } } } diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir index 66cb0af5cc69..0e5c817c1f96 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir @@ -2385,28 +2385,26 @@ hal.executable private @scatter { // ----- -module attributes {hal.device.targets = [#hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_86"}>], legacy_sync}>]} { - hal.executable private @collapse_workgroups_dispatch_dispatch_0 { - hal.executable.variant public @cuda_nvptx_fb target(<"cuda", "cuda-nvptx-fb", {target_arch = "sm_86"}>) { - hal.executable.export public @collapse_workgroups_dispatch_dispatch_0_generic_1024x128x16x64 ordinal(0) layout(#hal.pipeline.layout, <1, storage_buffer>]>]>) { - ^bb0(%arg0: !hal.device): - %x, %y, %z = flow.dispatch.workgroup_count_from_slice - hal.return %x, %y, %z : index, index, index - } - builtin.module { - func.func @collapse_workgroups_dispatch_dispatch_0_generic_1024x128x16x64() { - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1024, 16, 128, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<1024x16x128x64xf32> - %3 = tensor.empty() : tensor<1024x128x16x64xf32> - %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<1024x16x128x64xf32>) outs(%3 : tensor<1024x128x16x64xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<1024x128x16x64xf32> - flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0, 0], sizes = [1024, 128, 16, 64], strides = [1, 1, 1, 1] : tensor<1024x128x16x64xf32> -> !flow.dispatch.tensor> - return - } +hal.executable private @collapse_workgroups_dispatch_dispatch_0 { + hal.executable.variant public @cuda_nvptx_fb target(<"cuda", "cuda-nvptx-fb", {target_arch = "sm_86"}>) { + hal.executable.export public @collapse_workgroups_dispatch_dispatch_0_generic_1024x128x16x64 ordinal(0) layout(#hal.pipeline.layout, <1, storage_buffer>]>]>) { + ^bb0(%arg0: !hal.device): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @collapse_workgroups_dispatch_dispatch_0_generic_1024x128x16x64() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1024, 16, 128, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<1024x16x128x64xf32> + %3 = tensor.empty() : tensor<1024x128x16x64xf32> + %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<1024x16x128x64xf32>) outs(%3 : tensor<1024x128x16x64xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<1024x128x16x64xf32> + flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0, 0], sizes = [1024, 128, 16, 64], strides = [1, 1, 1, 1] : tensor<1024x128x16x64xf32> -> !flow.dispatch.tensor> + return } } } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel index fdbdad32886d..6f6de6b46255 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel @@ -60,6 +60,7 @@ iree_compiler_cc_library( ":IREEGPUInterfaces", "//compiler/src/iree/compiler/Codegen/Common", "//compiler/src/iree/compiler/Codegen/Utils", + "//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils", "//llvm-external-projects/iree-dialects:IREEVectorExtDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:DialectUtils", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/CMakeLists.txt index 853f8b0b6c77..28d58bf0967f 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/CMakeLists.txt @@ -41,6 +41,7 @@ iree_cc_library( MLIRVectorDialect iree::compiler::Codegen::Common iree::compiler::Codegen::Utils + iree::compiler::Codegen::Utils::VectorOpUtils PUBLIC ) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 4e140f3ccea7..752e8dd3fd57 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -6,16 +6,21 @@ #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h" #include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h" +#include "iree/compiler/Codegen/Utils/VectorOpUtils.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/TypeUtilities.h" +#define DEBUG_TYPE "iree-gpu-attrs" + #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.cpp.inc" #define GET_ATTRDEF_CLASSES #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp.inc" @@ -27,6 +32,7 @@ using VectorLayoutInterface = mlir::iree_compiler::IREE::VectorExt::VectorLayoutInterface; using PerDimLayoutAttr = mlir::iree_compiler::IREE::VectorExt::PerDimLayoutAttr; using LayoutAttr = mlir::iree_compiler::IREE::VectorExt::LayoutAttr; +using NestedLayoutAttr = mlir::iree_compiler::IREE::VectorExt::NestedLayoutAttr; namespace mlir::iree_compiler::IREE::GPU { @@ -324,7 +330,7 @@ MFMAAttr::getContractionLayout(vector::ContractionOp contract) const { return IREE::GPU::getContractionLayout(contract, layout); } -int64_t MFMAAttr::getBlockSize() { +int64_t MFMAAttr::getBlockSize() const { switch (getIntrinsic().getValue()) { case MFMAIntrinsic::F16_16x16x16_F32: { return 1; @@ -337,95 +343,225 @@ int64_t MFMAAttr::getBlockSize() { return 0; } -//===----------------------------------------------------------------------===// -// Initialize attributes -//===----------------------------------------------------------------------===// - -void IREEGPUDialect::registerAttributes() { - addAttributes< -#define GET_ATTRDEF_LIST -#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp.inc" // IWYU pragma: keep - >(); +MFMAAttr::SingleSubgroupLayout MFMAAttr::getASingleSubgroupLayoutCount() const { + switch (getIntrinsic().getValue()) { + case MFMAIntrinsic::F16_16x16x16_F32: { + return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*element=*/{1, 4}}; + } + case MFMAIntrinsic::F16_32x32x8_F32: { + return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*element=*/{1, 4}}; + } + } + return {}; } -} // namespace mlir::iree_compiler::IREE::GPU - -namespace mlir::iree_compiler { - -std::optional -getCompatibleMmaAttr(ArrayAttr mmaKinds, vector::ContractionOp contract) { - SmallVector iterationBounds; - contract.getIterationBounds(iterationBounds); - return getCompatibleMmaAttr(mmaKinds, contract.getIndexingMapsArray(), - iterationBounds, contract->getOperandTypes()); +MFMAAttr::SingleSubgroupLayout MFMAAttr::getBSingleSubgroupLayoutCount() const { + switch (getIntrinsic().getValue()) { + case MFMAIntrinsic::F16_16x16x16_F32: { + return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*element=*/{4, 1}}; + } + case MFMAIntrinsic::F16_32x32x8_F32: { + return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*element=*/{4, 1}}; + } + } + return {}; } -std::optional -getCompatibleMmaAttr(ArrayAttr mmaKinds, linalg::LinalgOp linalgOp) { - return getCompatibleMmaAttr(mmaKinds, linalgOp.getIndexingMapsArray(), - linalgOp.getStaticLoopRanges(), - linalgOp->getOperandTypes()); +MFMAAttr::SingleSubgroupLayout MFMAAttr::getCSingleSubgroupLayoutCount() const { + switch (getIntrinsic().getValue()) { + case MFMAIntrinsic::F16_16x16x16_F32: { + return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*element=*/{4, 1}}; + } + case MFMAIntrinsic::F16_32x32x8_F32: { + return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*element=*/{4, 1}}; + } + } + return {}; } -std::optional -getCompatibleMmaAttr(ArrayAttr mmaKinds, ArrayRef indexingMaps, - ArrayRef iterationBounds, TypeRange inputTypes) { - FailureOr maybeContractionDims = - linalg::inferContractionDims(indexingMaps); - if (failed(maybeContractionDims)) { - return std::nullopt; +MFMAAttr::SingleSubgroupLayout MFMAAttr::getASingleSubgroupLayoutOrder() const { + switch (getIntrinsic().getValue()) { + case MFMAIntrinsic::F16_16x16x16_F32: + case MFMAIntrinsic::F16_32x32x8_F32: { + return {/*outer=*/{0, 1}, /*thread=*/{1, 0}, /*element=*/{0, 1}}; } - auto contractionDims = *maybeContractionDims; + } + return {}; +} - // TODO: Relax this condition once distribution supports it. - if (contractionDims.k.size() != 1 || contractionDims.m.size() != 1 || - contractionDims.n.size() != 1) { - return std::nullopt; +MFMAAttr::SingleSubgroupLayout MFMAAttr::getBSingleSubgroupLayoutOrder() const { + switch (getIntrinsic().getValue()) { + case MFMAIntrinsic::F16_16x16x16_F32: + case MFMAIntrinsic::F16_32x32x8_F32: { + return {/*outer=*/{0, 1}, /*thread=*/{0, 1}, /*element=*/{1, 0}}; + } } + return {}; +} - unsigned mDim = contractionDims.m[0]; - unsigned nDim = contractionDims.n[0]; - unsigned kDim = contractionDims.k[0]; +MFMAAttr::SingleSubgroupLayout MFMAAttr::getCSingleSubgroupLayoutOrder() const { + switch (getIntrinsic().getValue()) { + case MFMAIntrinsic::F16_16x16x16_F32: + case MFMAIntrinsic::F16_32x32x8_F32: { + return {/*outer=*/{0, 1}, /*thread=*/{0, 1}, /*element=*/{1, 0}}; + } + } + return {}; +} - int64_t problemMSize = iterationBounds[mDim]; - int64_t problemNSize = iterationBounds[nDim]; - int64_t problemKSize = iterationBounds[kDim]; +//===----------------------------------------------------------------------===// +// MMA Schedule Attributes +//===----------------------------------------------------------------------===// - // Bail on dynamic shapes. Once better support for dynamic cases is in place, - // a separate helper should be added for dynamic and unaligned. - if (ShapedType::isDynamic(problemMSize) || - ShapedType::isDynamic(problemNSize) || - ShapedType::isDynamic(problemKSize)) { - return std::nullopt; +NestedLayoutAttr permuteAndCreateNestedLayout( + MLIRContext *context, ArrayRef permute, + SmallVector subgroupCount, + SmallVector subgroupOrder, SmallVector batchCount, + SmallVector batchOrder, SmallVector outerCount, + SmallVector outerOrder, SmallVector threadCount, + SmallVector threadOrder, SmallVector elementCount, + SmallVector elementOrder, ArrayRef subgroupBasis, + ArrayRef threadBasis) { + if (!isIdentityPermutation(permute)) { + applyPermutationToVector(subgroupCount, permute); + applyPermutationToVector(subgroupOrder, permute); + applyPermutationToVector(batchCount, permute); + applyPermutationToVector(batchOrder, permute); + applyPermutationToVector(outerCount, permute); + applyPermutationToVector(outerOrder, permute); + applyPermutationToVector(threadCount, permute); + applyPermutationToVector(threadOrder, permute); + applyPermutationToVector(elementCount, permute); + applyPermutationToVector(elementOrder, permute); } - if (inputTypes.size() != 3) { - return std::nullopt; - } + return NestedLayoutAttr::get(context, subgroupCount, subgroupOrder, + batchCount, batchOrder, outerCount, outerOrder, + threadCount, threadOrder, elementCount, + elementOrder, subgroupBasis, threadBasis); +} - Type lhsType = getElementTypeOrSelf(inputTypes[0]); - Type rhsType = getElementTypeOrSelf(inputTypes[1]); - Type accType = getElementTypeOrSelf(inputTypes[2]); +std::optional> +MMAScheduleAttr::getContractionLayout(vector::ContractionOp contractOp) const { + VectorContractOpInfo opInfo(contractOp); + if (opInfo.getOpKind() == VectorContractOpInfo::OpKind::UNKNOWN) + return std::nullopt; - for (Attribute a : mmaKinds.getValue()) { - auto mmaKind = dyn_cast(a); - if (!mmaKind) { - return std::nullopt; - } + auto [aM, bN] = *opInfo.getOperandMNIndex(); + auto [aK, bK] = *opInfo.getOperandKIndex(); + auto [cM, cN] = *opInfo.getResultMNIndex(); + SmallVector aPermute = {aM, aK}; + SmallVector bPermute = {bK, bN}; + SmallVector cPermute = {cM, cN}; + + auto mfmaAttr = llvm::cast(getIntrinsic()); + MLIRContext *context = getContext(); + + // Get the concrete nested layout for each matrix. Note that the struct + // MFMAAttr::SingleSubgroupLayout contains the partial layout for the + // canonical (M, K) x (K, N) -> (M, N) matmul form; while the specific + // contract op we are looking at right now may not be exactly in that form. + // So here we need to permute/transpose the canonical layout to match with + // the concrete contract op. + + // Note that no matter how we permute/transpose the input contraction problem, + // the way we view the hardware warps remain the same--that is, from the + // hardware's perspective, a single warp has the same warp ID no matter what + // part of the contraction it works on. Similarly here, we are delinearizing + // the linearized GPU hardware lane ID into a n-D concatenated logical + // warp+thread using the subgroup/thread basis, so the subgroup basis should + // remain the same for all A/B/C matrix. + SmallVector subgroupBasis = {getSubgroupMCount(), + getSubgroupNCount()}; + + // For threads though, we also need to make sure the basis is consistent + // across A, B, and C matrix. Though here we need to additionally think it + // from the matching of how the MMA intrinsics expect the treads organize and + // how we distribute the large input contraction problem to the threads. + // The intrinsics expect a certain 2-D (x, y) thread layout, where it's not + // guaranteed that y is always the fastest moving dimension. But when we + // distribute the large input contraction problem, we always associate the + // fastest moving dimension to the innermost thread ID dimension. Therefore, + // we need to "adjust" the intrinsic thread shape to from the slowest moving + // dimension to the fastest one. That is, to apply the corresponding order + // permutation vector. Because how the intrinsics are designed, the end result + // is actually we are basically guaranteed to see the same thread basis for A, + // B, and C matrix. But still.. + + // C matrix layout + MFMAAttr::SingleSubgroupLayout cCounts = + mfmaAttr.getCSingleSubgroupLayoutCount(); + MFMAAttr::SingleSubgroupLayout cOrders = + mfmaAttr.getCSingleSubgroupLayoutOrder(); + + SmallVector cThreadBasis = cCounts.thread; + applyPermutationToVector(cThreadBasis, cOrders.thread); + + auto cLayout = permuteAndCreateNestedLayout( + context, cPermute, + /*subgroupCount=*/{getSubgroupMCount(), getSubgroupNCount()}, + /*subgroupOrder=*/{0, 1}, + /*batchCount=*/{getSubgroupMTileCount(), getSubgroupNTileCount()}, + /*batchOrder=*/{0, 1}, /*outerCount=*/cCounts.outer, + /*outerOrder=*/cOrders.outer, /*threadCount=*/cCounts.thread, + /*threadOrder=*/cOrders.thread, + /*elementCount=*/cCounts.element, /*elementOrder=*/cOrders.element, + subgroupBasis, cThreadBasis); + + // A matrix layout + MFMAAttr::SingleSubgroupLayout aCounts = + mfmaAttr.getASingleSubgroupLayoutCount(); + MFMAAttr::SingleSubgroupLayout aOrders = + mfmaAttr.getASingleSubgroupLayoutOrder(); + + SmallVector aThreadBasis = aCounts.thread; + applyPermutationToVector(aThreadBasis, aOrders.thread); + + auto aLayout = permuteAndCreateNestedLayout( + context, aPermute, + /*subgroupCount=*/{getSubgroupMCount(), 1}, + /*subgroupOrder=*/{0, 1}, + /*batchCount=*/{getSubgroupMTileCount(), getSubgroupKTileCount()}, + /*batchOrder=*/{0, 1}, /*outerCount=*/aCounts.outer, + /*outerOrder=*/aOrders.outer, /*threadCount=*/aCounts.thread, + /*threadOrder=*/aOrders.thread, + /*elementCount=*/aCounts.element, /*elementOrder=*/aOrders.element, + subgroupBasis, aThreadBasis); + + // B matrix layout + MFMAAttr::SingleSubgroupLayout bCounts = + mfmaAttr.getBSingleSubgroupLayoutCount(); + MFMAAttr::SingleSubgroupLayout bOrders = + mfmaAttr.getBSingleSubgroupLayoutOrder(); + + SmallVector bThreadBasis = bCounts.thread; + applyPermutationToVector(bThreadBasis, bOrders.thread); + + auto bLayout = permuteAndCreateNestedLayout( + context, bPermute, + /*subgroupCount=*/{1, getSubgroupNCount()}, + /*subgroupOrder=*/{0, 1}, + /*batchCount=*/{getSubgroupKTileCount(), getSubgroupNTileCount()}, + /*batchOrder=*/{0, 1}, /*outerCount=*/bCounts.outer, + /*outerOrder=*/bOrders.outer, /*threadCount=*/bCounts.thread, + /*threadOrder=*/bOrders.thread, + /*elementCount=*/bCounts.element, /*elementOrder=*/bOrders.element, + subgroupBasis, bThreadBasis); + + return std::make_tuple(aLayout, bLayout, cLayout); +} - auto [typeA, typeB, typeC] = mmaKind.getABCElementTypes(); - if (typeA != lhsType || typeB != rhsType || typeC != accType) { - continue; - } +//===----------------------------------------------------------------------===// +// Attribute Registration +//===----------------------------------------------------------------------===// - auto [sizeM, sizeN, sizeK] = mmaKind.getMNKShape(); - if (problemMSize % sizeM != 0 || problemNSize % sizeN != 0 || - problemKSize % sizeK != 0) { - continue; - } - return mmaKind; - } - return std::nullopt; +void IREEGPUDialect::registerAttributes() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp.inc" // IWYU pragma: keep + >(); } -} // namespace mlir::iree_compiler +} // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h index 1a1cf4d391a7..2abca0165e80 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h @@ -4,9 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -//===- IREEGPUAttrs.h - Codegen GPU dialect attributes --------------------===// -//===----------------------------------------------------------------------===// - #ifndef IREE_COMPILER_CODEGEN_DIALECT_GPU_IREEGPUATTRS_H_ #define IREE_COMPILER_CODEGEN_DIALECT_GPU_IREEGPUATTRS_H_ @@ -24,30 +21,4 @@ #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h.inc" // clang-format on -namespace mlir::linalg { -class LinalgOp; -} // namespace mlir::linalg - -namespace mlir::iree_compiler { - -// Returns an MmaAttr from the array of mmaKinds compatible with the given -// structured operation description. The conditions for compatibility are -// -// 1. The iteration bounds are aligned on the shape of the mma operation. -// 2. The element types of |inputTypes| match with `[aType, bType, cType]` -// -// Returns the first successful match. -std::optional -getCompatibleMmaAttr(ArrayAttr mmaKinds, ArrayRef indexingMaps, - ArrayRef iterationBounds, TypeRange inputTypes); -// Helper for contractions. -std::optional getCompatibleMmaAttr(ArrayAttr mmaKinds, - vector::ContractionOp); -// Helper for linalg ops. Fails if the linalg op is not inferrable as a -// contraction op. -std::optional getCompatibleMmaAttr(ArrayAttr mmaKinds, - linalg::LinalgOp); - -} // namespace mlir::iree_compiler - #endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_IREEGPUATTRS_H_ diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td index e0dd48ae9c7a..7280d2324862 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td @@ -13,8 +13,9 @@ include "mlir/IR/OpBase.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/EnumAttr.td" -def IREEGPU_MmaArrayAttr : TypedArrayAttrBase; +//===----------------------------------------------------------------------===// +// Base MMA Vector Layout Attributes +//===----------------------------------------------------------------------===// class IREEGPU_MmaVectorLayoutAttr : AttrDef; -def IREEGPU_MFMA : IREEGPU_MmaVectorLayoutAttr<"MFMA", "MFMAIntrinsicAttr"> { +def IREEGPU_MFMAAttr : IREEGPU_MmaVectorLayoutAttr<"MFMA", "MFMAIntrinsicAttr"> { let mnemonic = "mfma_layout"; let cppNamespace = "::mlir::iree_compiler::IREE::GPU"; @@ -119,8 +120,70 @@ def IREEGPU_MFMA : IREEGPU_MmaVectorLayoutAttr<"MFMA", "MFMAIntrinsicAttr"> { ]; let extraClassDeclaration = !strconcat(baseExtraClassDeclaration, [{ - int64_t getBlockSize(); + int64_t getBlockSize() const; + + // Partial nested layout for an MMA intrinsic's matrix input/output inside + // a single subgroup. + // + // Note that this is just a container used by the following methods; it can + // contain both the shape and the order. + struct SingleSubgroupLayout { + SmallVector outer; + SmallVector thread; + SmallVector element; + }; + + // Returns the A/B/C matrix's partial nested layout shape inside a single + // subgroup. Shape at each outer/thread/element level is a 2-D value, + // following canonical matmul order--(M, K) for A, (K, N) for B, and + // (M, N) for C. + SingleSubgroupLayout getASingleSubgroupLayoutCount() const; + SingleSubgroupLayout getBSingleSubgroupLayoutCount() const; + SingleSubgroupLayout getCSingleSubgroupLayoutCount() const; + + // Returns the A/B/C matrix's partial nested layout order inside a single + // subgroup. Order at each outer/thread/element level is a 2-value + // permuation vector, following canonical matmul order--(M, K) for A, + // (K, N) for B, and (M, N) for C. + SingleSubgroupLayout getASingleSubgroupLayoutOrder() const; + SingleSubgroupLayout getBSingleSubgroupLayoutOrder() const; + SingleSubgroupLayout getCSingleSubgroupLayoutOrder() const; }]); } +//===----------------------------------------------------------------------===// +// MMA schedule Attributes +//===----------------------------------------------------------------------===// + +def IREEGPU_MmaScheduleAttr : AttrDef { + let mnemonic = "mma_schedule"; + let cppNamespace = "::mlir::iree_compiler::IREE::GPU"; + + string description = [{ + A schedule of MMA intrinsic instruction and various levels of tile sizes + to solve a specific contraction problem. + }]; + + + let parameters = (ins + "::mlir::iree_compiler::IREE::GPU::MmaAttr":$intrinsic, + "int64_t":$subgroup_m_count, + "int64_t":$subgroup_n_count, + "int64_t":$subgroup_m_tile_count, + "int64_t":$subgroup_n_tile_count, + "int64_t":$subgroup_k_tile_count + ); + + let assemblyFormat = "`<` struct(params) `>`"; + + let extraClassDeclaration = [{ + // Returns the A/B/C matrix concrete layout targeting |contractOp|. + ::std::optional<::std::tuple> + getContractionLayout(::mlir::vector::ContractionOp contractOp) const; + }]; +} + + #endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_IREEGPUATTRS diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_vectorize_nd_extract_tests.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_vectorize_nd_extract_tests.mlir index fc36a2414671..e2bf67ee7afe 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_vectorize_nd_extract_tests.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_vectorize_nd_extract_tests.mlir @@ -1,93 +1,91 @@ // RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmcpu-select-lowering-strategy, iree-llvmcpu-lower-executable-target)))' --split-input-file %s | FileCheck %s -module attributes {hal.device.targets = [#hal.device.target<"llvm-cpu", {executable_targets = [#hal.executable.target<"llvm-cpu", "system-elf-riscv_64", {cpu = "generic-rv64", cpu_features = "+m,+a,+f,+d,+v", data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128", native_vector_size = 64 : index, target_triple = "riscv64"}>]}>]} { - hal.executable private @main_dispatch_77 { - hal.executable.variant public @system_elf_riscv_64 target(<"llvm-cpu", "system-elf-riscv_64", {cpu = "generic-rv64", cpu_features = "+m,+a,+f,+d,+v", data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128", native_vector_size = 64 : index, target_triple = "riscv64"}>) { - hal.executable.export public @main_dispatch_77_generic_1x257x257x21 ordinal(0) layout(#hal.pipeline.layout, <1, storage_buffer>]>]>) { - ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index, %arg4: index): - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3, %arg4 - hal.return %x, %y, %z : index, index, index - } - builtin.module { - func.func @main_dispatch_77_generic_1x257x257x21() { - %c1115136 = arith.constant 1115136 : index - %c0 = arith.constant 0 : index - %cst = arith.constant 2.000000e+00 : f32 - %cst_0 = arith.constant 0.000000e+00 : f32 - %cst_1 = arith.constant 1.600000e+01 : f32 - %c1_i32 = arith.constant 1 : i32 - %c32_i32 = arith.constant 32 : i32 - %cst_2 = arith.constant 1.000000e+00 : f32 - %c0_i32 = arith.constant 0 : i32 - %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c1115136) flags(ReadOnly) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 33, 33, 21], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x33x33x21xf32> - %3 = tensor.empty() : tensor<1x257x257x21xf32> - %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%3 : tensor<1x257x257x21xf32>) { - ^bb0(%out: f32): - %5 = linalg.index 1 : index - %6 = linalg.index 0 : index - %7 = affine.apply affine_map<(d0, d1) -> (d0 + d1 * 257)>(%5, %6) - %8 = linalg.index 2 : index - %9 = linalg.index 3 : index - %10 = arith.index_cast %7 : index to i32 - %11 = arith.index_cast %8 : index to i32 - %12 = arith.uitofp %10 : i32 to f32 - %13 = arith.mulf %12, %cst : f32 - %14 = arith.addf %13, %cst_0 : f32 - %15 = arith.divf %14, %cst_1 : f32 - %16 = math.floor %15 : f32 - %17 = arith.subf %15, %16 : f32 - %18 = arith.fptosi %16 : f32 to i32 - %19 = arith.uitofp %11 : i32 to f32 - %20 = arith.mulf %19, %cst : f32 - %21 = arith.addf %20, %cst_0 : f32 - %22 = arith.divf %21, %cst_1 : f32 - %23 = math.floor %22 : f32 - %24 = arith.subf %22, %23 : f32 - %25 = arith.fptosi %23 : f32 to i32 - %26 = arith.addi %18, %c1_i32 : i32 - %27 = arith.cmpi slt, %18, %c0_i32 : i32 - %28 = arith.select %27, %c0_i32, %18 : i32 - %29 = arith.cmpi sgt, %18, %c32_i32 : i32 - %30 = arith.select %29, %c32_i32, %28 : i32 - %31 = arith.cmpi slt, %26, %c0_i32 : i32 - %32 = arith.select %31, %c0_i32, %26 : i32 - %33 = arith.cmpi sgt, %26, %c32_i32 : i32 - %34 = arith.select %33, %c32_i32, %32 : i32 - %35 = arith.index_cast %30 : i32 to index - %36 = arith.index_cast %34 : i32 to index - %37 = arith.addi %25, %c1_i32 : i32 - %38 = arith.cmpi slt, %25, %c0_i32 : i32 - %39 = arith.select %38, %c0_i32, %25 : i32 - %40 = arith.cmpi sgt, %25, %c32_i32 : i32 - %41 = arith.select %40, %c32_i32, %39 : i32 - %42 = arith.cmpi slt, %37, %c0_i32 : i32 - %43 = arith.select %42, %c0_i32, %37 : i32 - %44 = arith.cmpi sgt, %37, %c32_i32 : i32 - %45 = arith.select %44, %c32_i32, %43 : i32 - %46 = arith.index_cast %41 : i32 to index - %47 = arith.index_cast %45 : i32 to index - %extracted = tensor.extract %2[%c0, %35, %46, %9] : tensor<1x33x33x21xf32> - %extracted_3 = tensor.extract %2[%c0, %35, %47, %9] : tensor<1x33x33x21xf32> - %extracted_4 = tensor.extract %2[%c0, %36, %46, %9] : tensor<1x33x33x21xf32> - %extracted_5 = tensor.extract %2[%c0, %36, %47, %9] : tensor<1x33x33x21xf32> - %48 = arith.subf %cst_2, %24 : f32 - %49 = arith.mulf %extracted, %48 : f32 - %50 = arith.mulf %extracted_3, %24 : f32 - %51 = arith.addf %49, %50 : f32 - %52 = arith.mulf %extracted_4, %48 : f32 - %53 = arith.mulf %extracted_5, %24 : f32 - %54 = arith.addf %52, %53 : f32 - %55 = arith.subf %cst_2, %17 : f32 - %56 = arith.mulf %51, %55 : f32 - %57 = arith.mulf %54, %17 : f32 - %58 = arith.addf %56, %57 : f32 - linalg.yield %58 : f32 - } -> tensor<1x257x257x21xf32> - flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0, 0], sizes = [1, 257, 257, 21], strides = [1, 1, 1, 1] : tensor<1x257x257x21xf32> -> !flow.dispatch.tensor> - return - } +hal.executable private @main_dispatch_77 { + hal.executable.variant public @system_elf_riscv_64 target(<"llvm-cpu", "system-elf-riscv_64", {cpu = "generic-rv64", cpu_features = "+m,+a,+f,+d,+v", data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128", native_vector_size = 64 : index, target_triple = "riscv64"}>) { + hal.executable.export public @main_dispatch_77_generic_1x257x257x21 ordinal(0) layout(#hal.pipeline.layout, <1, storage_buffer>]>]>) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index, %arg4: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3, %arg4 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @main_dispatch_77_generic_1x257x257x21() { + %c1115136 = arith.constant 1115136 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 2.000000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + %cst_1 = arith.constant 1.600000e+01 : f32 + %c1_i32 = arith.constant 1 : i32 + %c32_i32 = arith.constant 32 : i32 + %cst_2 = arith.constant 1.000000e+00 : f32 + %c0_i32 = arith.constant 0 : i32 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c1115136) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 33, 33, 21], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x33x33x21xf32> + %3 = tensor.empty() : tensor<1x257x257x21xf32> + %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%3 : tensor<1x257x257x21xf32>) { + ^bb0(%out: f32): + %5 = linalg.index 1 : index + %6 = linalg.index 0 : index + %7 = affine.apply affine_map<(d0, d1) -> (d0 + d1 * 257)>(%5, %6) + %8 = linalg.index 2 : index + %9 = linalg.index 3 : index + %10 = arith.index_cast %7 : index to i32 + %11 = arith.index_cast %8 : index to i32 + %12 = arith.uitofp %10 : i32 to f32 + %13 = arith.mulf %12, %cst : f32 + %14 = arith.addf %13, %cst_0 : f32 + %15 = arith.divf %14, %cst_1 : f32 + %16 = math.floor %15 : f32 + %17 = arith.subf %15, %16 : f32 + %18 = arith.fptosi %16 : f32 to i32 + %19 = arith.uitofp %11 : i32 to f32 + %20 = arith.mulf %19, %cst : f32 + %21 = arith.addf %20, %cst_0 : f32 + %22 = arith.divf %21, %cst_1 : f32 + %23 = math.floor %22 : f32 + %24 = arith.subf %22, %23 : f32 + %25 = arith.fptosi %23 : f32 to i32 + %26 = arith.addi %18, %c1_i32 : i32 + %27 = arith.cmpi slt, %18, %c0_i32 : i32 + %28 = arith.select %27, %c0_i32, %18 : i32 + %29 = arith.cmpi sgt, %18, %c32_i32 : i32 + %30 = arith.select %29, %c32_i32, %28 : i32 + %31 = arith.cmpi slt, %26, %c0_i32 : i32 + %32 = arith.select %31, %c0_i32, %26 : i32 + %33 = arith.cmpi sgt, %26, %c32_i32 : i32 + %34 = arith.select %33, %c32_i32, %32 : i32 + %35 = arith.index_cast %30 : i32 to index + %36 = arith.index_cast %34 : i32 to index + %37 = arith.addi %25, %c1_i32 : i32 + %38 = arith.cmpi slt, %25, %c0_i32 : i32 + %39 = arith.select %38, %c0_i32, %25 : i32 + %40 = arith.cmpi sgt, %25, %c32_i32 : i32 + %41 = arith.select %40, %c32_i32, %39 : i32 + %42 = arith.cmpi slt, %37, %c0_i32 : i32 + %43 = arith.select %42, %c0_i32, %37 : i32 + %44 = arith.cmpi sgt, %37, %c32_i32 : i32 + %45 = arith.select %44, %c32_i32, %43 : i32 + %46 = arith.index_cast %41 : i32 to index + %47 = arith.index_cast %45 : i32 to index + %extracted = tensor.extract %2[%c0, %35, %46, %9] : tensor<1x33x33x21xf32> + %extracted_3 = tensor.extract %2[%c0, %35, %47, %9] : tensor<1x33x33x21xf32> + %extracted_4 = tensor.extract %2[%c0, %36, %46, %9] : tensor<1x33x33x21xf32> + %extracted_5 = tensor.extract %2[%c0, %36, %47, %9] : tensor<1x33x33x21xf32> + %48 = arith.subf %cst_2, %24 : f32 + %49 = arith.mulf %extracted, %48 : f32 + %50 = arith.mulf %extracted_3, %24 : f32 + %51 = arith.addf %49, %50 : f32 + %52 = arith.mulf %extracted_4, %48 : f32 + %53 = arith.mulf %extracted_5, %24 : f32 + %54 = arith.addf %52, %53 : f32 + %55 = arith.subf %cst_2, %17 : f32 + %56 = arith.mulf %51, %55 : f32 + %57 = arith.mulf %54, %17 : f32 + %58 = arith.addf %56, %57 : f32 + linalg.yield %58 : f32 + } -> tensor<1x257x257x21xf32> + flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0, 0], sizes = [1, 257, 257, 21], strides = [1, 1, 1, 1] : tensor<1x257x257x21xf32> -> !flow.dispatch.tensor> + return } } } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel index b335b716d560..962238bc4d1c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel @@ -124,6 +124,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Common", "//compiler/src/iree/compiler/Codegen/Common:TransformDialectInterpreterPass", "//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses", + "//compiler/src/iree/compiler/Codegen/Common/GPU:GPUHeuristics", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect", "//compiler/src/iree/compiler/Codegen/Interfaces:UKernelOpInterface", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt index 6c294a0b7550..75dda889a3b1 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt @@ -167,6 +167,7 @@ iree_cc_library( MLIRVectorTransforms iree::compiler::Codegen::Common iree::compiler::Codegen::Common::GPU::CommonGPUPasses + iree::compiler::Codegen::Common::GPU::GPUHeuristics iree::compiler::Codegen::Common::TransformDialectInterpreterPass iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index 533027132041..c3fc0ab63e8e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -11,22 +11,25 @@ #include #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "iree/compiler/Codegen/Common/GPU/GPUHeuristics.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h" #include "iree/compiler/Codegen/Interfaces/UKernelOpInterface.h" #include "iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Codegen/Utils/LinalgOpInfo.h" -#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Codegen/Utils/Utils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" @@ -286,97 +289,107 @@ getTensorCorePipeline(Type elementType) { static LogicalResult setVectorDistributionConfig(mlir::FunctionOpInterface entryPoint, linalg::LinalgOp op, const TargetInfo &targetInfo) { - if (!linalg::isaContractionOpInterface(op) || op.getNumParallelLoops() < 2) { - return failure(); - } - - FailureOr maybeMmaTypes = getSupportedMmaTypes(entryPoint); - if (failed(maybeMmaTypes)) { + if (!isMatmulOrBatchMatmul(op)) { return failure(); } - // Currently only applies when there is a supported mma operation we can map - // to. - // TODO: Reuse this pipeline for SIMT based flows. - std::optional maybeMmaAttr = - getCompatibleMmaAttr(*maybeMmaTypes, op); - if (!maybeMmaAttr) { + FailureOr mmaKinds = getSupportedMmaTypes(entryPoint); + if (failed(mmaKinds)) { return failure(); } - IREE::GPU::MmaAttr mmaAttr = *maybeMmaAttr; - // This pipeline needs to know the subgroup size to know how to distribute to - // virtual lane ids. + // This pipeline needs to know the subgroup size for distributing to virtual + // lane IDs. if (targetInfo.supportedSubgroupSizes.empty()) { return failure(); } - - int64_t numLoops = op.getNumLoops(); - - SmallVector workgroupTileSizes(numLoops, 0); - - // Arbitrary starter heuristics. - int64_t maxMSize = 128; - int64_t maxNSize = 128; - int64_t maxKSize = 32; + const int64_t subgroupSize = targetInfo.supportedSubgroupSizes.front(); SmallVector bounds = op.getStaticLoopRanges(); FailureOr contractionDims = mlir::linalg::inferContractionDims(op); assert(succeeded(contractionDims) && "Could not infer contraction dims"); + // TODO: Relax this condition to strictly alignment requirements. + if (contractionDims->k.size() != 1 || contractionDims->m.size() != 1 || + contractionDims->n.size() != 1) { + return failure(); + } + int64_t mDim = contractionDims->m[0]; int64_t nDim = contractionDims->n[0]; int64_t kDim = contractionDims->k[0]; - int64_t problemMSize = bounds[mDim]; - int64_t problemNSize = bounds[nDim]; - int64_t problemKSize = bounds[kDim]; + Value lhs = op.getDpsInputOperand(0)->get(); + Value rhs = op.getDpsInputOperand(1)->get(); + Value init = op.getDpsInitOperand(0)->get(); - auto getTileSize = [](int64_t problemSize, int64_t maxSize) { - int64_t tileSize = maxSize; - // The static verification that the linalg op is statically aligned to the - // particular mma type guarantees that this will be at least the minimum - // tile size. - // TODO: Allow unaligned and dynamic cases once masking is supported by - // distribution. - while (problemSize % tileSize != 0) - tileSize /= 2; - return tileSize; - }; + Type lhsElemType = getElementTypeOrSelf(lhs); + Type rhsElemType = getElementTypeOrSelf(rhs); + Type initElemType = getElementTypeOrSelf(init); + + GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim], + lhsElemType, rhsElemType, initElemType}; - int64_t mTile = getTileSize(problemMSize, maxMSize); - int64_t nTile = getTileSize(problemNSize, maxNSize); - int64_t kTile = getTileSize(problemKSize, maxKSize); + auto mmaAttrs = llvm::to_vector(mmaKinds->getAsRange()); + SmallVector intrinsics; + intrinsics.reserve(mmaKinds->size()); + for (auto mma : mmaAttrs) { + auto [mSize, nSize, kSize] = mma.getMNKShape(); + auto [aType, bType, cType] = mma.getABCElementTypes(); + intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); + } - // Get the shape of the [warp-synchronous] mma operation. - auto [minMSize, minNSize, minKSize] = mmaAttr.getMNKShape(); + // Note that the following heuristic seeds are just placeholder values. + // We need to clean it up and make it adjusting to different targets. + // See https://github.com/openxla/iree/issues/16341 for details. + GPUMMAHeuristicSeeds seeds{/*bestSubgroupCountPerWorkgroup=*/4, + /*bestMNTileCountPerSubgroup=*/8, + /*bestKTileCountPerSubgroup=*/2}; - // HACK: This is a single workgroup... - mTile = std::min(mTile, minMSize); - nTile = std::min(nTile, minNSize); + std::optional schedule = + deduceMMASchedule(problem, intrinsics, seeds); + if (!schedule) { + return failure(); + } - // Following the LLVMGPU convention of keeping all of the tile sizes in one - // list. - workgroupTileSizes[mDim] = mTile; - workgroupTileSizes[nDim] = nTile; - workgroupTileSizes[kDim] = kTile; + std::array workgroupSize{schedule->nWarpCount * subgroupSize, + schedule->mWarpCount, 1}; + + SmallVector workgroupTileSizes(op.getNumLoops(), 0); + // Tile all batch dimensions with unit size. + for (int64_t batch : contractionDims->batch) { + workgroupTileSizes[batch] = 1; + } + // Compute the M/N dimension tile size by multiply subgroup information. + workgroupTileSizes[mDim] = + schedule->mWarpCount * schedule->mTileCount * schedule->mSize; + workgroupTileSizes[nDim] = + schedule->nWarpCount * schedule->nTileCount * schedule->nSize; + + // Follow the LLVMGPU convention of keeping all of the tile sizes in one list. + workgroupTileSizes[kDim] = schedule->kTileCount * schedule->kSize; TileSizesListType tileSizes; - // HACK: need proper heuristics for workgroup size, but for now the pipeline - // is single subgroup. tileSizes.push_back(workgroupTileSizes); - SmallVector workgroupSize(3, 1); // (X, Y, Z) - int64_t subgroupSize = targetInfo.supportedSubgroupSizes.front(); - workgroupSize[0] = subgroupSize; + // Attach the MMA schedule as an attribute to the entry point export function + // for later access in the pipeline. + MLIRContext *context = op.getContext(); + auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get( + context, mmaAttrs[schedule->index], schedule->mWarpCount, + schedule->nWarpCount, schedule->mTileCount, schedule->nTileCount, + schedule->kTileCount); + SmallVector attrs; + attrs.emplace_back( + StringAttr::get(context, IREE::GPU::MMAScheduleAttr::getMnemonic()), + scheduleAttr); + auto configDict = DictionaryAttr::get(context, attrs); return setOpConfigAndEntryPointFnTranslation( entryPoint, op, tileSizes, IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUVectorDistribute, - workgroupSize, /*subgroupSize=*/subgroupSize); - - return success(); + workgroupSize, subgroupSize, configDict); } static LogicalResult setContractConfig(mlir::FunctionOpInterface entryPoint, diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp index 12dc371ea864..cc3701b487c9 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp @@ -10,18 +10,20 @@ #include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h" #include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h" #include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/LLVMGPU/PassDetail.h" #include "iree/compiler/Codegen/LLVMGPU/Passes.h" -#include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" @@ -29,6 +31,8 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#define DEBUG_TYPE "iree-llvmgpu-vector-distribute" + using LayoutDimension = mlir::iree_compiler::IREE::VectorExt::LayoutDimension; using LayoutDimensionAttr = mlir::iree_compiler::IREE::VectorExt::LayoutDimensionAttr; @@ -49,10 +53,13 @@ namespace { // setting for other problems like reductions is TODO. class ContractionVectorLayoutOptions : public VectorLayoutOptions { public: - ContractionVectorLayoutOptions(Operation *root, ArrayAttr types, - ArrayRef workgroupSize, Value laneId) - : VectorLayoutOptions(root), mmaTypes(types), - workgroupSize(workgroupSize), patterns(root->getContext()) { + ContractionVectorLayoutOptions(Operation *root, + ArrayRef workgroupSize, + IREE::GPU::MMAScheduleAttr schedule, + Value laneId, bool printLayout) + : VectorLayoutOptions(root, /*fullConversion=*/!printLayout), + workgroupSize(workgroupSize), schedule(schedule), + printLayout(printLayout), patterns(root->getContext()) { populateGPUDistributionPatterns(patterns); populateGPUDistributionLayoutAttrPatterns(laneId, patterns); populateGPUDistributeNestedLayoutAttrPatterns(laneId, patterns); @@ -80,24 +87,33 @@ class ContractionVectorLayoutOptions : public VectorLayoutOptions { void setContractionAnchor(MLIRContext *context, VectorLayoutAnalysis &analysis, vector::ContractionOp contract) { - std::optional maybeMmaType = - getCompatibleMmaAttr(mmaTypes, contract); // TODO: Add SIMT fallback. - assert(maybeMmaType && "incompatible contraction op"); + assert(schedule && "incompatible contraction op"); - auto mmaType = *maybeMmaType; - auto maybeLayouts = mmaType.getContractionLayout(contract); - assert(maybeMmaType && "mma layout type must not be opaque"); + auto layouts = schedule.getContractionLayout(contract); + assert(layouts && "cannot get concrete layout for contraction"); - auto [aLayout, bLayout, cLayout] = *maybeLayouts; + auto [aLayout, bLayout, cLayout] = *layouts; analysis.setAnchor(contract.getLhs(), aLayout); analysis.setAnchor(contract.getRhs(), bLayout); analysis.setAnchor(contract.getAcc(), cLayout); analysis.setAnchor(contract.getResult(), cLayout); + contract->setAttr("iree.amdgpu.mfma", schedule.getIntrinsic()); + if (printLayout) { + llvm::outs() << "contract A vector layout: " << aLayout << "\n"; + llvm::outs() << "contract B vector layout: " << bLayout << "\n"; + llvm::outs() << "contract C vector layout: " << cLayout << "\n"; + } + LLVM_DEBUG({ + llvm::dbgs() << "chosen a layout: " << aLayout << "\n"; + llvm::dbgs() << "chosen b layout: " << bLayout << "\n"; + llvm::dbgs() << "chosen c layout: " << cLayout << "\n"; + llvm::dbgs() << "anchor set on contract: " << contract << "\n"; + }); - if (isa(mmaType)) { + if (isa(schedule.getIntrinsic())) { if (!populatedMfma) { - populateAMDGPUDistributionPatterns(patterns); + populateGPUDistributeNestedLayoutContractAMDGPUPatterns(patterns); populatedMfma = true; } } else { @@ -134,6 +150,23 @@ class ContractionVectorLayoutOptions : public VectorLayoutOptions { void setTransferReadAnchor(MLIRContext *context, VectorLayoutAnalysis &analysis, vector::TransferReadOp transfer) { + + // Get the forward slice of the transfer to approximate whether it will take + // the layout of a contraction instead. Transfer_read ops used directly by a + // contraction (i.e. without a copy to shared memory in between) should take + // the layout of the contraction op. This is common for cases where the + // initial values of the accumulator in a linalg.matmul is read from memory + // instead of just being a zerofill. + SetVector forwardSlice; + ForwardSliceOptions options; + getForwardSlice(transfer.getResult(), &forwardSlice, options); + + if (llvm::any_of(forwardSlice, [](Operation *op) { + return llvm::isa(op); + })) { + return; + } + // TODO: Support masking. if (transfer.getMask()) { return; @@ -254,10 +287,16 @@ class ContractionVectorLayoutOptions : public VectorLayoutOptions { context, subgroupCounts, order, batchSizes, order, outerSizes, order, threadCounts, order, elementSizes, order, subgroupBasis, threadBasis); analysis.setAnchor(transfer.getResult(), layout); + if (printLayout) { + llvm::outs() << "transfer '" << transfer << "' vector layout: " << layout + << "\n"; + } } - ArrayAttr mmaTypes; SmallVector workgroupSize; + IREE::GPU::MMAScheduleAttr schedule; + // Whether to print the chosen layout for testing purposes + bool printLayout; bool populatedMfma = false; RewritePatternSet patterns; @@ -276,38 +315,8 @@ struct LLVMGPUVectorDistributePass void runOnOperation() override { auto func = getOperation(); - FailureOr maybeSupportedTypes = - getSupportedMmaTypes(llvm::cast(func)); - // TODO: Support FMA fallback. Contractions always benefit from an anchoring - // layout because they do implicit shuffles, or broadcast when loading data. - if (failed(maybeSupportedTypes)) { - func->emitError() << "Failed to collect the set of supported mma types " - "for vector distribution"; - return signalPassFailure(); - } - - std::optional maybeSubgroupSize = std::nullopt; - if (func->hasAttr("subgroup_size")) { - maybeSubgroupSize = - llvm::cast(func->getAttr("subgroup_size")).getInt(); - } else { - maybeSubgroupSize = getSubgroupSize(func); - } - - if (!maybeSubgroupSize) { - func.emitError() << "subgroup size required for vector distribution"; - return signalPassFailure(); - } - - OpBuilder builder(func); - builder.setInsertionPointToStart(&func.getFunctionBody().front()); - SmallVector threadGrid = { - builder.createOrFold(func.getLoc(), gpu::Dimension::x), - builder.createOrFold(func.getLoc(), gpu::Dimension::y), - builder.createOrFold(func.getLoc(), - gpu::Dimension::z)}; std::array workgroupSize; - if (func->hasAttr("subgroup_size")) { + if (func->hasAttr("workgroup_size")) { auto tmpSizes = llvm::cast(func->getAttr("workgroup_size")).getValue(); for (auto [i, size] : llvm::enumerate(tmpSizes)) { @@ -316,43 +325,40 @@ struct LLVMGPUVectorDistributePass } else { workgroupSize = getWorkgroupSize(func); } + + llvm::StringLiteral scheduleAttrName = + IREE::GPU::MMAScheduleAttr::getMnemonic(); + auto scheduleAttr = + func->getAttrOfType(scheduleAttrName); + if (!scheduleAttr) { + DictionaryAttr configDict = getTranslationInfo(func).getConfiguration(); + scheduleAttr = dyn_cast_or_null( + configDict.get(scheduleAttrName)); + } + AffineExpr x, y, z; bindSymbols(func.getContext(), x, y, z); // Construct the expression for linearizing the thread indices. AffineExpr linearId = x + workgroupSize[0] * y + workgroupSize[1] * workgroupSize[0] * z; - AffineExpr laneId = linearId % *maybeSubgroupSize; - - // This all needs some kind of simplification; the arithmetic it produces - // doest not get folded away as nicely as it could. - AffineMap idMap = AffineMap::getMultiDimIdentityMap(2, func.getContext()); - - // Clamp the thread indices to the workgroup sizes. - OpFoldResult c0 = - builder.createOrFold(func.getLoc(), 0); - threadGrid[0] = affine::makeComposedFoldedAffineMax( - builder, func.getLoc(), idMap, {threadGrid[0], c0}); - threadGrid[1] = affine::makeComposedFoldedAffineMax( - builder, func.getLoc(), idMap, {threadGrid[1], c0}); - threadGrid[2] = affine::makeComposedFoldedAffineMax( - builder, func.getLoc(), idMap, {threadGrid[2], c0}); - - OpFoldResult dimX = builder.getIndexAttr(workgroupSize[0] - 1); - OpFoldResult dimY = builder.getIndexAttr(workgroupSize[1] - 1); - OpFoldResult dimZ = builder.getIndexAttr(workgroupSize[2] - 1); - threadGrid[0] = affine::makeComposedFoldedAffineMin( - builder, func.getLoc(), idMap, {threadGrid[0], dimX}); - threadGrid[1] = affine::makeComposedFoldedAffineMin( - builder, func.getLoc(), idMap, {threadGrid[1], dimY}); - threadGrid[2] = affine::makeComposedFoldedAffineMin( - builder, func.getLoc(), idMap, {threadGrid[2], dimZ}); - Value laneVal = affine::makeComposedAffineApply(builder, func.getLoc(), - laneId, threadGrid); - - ContractionVectorLayoutOptions options(func, *maybeSupportedTypes, - workgroupSize, laneVal); - // TODO: This should return failure when distribution fails for any op. - distributeVectorOps(func, options.getPatterns(), options); + + OpBuilder builder(func); + builder.setInsertionPointToStart(&func.getFunctionBody().front()); + SmallVector threadGrid = { + builder.createOrFold(func.getLoc(), gpu::Dimension::x), + builder.createOrFold(func.getLoc(), gpu::Dimension::y), + builder.createOrFold(func.getLoc(), + gpu::Dimension::z)}; + + Value linearThreadIdVal = affine::makeComposedAffineApply( + builder, func.getLoc(), linearId, threadGrid); + + ContractionVectorLayoutOptions options(func, workgroupSize, scheduleAttr, + linearThreadIdVal, testLayout); + if (failed(distributeVectorOps(func, options.getPatterns(), options))) { + func->emitOpError() << "failed to distribute"; + return signalPassFailure(); + } } }; } // namespace diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 70e465c2ec01..d38027c4285c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -134,7 +134,6 @@ static void addGPUVectorizationPasses(OpPassManager &pm) { options.vectorizeGatherAccesses = true; options.enableCleanup = false; options.foldCastIntoContract = true; - options.maxVectorSize = 4096; pm.addNestedPass(createGenericVectorizationPass(options)); pm.addNestedPass(createOptimizeTensorInsertExtractSlicesPass()); pm.addNestedPass(createCanonicalizerPass()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td index 2b3dc7cf4362..dc6b452755a2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td @@ -93,6 +93,11 @@ def LLVMGPUVectorDistribute : InterfacePass<"iree-llvmgpu-vector-distribute", "mlir::FunctionOpInterface"> { let summary = "Pass to distribute vectorized functions."; let constructor = "mlir::iree_compiler::createLLVMGPUVectorDistribute()"; + let options = [ + Option<"testLayout", "test-layout", "bool", /*default=*/"false", + "Annotate vector ops with deduced layouts without real conversion " + "for testing purposes"> + ]; } def LLVMGPUVectorLowering : diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index 63f9086e4a03..22ec40e2d594 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp @@ -1501,9 +1501,10 @@ void transform_dialect::PackSharedMemoryAllocOp::getEffects( transform::modifiesPayload(effects); } -class TestVectorLayoutOptions : public VectorLayoutOptions { +class TransformVectorLayoutOptions : public VectorLayoutOptions { public: - TestVectorLayoutOptions(Operation *root) : VectorLayoutOptions(root) {} + TransformVectorLayoutOptions(Operation *root, bool fullConversion) + : VectorLayoutOptions(root, fullConversion) {} void setAnchorOps(VectorLayoutAnalysis &analysis) override { setAnchorOpsFromAttributes(analysis, root); @@ -1515,7 +1516,7 @@ transform_dialect::AMDGPUDistributeVectorsOp::applyToOne( transform::TransformRewriter &rewriter, mlir::FunctionOpInterface target, transform::ApplyToEachResultList &results, transform::TransformState &state) { - TestVectorLayoutOptions options(target); + TransformVectorLayoutOptions options(target, !getTestConversion()); RewritePatternSet patterns(target.getContext()); rewriter.setInsertionPointToStart(&target.getFunctionBody().front()); @@ -1527,7 +1528,9 @@ transform_dialect::AMDGPUDistributeVectorsOp::applyToOne( populateGPUReductionDistributionPatterns(patterns); populateGPUDistributeNestedLayoutAttrPatterns(laneId, patterns); populateAMDGPUDistributionPatterns(patterns); - distributeVectorOps(target, patterns, options); + if (failed(distributeVectorOps(target, patterns, options))) { + return emitDefaultSilenceableFailure(target); + } return DiagnosedSilenceableFailure::success(); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td index 801f1f8ceea3..77280e452620 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td @@ -706,10 +706,14 @@ def AMDGPUDistributeVectorsOp : This transform does not consume the target handle and always return success. }]; - let arguments = (ins TransformHandleTypeInterface:$target); + let arguments = (ins TransformHandleTypeInterface:$target, + UnitAttr:$test_conversion); let results = (outs); - let assemblyFormat = [{ $target attr-dict `:` type($target)}]; + let assemblyFormat = [{ + $target (`test_conversion` $test_conversion^)? + attr-dict `:` type($target) + }]; let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect"; let extraClassDeclaration = [{ diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel index a381431ff159..7f7aba6fd49a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel @@ -65,8 +65,8 @@ iree_lit_test_suite( "transform_vector_to_mma.mlir", "transpose_pipeline_test.mlir", "ukernel_pipeline_transform.mlir", - "vector_distribute.mlir", - "vector_distribution_pipeline_rocm.mlir", + "vector_distribute_conversion.mlir", + "vector_distribute_layout.mlir", "vector_lowering.mlir", "vector_to_gpu.mlir", "workgroup_specialization_pipeline_test.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt index f50179cef57b..6ecee2e0293f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt @@ -61,8 +61,8 @@ iree_lit_test_suite( "transform_vector_to_mma.mlir" "transpose_pipeline_test.mlir" "ukernel_pipeline_transform.mlir" - "vector_distribute.mlir" - "vector_distribution_pipeline_rocm.mlir" + "vector_distribute_conversion.mlir" + "vector_distribute_layout.mlir" "vector_lowering.mlir" "vector_to_gpu.mlir" "workgroup_specialization_pipeline_test.mlir" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel index dc64ddbfca3d..ece13b17239f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel @@ -18,6 +18,8 @@ iree_lit_test_suite( name = "lit", srcs = enforce_glob( [ + "config_vector_distribute.mlir", + "pipeline_vector_distribute.mlir", "pipeline_warp_reduction.mlir", ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt index d56c2598f975..8861e4200a58 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt @@ -14,6 +14,8 @@ iree_lit_test_suite( NAME lit SRCS + "config_vector_distribute.mlir" + "pipeline_vector_distribute.mlir" "pipeline_warp_reduction.mlir" TOOLS FileCheck diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribution_pipeline_rocm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir similarity index 63% rename from compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribution_pipeline_rocm.mlir rename to compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir index ed4469098176..4f739e19c9af 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribution_pipeline_rocm.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir @@ -1,5 +1,9 @@ // RUN: iree-opt --split-input-file --iree-codegen-llvmgpu-use-vector-distribution \ -// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-select-lowering-strategy, iree-llvmgpu-lower-executable-target)))" %s | FileCheck %s +// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-select-lowering-strategy)))" %s | FileCheck %s + +// TODO: This test is still using the legacy LLVMGPU kernel config. This needs +// to be migrated to the rocdl heuristics, but for now is just physically +// located here. #pipeline_layout = #hal.pipeline.layout, - #iree_gpu.mfma_layout] + mma_intrinsics = [] }>) { hal.executable.export @matmul_256x256x256 layout(#pipeline_layout) { ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index): @@ -20,8 +23,8 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", { } builtin.module { func.func @matmul_256x256x256() { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> @@ -37,15 +40,6 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", { } } -// Basic pipeline test to make sure it generates the instructions we expect. - -// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info -// CHECK-LABEL: hal.executable.export public @matmul_256x256x256 -// CHECK-SAME: subgroup_size = 64 -// CHECK-SAME: translation_info = #[[$TRANSLATION]] -// CHECK-SAME: workgroup_size = [64 : index, 1 : index, 1 : index] -// CHECK-LABEL: func.func @matmul_256x256x256 -// CHECK-COUNT: scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}}) -> (vector<1x1x4xf32>) -// CHECK-COUNT-2: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> -// CHECK: scf.yield %{{.*}} : vector<1x1x4xf32> -// CHECK-COUNT-4: vector.store {{.*}} : memref<256x256xf32, #hal.descriptor_type>, vector<1xf32> +// Check that we do not use the distribute pipeline if there are no supported +// intrinsics. +// CHECK-NOT: iree_codegen.translation_info, + #hal.descriptor_set.binding<1, storage_buffer> + ]> +]> +hal.executable @matmul_256x256x256 { +hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", { + target_arch = "gfx940", + mma_intrinsics = [#iree_gpu.mfma_layout, + #iree_gpu.mfma_layout] + }>) { + hal.executable.export @matmul_256x256x256 layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @matmul_256x256x256() { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf16> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf16> + %5 = tensor.empty() : tensor<256x256xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32> + %7 = linalg.matmul ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor> + return + } + } +} +} + +// Basic pipeline test to make sure it generates the instructions we expect. + +// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info, +// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 2, subgroup_m_tile_count = 2, subgroup_n_tile_count = 4, subgroup_k_tile_count = 2> + +// CHECK-LABEL: hal.executable.export public @matmul_256x256x256 +// CHECK-SAME: subgroup_size = 64 +// CHECK-SAME: translation_info = #[[$TRANSLATION]] +// CHECK-SAME: workgroup_size = [128 : index, 2 : index, 1 : index] + +// CHECK-LABEL: func.func @matmul_256x256x256 +// CHECK: scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}}) -> (vector<2x4x1x1x1x4xf32>) +// Each subgroup handles 2 * 4 tiles, and for each tile we accumulate 2 times +// along the K dimension. So in total 16 mfma ops. +// CHECK-COUNT-16: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> +// CHECK: scf.yield %{{.+}} : vector<2x4x1x1x1x4xf32> +// CHECK-COUNT-8: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf32>, memref<256x256xf32, #hal.descriptor_type> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir index a3f5d7072007..8710f7696a57 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_contraction_distribution.mlir @@ -43,7 +43,7 @@ builtin.module attributes { transform.with_named_sequence } { } transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op - transform.iree.amdgpu_distribute_vectors %top_level_func : !transform.any_op + transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : !transform.any_op transform.yield } } @@ -84,7 +84,7 @@ builtin.module attributes { transform.with_named_sequence } { } transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op - transform.iree.amdgpu_distribute_vectors %top_level_func : !transform.any_op + transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : !transform.any_op transform.yield } } @@ -132,7 +132,7 @@ builtin.module attributes { transform.with_named_sequence } { } transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op - transform.iree.amdgpu_distribute_vectors %top_level_func : !transform.any_op + transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : !transform.any_op transform.yield } } @@ -179,7 +179,7 @@ builtin.module attributes { transform.with_named_sequence } { } transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op - transform.iree.amdgpu_distribute_vectors %top_level_func : !transform.any_op + transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : !transform.any_op transform.yield } } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma_transform_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma_transform_spec.mlir index 0eaefd1a3738..717d2facabad 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma_transform_spec.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma_transform_spec.mlir @@ -158,7 +158,7 @@ module attributes { transform.with_named_sequence } { transform.iree.set_contraction_layout_attributes %contracts, %layout16x16x16 : !transform.any_op, !transform.any_param %distribute_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op - transform.iree.amdgpu_distribute_vectors %distribute_func : !transform.any_op + transform.iree.amdgpu_distribute_vectors %distribute_func test_conversion : !transform.any_op transform.apply_patterns to %distribute_func { transform.apply_patterns.canonicalization diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/conv_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/conv_pipeline_test.mlir index bc4bc935325f..554e9fbf7770 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/conv_pipeline_test.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/conv_pipeline_test.mlir @@ -2,32 +2,29 @@ // RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-select-lowering-strategy, iree-llvmgpu-lower-executable-target,canonicalize)))' \ // RUN: %s | FileCheck %s -#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_60"}>]}> #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_60"}> #pipeline_layout = #hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]> -module attributes {hal.device.targets = [#device_target_cuda]} { - hal.executable private @conv2d_1x230x230x3_7x7x3x64_dispatch_0 { - hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { - hal.executable.export public @conv2d_1x230x230x3_7x7x3x64 ordinal(0) layout(#pipeline_layout) { - ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index): - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7 - hal.return %x, %y, %z : index, index, index - } - builtin.module { - func.func @conv2d_1x230x230x3_7x7x3x64() { - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f32 - %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 230, 230, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x230x230x3xf32> - %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [7, 7, 3, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<7x7x3x64xf32> - %5 = tensor.empty() : tensor<1x112x112x64xf32> - %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> - %7 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%3, %4 : tensor<1x230x230x3xf32>, tensor<7x7x3x64xf32>) outs(%6 : tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> - flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [1, 112, 112, 64], strides = [1, 1, 1, 1] : tensor<1x112x112x64xf32> -> !flow.dispatch.tensor> - return - } +hal.executable private @conv2d_1x230x230x3_7x7x3x64_dispatch_0 { + hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { + hal.executable.export public @conv2d_1x230x230x3_7x7x3x64 ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @conv2d_1x230x230x3_7x7x3x64() { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 230, 230, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x230x230x3xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [7, 7, 3, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<7x7x3x64xf32> + %5 = tensor.empty() : tensor<1x112x112x64xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> + %7 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%3, %4 : tensor<1x230x230x3xf32>, tensor<7x7x3x64xf32>) outs(%6 : tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [1, 112, 112, 64], strides = [1, 1, 1, 1] : tensor<1x112x112x64xf32> -> !flow.dispatch.tensor> + return } } } @@ -46,34 +43,31 @@ module attributes {hal.device.targets = [#device_target_cuda]} { // ----- -#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_60"}>], legacy_sync}> #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_60"}> #pipeline_layout = #hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]> -module attributes {hal.device.targets = [#device_target_cuda]} { - hal.executable private @conv_nchw_dispatch_0 { - hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { - hal.executable.export public @conv_nchw ordinal(0) layout(#pipeline_layout) { - ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index): - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7 - hal.return %x, %y, %z : index, index, index - } - builtin.module { - func.func @conv_nchw() { - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f32 - %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 230, 230, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x4x66x66xf32> - %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [7, 7, 3, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<320x4x3x3xf32> - %5 = tensor.empty() : tensor<2x320x64x64xf32> - %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x320x64x64xf32>) -> tensor<2x320x64x64xf32> - %7 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} - ins(%3, %4 : tensor<2x4x66x66xf32>, tensor<320x4x3x3xf32>) - outs(%6 : tensor<2x320x64x64xf32>) -> tensor<2x320x64x64xf32> - flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 320, 64, 64], strides = [1, 1, 1, 1] : tensor<2x320x64x64xf32> -> !flow.dispatch.tensor> - return - } +hal.executable private @conv_nchw_dispatch_0 { + hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { + hal.executable.export public @conv_nchw ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @conv_nchw() { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 230, 230, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x4x66x66xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [7, 7, 3, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<320x4x3x3xf32> + %5 = tensor.empty() : tensor<2x320x64x64xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x320x64x64xf32>) -> tensor<2x320x64x64xf32> + %7 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} + ins(%3, %4 : tensor<2x4x66x66xf32>, tensor<320x4x3x3xf32>) + outs(%6 : tensor<2x320x64x64xf32>) -> tensor<2x320x64x64xf32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 320, 64, 64], strides = [1, 1, 1, 1] : tensor<2x320x64x64xf32> -> !flow.dispatch.tensor> + return } } } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir index 2a59b3594bd6..17f3a6fe2781 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir @@ -8,77 +8,74 @@ // RUN: --iree-codegen-transform-dialect-library=%p/transform_dialect_codegen_foreach_to_gpu_spec.mlir@__transform_main | \ // RUN: FileCheck %s --check-prefix=FOREACH-TO-GPU -#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_60"}>]}> #pipeline_layout = #hal.pipeline.layout]>]> #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_60"}> -module attributes {hal.device.targets = [#device_target_cuda]} { - hal.executable private @matmul_static_dispatch_0 { - hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { - hal.executable.export public @matmul_static_dispatch_0 ordinal(0) layout(#pipeline_layout){ - ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index): - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3 - hal.return %x, %y, %z : index, index, index - } - builtin.module { - func.func @matmul_static_dispatch_0() { - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [250, 500], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<250x500xf32> - %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [500, 1020], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<500x1020xf32> +hal.executable private @matmul_static_dispatch_0 { + hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { + hal.executable.export public @matmul_static_dispatch_0 ordinal(0) layout(#pipeline_layout){ + ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @matmul_static_dispatch_0() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [250, 500], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<250x500xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [500, 1020], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<500x1020xf32> - %50 = tensor.empty() : tensor<250x1020xf32> - %cst = arith.constant 0.000000e+00 : f32 - %5 = linalg.fill ins(%cst : f32) outs(%50 : tensor<250x1020xf32>) -> tensor<250x1020xf32> + %50 = tensor.empty() : tensor<250x1020xf32> + %cst = arith.constant 0.000000e+00 : f32 + %5 = linalg.fill ins(%cst : f32) outs(%50 : tensor<250x1020xf32>) -> tensor<250x1020xf32> - // CHECK: memref.assume_alignment %{{.*}}, 64 : memref<250x1020xf32, #hal.descriptor_type> - // CHECK-NEXT: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : memref<250x1020xf32, #hal.descriptor_type>) - // CHECK-NEXT: linalg.matmul{{.*}}ins(%{{.*}} : memref<250x500xf32, #hal.descriptor_type>, memref<500x1020xf32, #hal.descriptor_type>) outs(%{{.*}} : memref<250x1020xf32, #hal.descriptor_type>) - // CHECK-NEXT: return + // CHECK: memref.assume_alignment %{{.*}}, 64 : memref<250x1020xf32, #hal.descriptor_type> + // CHECK-NEXT: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : memref<250x1020xf32, #hal.descriptor_type>) + // CHECK-NEXT: linalg.matmul{{.*}}ins(%{{.*}} : memref<250x500xf32, #hal.descriptor_type>, memref<500x1020xf32, #hal.descriptor_type>) outs(%{{.*}} : memref<250x1020xf32, #hal.descriptor_type>) + // CHECK-NEXT: return - // workgroup_size is explicitly set to [10, 11]. - // FOREACH-TO-GPU-DAG: hal.executable.export {{.*}}{subgroup_size = 32 : index, translation_info = #translation, workgroup_size = [10 : index, 11 : index, 1 : index]} - // FOREACH-TO-GPU-DAG: %[[C0:.*]] = arith.constant 0 : index - // FOREACH-TO-GPU-DAG: %[[C1:.*]] = arith.constant 1 : index - // FOREACH-TO-GPU-DAG: %[[C5:.*]] = arith.constant 5 : index - // FOREACH-TO-GPU-DAG: %[[C7:.*]] = arith.constant 7 : index - // FOREACH-TO-GPU-DAG: %[[C9:.*]] = arith.constant 9 : index - // FOREACH-TO-GPU-DAG: %[[CF0:.*]] = arith.constant 0.000000e+00 : f32 - // FOREACH-TO-GPU: %[[TIDX:.*]] = gpu.thread_id x - // FOREACH-TO-GPU: %[[TIDY:.*]] = gpu.thread_id y - // - // Fill is tiled by 5x1 with thread_dim_mapping = [1, 0, 2], predicate appropriately. - // FOREACH-TO-GPU: %[[LT1:.*]] = arith.cmpi ult, %[[TIDX]], %[[C1]] : index - // FOREACH-TO-GPU: %[[LT5:.*]] = arith.cmpi ult, %[[TIDY]], %[[C5]] : index - // FOREACH-TO-GPU: %[[COND:.*]] = arith.andi %[[LT1]], %[[LT5]] : i1 - // FOREACH-TO-GPU: scf.if %[[COND]] { - // FOREACH-TO-GPU: affine.apply #{{.*}}()[%[[TIDY]]] - // FOREACH-TO-GPU: affine.apply #{{.*}}()[%[[TIDX]]] - // FOREACH-TO-GPU: linalg.fill - // FOREACH-TO-GPU: } - // FOREACH-TO-GPU: gpu.barrier - // - // Matmul is tiled by 7x9 with identity (omitted) thread_dim_mapping, predicate appropriately. - // FOREACH-TO-GPU: %[[LT7:.*]] = arith.cmpi ult, %[[TIDX]], %[[C7]] : index - // FOREACH-TO-GPU: %[[LT9:.*]] = arith.cmpi ult, %[[TIDY]], %[[C9]] : index - // FOREACH-TO-GPU: %[[COND2:.*]] = arith.andi %[[LT7]], %[[LT9]] : i1 - // FOREACH-TO-GPU: scf.if %[[COND2]] { - // FOREACH-TO-GPU: affine.min #{{.*}}()[%[[TIDX]]] - // FOREACH-TO-GPU: affine.min #{{.*}}()[%[[TIDY]]] - // FOREACH-TO-GPU-DAG: affine.apply #{{.*}}()[%[[TIDX]]] - // FOREACH-TO-GPU-DAG: %[[svA:.*]] = memref.subview {{.*}} : memref<250x500xf32{{.*}}> to memref to memref<500x?xf32 - // FOREACH-TO-GPU-DAG: %[[svC:.*]] = memref.subview {{.*}} : memref<250x1020xf32{{.*}}> to memref, memref<500x?xf32{{.*}}>) outs(%[[svC]] : memref) - // FOREACH-TO-GPU: } - // FOREACH-TO-GPU: gpu.barrier - // - %6 = linalg.matmul ins(%3, %4 : tensor<250x500xf32>, tensor<500x1020xf32>) outs(%5 : tensor<250x1020xf32>) -> tensor<250x1020xf32> - flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [250, 1020], strides = [1, 1] : tensor<250x1020xf32> -> !flow.dispatch.tensor> - return - } + // workgroup_size is explicitly set to [10, 11]. + // FOREACH-TO-GPU-DAG: hal.executable.export {{.*}}{subgroup_size = 32 : index, translation_info = #translation, workgroup_size = [10 : index, 11 : index, 1 : index]} + // FOREACH-TO-GPU-DAG: %[[C0:.*]] = arith.constant 0 : index + // FOREACH-TO-GPU-DAG: %[[C1:.*]] = arith.constant 1 : index + // FOREACH-TO-GPU-DAG: %[[C5:.*]] = arith.constant 5 : index + // FOREACH-TO-GPU-DAG: %[[C7:.*]] = arith.constant 7 : index + // FOREACH-TO-GPU-DAG: %[[C9:.*]] = arith.constant 9 : index + // FOREACH-TO-GPU-DAG: %[[CF0:.*]] = arith.constant 0.000000e+00 : f32 + // FOREACH-TO-GPU: %[[TIDX:.*]] = gpu.thread_id x + // FOREACH-TO-GPU: %[[TIDY:.*]] = gpu.thread_id y + // + // Fill is tiled by 5x1 with thread_dim_mapping = [1, 0, 2], predicate appropriately. + // FOREACH-TO-GPU: %[[LT1:.*]] = arith.cmpi ult, %[[TIDX]], %[[C1]] : index + // FOREACH-TO-GPU: %[[LT5:.*]] = arith.cmpi ult, %[[TIDY]], %[[C5]] : index + // FOREACH-TO-GPU: %[[COND:.*]] = arith.andi %[[LT1]], %[[LT5]] : i1 + // FOREACH-TO-GPU: scf.if %[[COND]] { + // FOREACH-TO-GPU: affine.apply #{{.*}}()[%[[TIDY]]] + // FOREACH-TO-GPU: affine.apply #{{.*}}()[%[[TIDX]]] + // FOREACH-TO-GPU: linalg.fill + // FOREACH-TO-GPU: } + // FOREACH-TO-GPU: gpu.barrier + // + // Matmul is tiled by 7x9 with identity (omitted) thread_dim_mapping, predicate appropriately. + // FOREACH-TO-GPU: %[[LT7:.*]] = arith.cmpi ult, %[[TIDX]], %[[C7]] : index + // FOREACH-TO-GPU: %[[LT9:.*]] = arith.cmpi ult, %[[TIDY]], %[[C9]] : index + // FOREACH-TO-GPU: %[[COND2:.*]] = arith.andi %[[LT7]], %[[LT9]] : i1 + // FOREACH-TO-GPU: scf.if %[[COND2]] { + // FOREACH-TO-GPU: affine.min #{{.*}}()[%[[TIDX]]] + // FOREACH-TO-GPU: affine.min #{{.*}}()[%[[TIDY]]] + // FOREACH-TO-GPU-DAG: affine.apply #{{.*}}()[%[[TIDX]]] + // FOREACH-TO-GPU-DAG: %[[svA:.*]] = memref.subview {{.*}} : memref<250x500xf32{{.*}}> to memref to memref<500x?xf32 + // FOREACH-TO-GPU-DAG: %[[svC:.*]] = memref.subview {{.*}} : memref<250x1020xf32{{.*}}> to memref, memref<500x?xf32{{.*}}>) outs(%[[svC]] : memref) + // FOREACH-TO-GPU: } + // FOREACH-TO-GPU: gpu.barrier + // + %6 = linalg.matmul ins(%3, %4 : tensor<250x500xf32>, tensor<500x1020xf32>) outs(%5 : tensor<250x1020xf32>) -> tensor<250x1020xf32> + flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [250, 1020], strides = [1, 1] : tensor<250x1020xf32> -> !flow.dispatch.tensor> + return } } } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir index fb45c206925c..0dd78f5ac9ae 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir @@ -74,7 +74,6 @@ // CHECK: %[[IV_NEXT:.*]] = llvm.mul %[[IV]], %[[C8192]] : i64 #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> #pipeline_layout = #hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]> -#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#executable_target_cuda_nvptx_fb], legacy_sync}> hal.executable private @matmul_dispatch_0 { hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { hal.executable.export public @matmul_dispatch_0_matmul_2560x2560x2560 ordinal(0) layout(#pipeline_layout) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform_cuda.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform_cuda.mlir index d10e4080962b..b16675f1121c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform_cuda.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform_cuda.mlir @@ -402,7 +402,6 @@ hal.executable.variant public @cuda_nvptx_fb target(<"cuda", "cuda-nvptx-fb", {t #map = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0)> #pipeline_layout = #hal.pipeline.layout, <1, storage_buffer>]>]> -#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#executable_target_cuda_nvptx_fb], legacy_sync}> hal.executable @reduction_2d_trailing_elementwise_static_dispatch_0 { hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_batch_matmul.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_batch_matmul.mlir index 31bb131bf0d2..136285c96199 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_batch_matmul.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_batch_matmul.mlir @@ -19,41 +19,37 @@ #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> #pipeline_layout = #hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]> -#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#executable_target_cuda_nvptx_fb], legacy_sync}> -module attributes {hal.device.targets = [#device_target_cuda]} { - hal.executable private @batch_matmul_dispatch_0 { - hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { - hal.executable.export public @batch_matmul_dispatch_0_generic_128x80x320x32_f32 ordinal(0) layout(#pipeline_layout) { - ^bb0(%arg0: !hal.device): - %x, %y, %z = flow.dispatch.workgroup_count_from_slice - hal.return %x, %y, %z : index, index, index - } - builtin.module { - func.func @batch_matmul_dispatch_0_generic_128x80x320x32_f32() { - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f32 - %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [128, 80, 32], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<128x80x32xf32> - %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [128, 32, 320], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<128x32x320xf32> - %5 = tensor.empty() : tensor<128x80x320xf32> - %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<128x80x320xf32>) -> tensor<128x80x320xf32> - %7 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%3, %4 : tensor<128x80x32xf32>, tensor<128x32x320xf32>) outs(%6 : tensor<128x80x320xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %8 = arith.mulf %in, %in_0 : f32 - %9 = arith.addf %out, %8 : f32 - linalg.yield %9 : f32 - } -> tensor<128x80x320xf32> - flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0], sizes = [128, 80, 320], strides = [1, 1, 1] : tensor<128x80x320xf32> -> !flow.dispatch.tensor> - return - } +hal.executable private @batch_matmul_dispatch_0 { + hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { + hal.executable.export public @batch_matmul_dispatch_0_generic_128x80x320x32_f32 ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @batch_matmul_dispatch_0_generic_128x80x320x32_f32() { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [128, 80, 32], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<128x80x32xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [128, 32, 320], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<128x32x320xf32> + %5 = tensor.empty() : tensor<128x80x320xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<128x80x320xf32>) -> tensor<128x80x320xf32> + %7 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%3, %4 : tensor<128x80x32xf32>, tensor<128x32x320xf32>) outs(%6 : tensor<128x80x320xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %8 = arith.mulf %in, %in_0 : f32 + %9 = arith.addf %out, %8 : f32 + linalg.yield %9 : f32 + } -> tensor<128x80x320xf32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0], sizes = [128, 80, 320], strides = [1, 1, 1] : tensor<128x80x320xf32> -> !flow.dispatch.tensor> + return } } } } - // CHECK: transform.named_sequence // CHECK: transform.iree.register_match_callbacks // CHECK: %[[MATCH:.+]]:2 = transform.iree.match_callback failures(propagate) "batch_matmul" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir index b4d2087b9845..a856da6365fc 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir @@ -1,30 +1,27 @@ // RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-select-lowering-strategy, iree-llvmgpu-lower-executable-target, fold-memref-alias-ops, canonicalize, cse)))" %s | FileCheck %s -#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>]}> #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> #pipeline_layout = #hal.pipeline.layout, <1, storage_buffer>]>]> -module attributes {hal.device.targets = [#device_target_cuda]} { - hal.executable @transpose_dispatch_0 { - hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { - hal.executable.export public @transpose_dispatch_0_generic_4096x4096 ordinal(0) layout(#pipeline_layout) { - ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index): - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 - hal.return %x, %y, %z : index, index, index - } - builtin.module { - func.func @transpose_dispatch_0_generic_4096x4096() { - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [4096, 4096], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<4096x4096xf32> - %3 = tensor.empty() : tensor<4096x4096xf32> - %4 = linalg.generic {indexing_maps = [ affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<4096x4096xf32>) outs(%3 : tensor<4096x4096xf32>) { - ^bb0(%arg0: f32, %arg1: f32): - linalg.yield %arg0 : f32 - } -> tensor<4096x4096xf32> - flow.dispatch.tensor.store %4, %1, offsets = [0, 0], sizes = [4096, 4096], strides = [1, 1] : tensor<4096x4096xf32> -> !flow.dispatch.tensor> - return - } +hal.executable @transpose_dispatch_0 { + hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { + hal.executable.export public @transpose_dispatch_0_generic_4096x4096 ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @transpose_dispatch_0_generic_4096x4096() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [4096, 4096], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<4096x4096xf32> + %3 = tensor.empty() : tensor<4096x4096xf32> + %4 = linalg.generic {indexing_maps = [ affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<4096x4096xf32>) outs(%3 : tensor<4096x4096xf32>) { + ^bb0(%arg0: f32, %arg1: f32): + linalg.yield %arg0 : f32 + } -> tensor<4096x4096xf32> + flow.dispatch.tensor.store %4, %1, offsets = [0, 0], sizes = [4096, 4096], strides = [1, 1] : tensor<4096x4096xf32> -> !flow.dispatch.tensor> + return } } } @@ -59,34 +56,31 @@ module attributes {hal.device.targets = [#device_target_cuda]} { // ----- -#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>]}> #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> #pipeline_layout = #hal.pipeline.layout, <1, storage_buffer>]>]> -module attributes {hal.device.targets = [#device_target_cuda]} { - hal.executable @transpose_single_operand_dispatch_0_generic_768x2048 { - hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { - hal.executable.export public @transpose_single_operand_dispatch_0_generic_768x2048 ordinal(0) layout(#pipeline_layout) { - ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index): - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 - hal.return %x, %y, %z : index, index, index - } - builtin.module { - func.func @transpose_single_operand_dispatch_0_generic_768x2048() { - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2048, 768], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<2048x768xf32> - %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [768, 2048], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<768x2048xf32> - %5 = tensor.empty() : tensor<768x2048xf32> - %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%3, %4 : tensor<2048x768xf32>, tensor<768x2048xf32>) outs(%5 : tensor<768x2048xf32>) { - ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): - %7 = arith.addf %arg0, %arg1 : f32 - linalg.yield %7 : f32 - } -> tensor<768x2048xf32> - flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [768, 2048], strides = [1, 1] : tensor<768x2048xf32> -> !flow.dispatch.tensor> - return - } +hal.executable @transpose_single_operand_dispatch_0_generic_768x2048 { + hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { + hal.executable.export public @transpose_single_operand_dispatch_0_generic_768x2048 ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @transpose_single_operand_dispatch_0_generic_768x2048() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2048, 768], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<2048x768xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [768, 2048], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<768x2048xf32> + %5 = tensor.empty() : tensor<768x2048xf32> + %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%3, %4 : tensor<2048x768xf32>, tensor<768x2048xf32>) outs(%5 : tensor<768x2048xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %7 = arith.addf %arg0, %arg1 : f32 + linalg.yield %7 : f32 + } -> tensor<768x2048xf32> + flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [768, 2048], strides = [1, 1] : tensor<768x2048xf32> -> !flow.dispatch.tensor> + return } } } @@ -124,34 +118,31 @@ module attributes {hal.device.targets = [#device_target_cuda]} { // ----- -#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>]}> #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> #pipeline_layout = #hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]> -module attributes {hal.device.targets = [#device_target_cuda]} { - hal.executable @transpose_3d_no_dispatch_0_generic_768x2048x1024 { - hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { - hal.executable.export public @transpose_3d_no_dispatch_0_generic_768x2048x1024 ordinal(0) layout(#pipeline_layout) { - ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index): - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3 - hal.return %x, %y, %z : index, index, index - } - builtin.module { - func.func @transpose_3d_no_dispatch_0_generic_768x2048x1024() { - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [2048, 768, 1024], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<2048x768x1024xf32> - %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [768, 2048, 1024], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<768x2048x1024xf32> - %5 = tensor.empty() : tensor<768x2048x1024xf32> - %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3, %4 : tensor<2048x768x1024xf32>, tensor<768x2048x1024xf32>) outs(%5 : tensor<768x2048x1024xf32>) { - ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): - %7 = arith.addf %arg0, %arg1 : f32 - linalg.yield %7 : f32 - } -> tensor<768x2048x1024xf32> - flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0], sizes = [768, 2048, 1024], strides = [1, 1, 1] : tensor<768x2048x1024xf32> -> !flow.dispatch.tensor> - return - } +hal.executable @transpose_3d_no_dispatch_0_generic_768x2048x1024 { + hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { + hal.executable.export public @transpose_3d_no_dispatch_0_generic_768x2048x1024 ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @transpose_3d_no_dispatch_0_generic_768x2048x1024() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [2048, 768, 1024], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<2048x768x1024xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [768, 2048, 1024], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<768x2048x1024xf32> + %5 = tensor.empty() : tensor<768x2048x1024xf32> + %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3, %4 : tensor<2048x768x1024xf32>, tensor<768x2048x1024xf32>) outs(%5 : tensor<768x2048x1024xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %7 = arith.addf %arg0, %arg1 : f32 + linalg.yield %7 : f32 + } -> tensor<768x2048x1024xf32> + flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0], sizes = [768, 2048, 1024], strides = [1, 1, 1] : tensor<768x2048x1024xf32> -> !flow.dispatch.tensor> + return } } } @@ -164,34 +155,31 @@ module attributes {hal.device.targets = [#device_target_cuda]} { // ----- -#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>]}> #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> #pipeline_layout = #hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]> -module attributes {hal.device.targets = [#device_target_cuda]} { - hal.executable @transpose_3d_yes_dispatch_0_generic_10x768x2048 { - hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { - hal.executable.export public @transpose_3d_yes_dispatch_0_generic_10x768x2048 ordinal(0) layout(#pipeline_layout) { - ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index): - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3 - hal.return %x, %y, %z : index, index, index - } - builtin.module { - func.func @transpose_3d_yes_dispatch_0_generic_10x768x2048() { - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [10, 2048, 768], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<10x2048x768xf32> - %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [10, 768, 2048], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<10x768x2048xf32> - %5 = tensor.empty() : tensor<10x768x2048xf32> - %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3, %4 : tensor<10x2048x768xf32>, tensor<10x768x2048xf32>) outs(%5 : tensor<10x768x2048xf32>) { - ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): - %7 = arith.addf %arg0, %arg1 : f32 - linalg.yield %7 : f32 - } -> tensor<10x768x2048xf32> - flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0], sizes = [10, 768, 2048], strides = [1, 1, 1] : tensor<10x768x2048xf32> -> !flow.dispatch.tensor> - return - } +hal.executable @transpose_3d_yes_dispatch_0_generic_10x768x2048 { + hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { + hal.executable.export public @transpose_3d_yes_dispatch_0_generic_10x768x2048 ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @transpose_3d_yes_dispatch_0_generic_10x768x2048() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [10, 2048, 768], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<10x2048x768xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [10, 768, 2048], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<10x768x2048xf32> + %5 = tensor.empty() : tensor<10x768x2048xf32> + %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3, %4 : tensor<10x2048x768xf32>, tensor<10x768x2048xf32>) outs(%5 : tensor<10x768x2048xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %7 = arith.addf %arg0, %arg1 : f32 + linalg.yield %7 : f32 + } -> tensor<10x768x2048xf32> + flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0], sizes = [10, 768, 2048], strides = [1, 1, 1] : tensor<10x768x2048xf32> -> !flow.dispatch.tensor> + return } } } @@ -229,34 +217,31 @@ module attributes {hal.device.targets = [#device_target_cuda]} { // ----- -#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>]}> #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> #pipeline_layout = #hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]> -module attributes {hal.device.targets = [#device_target_cuda]} { - hal.executable @transpose_3d_trans_out_dispatch_0_generic_10x2048x768 { - hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { - hal.executable.export public @transpose_3d_trans_out_dispatch_0_generic_10x2048x768 ordinal(0) layout(#pipeline_layout) { - ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index): - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3 - hal.return %x, %y, %z : index, index, index - } - builtin.module { - func.func @transpose_3d_trans_out_dispatch_0_generic_10x2048x768() { - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [10, 768, 2048], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<10x768x2048xf32> - %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [10, 768, 2048], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<10x768x2048xf32> - %5 = tensor.empty() : tensor<10x2048x768xf32> - %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3, %4 : tensor<10x768x2048xf32>, tensor<10x768x2048xf32>) outs(%5 : tensor<10x2048x768xf32>) { - ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): - %7 = arith.addf %arg0, %arg1 : f32 - linalg.yield %7 : f32 - } -> tensor<10x2048x768xf32> - flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0], sizes = [10, 2048, 768], strides = [1, 1, 1] : tensor<10x2048x768xf32> -> !flow.dispatch.tensor> - return - } +hal.executable @transpose_3d_trans_out_dispatch_0_generic_10x2048x768 { + hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { + hal.executable.export public @transpose_3d_trans_out_dispatch_0_generic_10x2048x768 ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @transpose_3d_trans_out_dispatch_0_generic_10x2048x768() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [10, 768, 2048], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<10x768x2048xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [10, 768, 2048], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<10x768x2048xf32> + %5 = tensor.empty() : tensor<10x2048x768xf32> + %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3, %4 : tensor<10x768x2048xf32>, tensor<10x768x2048xf32>) outs(%5 : tensor<10x2048x768xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %7 = arith.addf %arg0, %arg1 : f32 + linalg.yield %7 : f32 + } -> tensor<10x2048x768xf32> + flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0], sizes = [10, 2048, 768], strides = [1, 1, 1] : tensor<10x2048x768xf32> -> !flow.dispatch.tensor> + return } } } @@ -297,55 +282,52 @@ module attributes {hal.device.targets = [#device_target_cuda]} { // ----- -#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>]}> #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}> #pipeline_layout = #hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]> -module attributes {hal.device.targets = [#device_target_cuda]} { - hal.executable @transpose_3d_diff_dispatch_0_generic_10x768x2048 { - hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { - hal.executable.export public @transpose_3d_diff_dispatch_0_generic_10x768x2048 ordinal(0) layout(#pipeline_layout) { - ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index): - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3 - hal.return %x, %y, %z : index, index, index +hal.executable @transpose_3d_diff_dispatch_0_generic_10x768x2048 { + hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { + hal.executable.export public @transpose_3d_diff_dispatch_0_generic_10x768x2048 ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @transpose_3d_diff_dispatch_0_generic_10x768x2048() { + %c256 = arith.constant 256 : index + %c10 = arith.constant 10 : index + %c768 = arith.constant 768 : index + %c2048 = arith.constant 2048 : index + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_count_x = hal.interface.workgroup.count[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %workgroup_id_z = hal.interface.workgroup.id[2] : index + %workgroup_count_z = hal.interface.workgroup.count[2] : index + scf.for %arg0 = %workgroup_id_z to %c10 step %workgroup_count_z { + scf.for %arg1 = %workgroup_id_y to %c768 step %workgroup_count_y { + %3 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%workgroup_id_x] + %4 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%workgroup_count_x] + scf.for %arg2 = %3 to %c2048 step %4 { + %5 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg2, %arg1], sizes = [1, %c256, 1], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x?x1xf32> + %6 = flow.dispatch.tensor.load %1, offsets = [%arg2, %arg1, %arg0], sizes = [%c256, 1, 1], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor + %7 = tensor.empty() : tensor<1x1x256xf32> + %8 = tensor.cast %5 : tensor<1x?x1xf32> to tensor<1x256x1xf32> + %9 = tensor.cast %6 : tensor to tensor<256x1x1xf32> + %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2, d1)>, affine_map<(d0, d1, d2) -> (d2, d1, d0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8, %9 : tensor<1x256x1xf32>, tensor<256x1x1xf32>) outs(%7 : tensor<1x1x256xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + %12 = arith.addf %arg3, %arg4 : f32 + linalg.yield %12 : f32 + } -> tensor<1x1x256xf32> + %11 = tensor.cast %10 : tensor<1x1x256xf32> to tensor<1x1x?xf32> + flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 1, %c256], strides = [1, 1, 1] : tensor<1x1x?xf32> -> !flow.dispatch.tensor> + } + } } - builtin.module { - func.func @transpose_3d_diff_dispatch_0_generic_10x768x2048() { - %c256 = arith.constant 256 : index - %c10 = arith.constant 10 : index - %c768 = arith.constant 768 : index - %c2048 = arith.constant 2048 : index - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %workgroup_id_x = hal.interface.workgroup.id[0] : index - %workgroup_count_x = hal.interface.workgroup.count[0] : index - %workgroup_id_y = hal.interface.workgroup.id[1] : index - %workgroup_count_y = hal.interface.workgroup.count[1] : index - %workgroup_id_z = hal.interface.workgroup.id[2] : index - %workgroup_count_z = hal.interface.workgroup.count[2] : index - scf.for %arg0 = %workgroup_id_z to %c10 step %workgroup_count_z { - scf.for %arg1 = %workgroup_id_y to %c768 step %workgroup_count_y { - %3 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%workgroup_id_x] - %4 = affine.apply affine_map<()[s0] -> (s0 * 256)>()[%workgroup_count_x] - scf.for %arg2 = %3 to %c2048 step %4 { - %5 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg2, %arg1], sizes = [1, %c256, 1], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x?x1xf32> - %6 = flow.dispatch.tensor.load %1, offsets = [%arg2, %arg1, %arg0], sizes = [%c256, 1, 1], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor - %7 = tensor.empty() : tensor<1x1x256xf32> - %8 = tensor.cast %5 : tensor<1x?x1xf32> to tensor<1x256x1xf32> - %9 = tensor.cast %6 : tensor to tensor<256x1x1xf32> - %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2, d1)>, affine_map<(d0, d1, d2) -> (d2, d1, d0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8, %9 : tensor<1x256x1xf32>, tensor<256x1x1xf32>) outs(%7 : tensor<1x1x256xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): - %12 = arith.addf %arg3, %arg4 : f32 - linalg.yield %12 : f32 - } -> tensor<1x1x256xf32> - %11 = tensor.cast %10 : tensor<1x1x256xf32> to tensor<1x1x?xf32> - flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 1, %c256], strides = [1, 1, 1] : tensor<1x1x?xf32> -> !flow.dispatch.tensor> - } - } - } - return - } + return } } } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute.mlir deleted file mode 100644 index 5f94f6e3ca6b..000000000000 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute.mlir +++ /dev/null @@ -1,109 +0,0 @@ -// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-llvmgpu-vector-distribute, canonicalize, cse))' -split-input-file %s | FileCheck %s - -builtin.module attributes { hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", { - mma_intrinsics = [#iree_gpu.mfma_layout, - #iree_gpu.mfma_layout], - target_arch = "gfx940", - ukernels = "none"}>} { - func.func @matmul_256x256x256(%lhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, - %rhs: memref<256x16xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, - %out: memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type>) - attributes {subgroup_size = 64, workgroup_size = [64, 1, 1]} { - %alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space> - %alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space> - %cst = arith.constant 0.000000e+00 : f16 - %cst_1 = arith.constant dense<0.000000e+00> : vector<16x16xf32> - %c32 = arith.constant 32 : index - %c256 = arith.constant 256 : index - %c0 = arith.constant 0 : index - %5 = scf.for %arg0 = %c0 to %c256 step %c32 iter_args(%arg1 = %cst_1) -> (vector<16x16xf32>) { - %6 = vector.transfer_read %lhs[%c0, %arg0], %cst {in_bounds = [true, true]} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, vector<16x32xf16> - %7 = vector.transfer_read %rhs[%arg0, %c0], %cst {in_bounds = [true, true]} : memref<256x16xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, vector<32x16xf16> - vector.transfer_write %6, %alloc_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x32xf16>, memref<16x32xf16, #gpu.address_space> - gpu.barrier - vector.transfer_write %7, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<32x16xf16>, memref<32x16xf16, #gpu.address_space> - gpu.barrier - %8 = vector.transfer_read %alloc_0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x32xf16, #gpu.address_space>, vector<16x32xf16> - %9 = vector.transfer_read %alloc[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x16xf16, #gpu.address_space>, vector<32x16xf16> - %10 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %8, %9, %arg1 : vector<16x32xf16>, vector<32x16xf16> into vector<16x16xf32> - scf.yield %10 : vector<16x16xf32> - } - vector.transfer_write %5, %out[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type> - memref.dealloc %alloc_0 : memref<16x32xf16, #gpu.address_space> - memref.dealloc %alloc : memref<32x16xf16, #gpu.address_space> - return - } -} - -// CHECK-LABEL: func.func @matmul_256x256x256 -// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x4xf32> -// CHECK: %[[RHS_ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #gpu.address_space> -// CHECK: %[[LHS_ALLOC:.+]] = memref.alloc() : memref<16x32xf16, #gpu.address_space> -// CHECK: scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}} = %[[INIT]]) -> (vector<1x1x4xf32>) -// CHECK: %[[LLOAD:.+]] = vector.transfer_read {{.*}} : memref<16x256xf16, {{.*}}>, vector<1x8xf16> -// CHECK: %[[RLOAD:.+]] = vector.transfer_read {{.*}} : memref<256x16xf16, {{.*}}>, vector<1x8xf16> -// CHECK: vector.transfer_write %[[LLOAD]], %[[LHS_ALLOC]]{{.*}} : vector<1x8xf16>, memref<16x32xf16, #gpu.address_space> -// CHECK: vector.transfer_write %[[RLOAD]], %[[RHS_ALLOC]]{{.*}} : vector<1x8xf16>, memref<32x16xf16, #gpu.address_space> -// CHECK: gpu.barrier -// CHECK-COUNT-2: vector.load %[[LHS_ALLOC]]{{.*}} : memref<16x32xf16, #gpu.address_space>, vector<4xf16> -// CHECK-COUNT-8: vector.load %[[RHS_ALLOC]]{{.*}} : memref<32x16xf16, #gpu.address_space>, vector<1xf16> -// CHECK-COUNT-2: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> -// CHECK: %[[BCAST:.+]] = vector.broadcast {{.*}} : vector<4xf32> to vector<1x1x4xf32> -// CHECK: scf.yield %[[BCAST]] : vector<1x1x4xf32> -// CHECK-COUNT-4: vector.store {{.*}} : memref<16x16xf32{{.*}}>, vector<1xf32> - -// ----- - -builtin.module attributes { hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", { - mma_intrinsics = [#iree_gpu.mfma_layout, - #iree_gpu.mfma_layout], - target_arch = "gfx940", - ukernels = "none"}>} { - func.func @matmul_256x256x256(%lhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, - %rhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, - %out: memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type>) - attributes {subgroup_size = 64, workgroup_size = [64, 1, 1]} { - %alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space> - %alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space> - %cst = arith.constant 0.000000e+00 : f16 - %cst_1 = arith.constant dense<0.000000e+00> : vector<16x16xf32> - %c32 = arith.constant 32 : index - %c256 = arith.constant 256 : index - %c0 = arith.constant 0 : index - %5 = scf.for %arg0 = %c0 to %c256 step %c32 iter_args(%arg1 = %cst_1) -> (vector<16x16xf32>) { - %6 = vector.transfer_read %lhs[%c0, %arg0], %cst {in_bounds = [true, true]} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, vector<16x32xf16> - %7 = vector.transfer_read %rhs[%arg0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, vector<32x16xf16> - vector.transfer_write %6, %alloc_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x32xf16>, memref<16x32xf16, #gpu.address_space> - gpu.barrier - vector.transfer_write %7, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<32x16xf16>, memref<32x16xf16, #gpu.address_space> - gpu.barrier - %8 = vector.transfer_read %alloc_0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x32xf16, #gpu.address_space>, vector<16x32xf16> - %9 = vector.transfer_read %alloc[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x16xf16, #gpu.address_space>, vector<32x16xf16> - %10 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %8, %9, %arg1 : vector<16x32xf16>, vector<32x16xf16> into vector<16x16xf32> - scf.yield %10 : vector<16x16xf32> - } - vector.transfer_write %5, %out[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type> - memref.dealloc %alloc_0 : memref<16x32xf16, #gpu.address_space> - memref.dealloc %alloc : memref<32x16xf16, #gpu.address_space> - return - } -} - -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1, d0)> - -// CHECK-LABEL: func.func @matmul_256x256x256 -// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x4xf32> -// CHECK: %[[RHS_ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #gpu.address_space> -// CHECK: %[[LHS_ALLOC:.+]] = memref.alloc() : memref<16x32xf16, #gpu.address_space> -// CHECK: scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}} = %[[INIT]]) -> (vector<1x1x4xf32>) -// CHECK: %[[LLOAD:.+]] = vector.transfer_read {{.*}} : memref<16x256xf16, {{.*}}>, vector<1x8xf16> -// CHECK: %[[RLOAD:.+]] = vector.transfer_read {{.*}} permutation_map = #[[$MAP]]} : memref<16x256xf16, {{.*}}>, vector<8x1xf16> -// CHECK: vector.transfer_write %[[LLOAD]], %[[LHS_ALLOC]]{{.*}} : vector<1x8xf16>, memref<16x32xf16, #gpu.address_space> -// CHECK: vector.transfer_write %[[RLOAD]], %[[RHS_ALLOC]]{{.*}} : vector<8x1xf16>, memref<32x16xf16, #gpu.address_space> -// CHECK: gpu.barrier -// CHECK-COUNT-2: vector.load %[[LHS_ALLOC]]{{.*}} : memref<16x32xf16, #gpu.address_space>, vector<4xf16> -// CHECK-COUNT-8: vector.load %[[RHS_ALLOC]]{{.*}} : memref<32x16xf16, #gpu.address_space>, vector<1xf16> -// CHECK-COUNT-2: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> -// CHECK: %[[BCAST:.+]] = vector.broadcast {{.*}} : vector<4xf32> to vector<1x1x4xf32> -// CHECK: scf.yield %[[BCAST]] : vector<1x1x4xf32> -// CHECK-COUNT-4: vector.store {{.*}} : memref<16x16xf32{{.*}}>, vector<1xf32> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir new file mode 100644 index 000000000000..e48b5f3af120 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_conversion.mlir @@ -0,0 +1,119 @@ +// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-llvmgpu-vector-distribute, canonicalize, cse))' -split-input-file %s | FileCheck %s + +func.func @matmul_256x256x256(%lhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, + %rhs: memref<256x16xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, + %out: memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type>) + attributes { + mma_schedule = #iree_gpu.mma_schedule, + subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>, + workgroup_size = [64, 1, 1]} { + %alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space> + %alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space> + %cst = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant dense<0.000000e+00> : vector<16x16xf32> + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c0 = arith.constant 0 : index + %5 = scf.for %arg0 = %c0 to %c256 step %c32 iter_args(%arg1 = %cst_1) -> (vector<16x16xf32>) { + %6 = vector.transfer_read %lhs[%c0, %arg0], %cst {in_bounds = [true, true]} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, vector<16x32xf16> + %7 = vector.transfer_read %rhs[%arg0, %c0], %cst {in_bounds = [true, true]} : memref<256x16xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, vector<32x16xf16> + vector.transfer_write %6, %alloc_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x32xf16>, memref<16x32xf16, #gpu.address_space> + gpu.barrier + vector.transfer_write %7, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<32x16xf16>, memref<32x16xf16, #gpu.address_space> + gpu.barrier + %8 = vector.transfer_read %alloc_0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x32xf16, #gpu.address_space>, vector<16x32xf16> + %9 = vector.transfer_read %alloc[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x16xf16, #gpu.address_space>, vector<32x16xf16> + %10 = vector.contract { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} + %8, %9, %arg1 : vector<16x32xf16>, vector<32x16xf16> into vector<16x16xf32> + scf.yield %10 : vector<16x16xf32> + } + vector.transfer_write %5, %out[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type> + memref.dealloc %alloc_0 : memref<16x32xf16, #gpu.address_space> + memref.dealloc %alloc : memref<32x16xf16, #gpu.address_space> + return +} + +// CHECK-LABEL: func.func @matmul_256x256x256 +// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x1x1x1x4xf32> +// CHECK: %[[RHS_ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #gpu.address_space> +// CHECK: %[[LHS_ALLOC:.+]] = memref.alloc() : memref<16x32xf16, #gpu.address_space> +// CHECK: scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}} = %[[INIT]]) -> (vector<1x1x1x1x1x4xf32>) +// CHECK: %[[LLOAD:.+]] = vector.transfer_read {{.*}} : memref<16x256xf16, {{.*}}>, vector<1x8xf16> +// CHECK: %[[RLOAD:.+]] = vector.transfer_read {{.*}} : memref<256x16xf16, {{.*}}>, vector<1x8xf16> +// CHECK: vector.transfer_write %[[LLOAD]], %[[LHS_ALLOC]]{{.*}} : vector<1x8xf16>, memref<16x32xf16, #gpu.address_space> +// CHECK: vector.transfer_write %[[RLOAD]], %[[RHS_ALLOC]]{{.*}} : vector<1x8xf16>, memref<32x16xf16, #gpu.address_space> +// CHECK: gpu.barrier +// CHECK-COUNT-2: vector.transfer_read %[[LHS_ALLOC]][{{.+}}], %{{.+}} {in_bounds = [true, true]} : memref<16x32xf16, #gpu.address_space>, vector<1x4xf16> +// CHECK-COUNT-2: vector.transfer_read %[[RHS_ALLOC]][{{.+}}], %{{.+}} {in_bounds = [true, true]} : memref<32x16xf16, #gpu.address_space>, vector<4x1xf16> +// CHECK-COUNT-2: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> +// CHECK: %[[BCAST:.+]] = vector.broadcast {{.*}} : vector<1x1x1x4xf32> to vector<1x1x1x1x1x4xf32> +// CHECK: scf.yield %[[BCAST]] : vector<1x1x1x1x1x4xf32> +// CHECK: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf32>, memref<16x16xf32{{.*}}> + +// ----- + +func.func @matmul_256x256x256(%lhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, + %rhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, + %out: memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type>) + attributes { + mma_schedule = #iree_gpu.mma_schedule, + subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>, + workgroup_size = [64, 1, 1]} { + %alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space> + %alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space> + %cst = arith.constant 0.000000e+00 : f16 + %cst_f32 = arith.constant 0.000000e+00 : f32 + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c0 = arith.constant 0 : index + %init_acc = vector.transfer_read %out[%c0, %c0], %cst_f32 {in_bounds = [true, true]} + : memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type>, vector<16x16xf32> + %5 = scf.for %arg0 = %c0 to %c256 step %c32 iter_args(%arg1 = %init_acc) -> (vector<16x16xf32>) { + %6 = vector.transfer_read %lhs[%c0, %arg0], %cst {in_bounds = [true, true]} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, vector<16x32xf16> + %7 = vector.transfer_read %rhs[%arg0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, vector<32x16xf16> + vector.transfer_write %6, %alloc_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x32xf16>, memref<16x32xf16, #gpu.address_space> + gpu.barrier + vector.transfer_write %7, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<32x16xf16>, memref<32x16xf16, #gpu.address_space> + gpu.barrier + %8 = vector.transfer_read %alloc_0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x32xf16, #gpu.address_space>, vector<16x32xf16> + %9 = vector.transfer_read %alloc[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x16xf16, #gpu.address_space>, vector<32x16xf16> + %10 = vector.contract { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} + %8, %9, %arg1 : vector<16x32xf16>, vector<32x16xf16> into vector<16x16xf32> + scf.yield %10 : vector<16x16xf32> + } + vector.transfer_write %5, %out[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type> + memref.dealloc %alloc_0 : memref<16x32xf16, #gpu.address_space> + memref.dealloc %alloc : memref<32x16xf16, #gpu.address_space> + return +} + +// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 64 + s2 * 64)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d1, d0)> + +// CHECK-LABEL: func.func @matmul_256x256x256 +// CHECK: %[[TIDX:.+]] = gpu.thread_id x +// CHECK: %[[TIDY:.+]] = gpu.thread_id y +// CHECK: %[[TIDZ:.+]] = gpu.thread_id z +// CHECK: %[[LIN_ID:.+]] = affine.apply #[[$MAP]]()[%[[TIDX]], %[[TIDY]], %[[TIDZ]]] +// CHECK: %[[RHS_ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #gpu.address_space> +// CHECK: %[[LHS_ALLOC:.+]] = memref.alloc() : memref<16x32xf16, #gpu.address_space> +// CHECK: affine.delinearize_index %[[LIN_ID]] +// CHECK: %[[INIT_READ:.+]] = vector.transfer_read %{{.*}} memref<16x16xf32, {{.*}}>, vector<4x1xf32> +// CHECK: %[[INIT_TRANSP:.+]] = vector.transpose %[[INIT_READ]], [1, 0] +// CHECK: %[[INIT:.+]] = vector.insert_strided_slice %[[INIT_TRANSP]] +// CHECK: scf.for {{.*}} = %c0 to %c256 step %c32 iter_args({{.*}} = %[[INIT]]) -> (vector<1x1x1x1x1x4xf32>) +// CHECK: %[[LLOAD:.+]] = vector.transfer_read {{.*}} : memref<16x256xf16, {{.*}}>, vector<1x8xf16> +// CHECK: %[[RLOAD:.+]] = vector.transfer_read {{.*}} permutation_map = #[[$MAP1]]} : memref<16x256xf16, {{.*}}>, vector<8x1xf16> +// CHECK: vector.transfer_write %[[LLOAD]], %[[LHS_ALLOC]]{{.*}} : vector<1x8xf16>, memref<16x32xf16, #gpu.address_space> +// CHECK: vector.transfer_write %[[RLOAD]], %[[RHS_ALLOC]]{{.*}} : vector<8x1xf16>, memref<32x16xf16, #gpu.address_space> +// CHECK: gpu.barrier +// CHECK-COUNT-2: vector.transfer_read %[[LHS_ALLOC]][{{.+}}], %{{.+}} {in_bounds = [true, true]} : memref<16x32xf16, #gpu.address_space>, vector<1x4xf16> +// CHECK-COUNT-2: vector.transfer_read %[[RHS_ALLOC]][{{.+}}], %{{.+}} {in_bounds = [true, true]} : memref<32x16xf16, #gpu.address_space>, vector<4x1xf16> +// CHECK-COUNT-2: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> +// CHECK: %[[BCAST:.+]] = vector.broadcast {{.*}} : vector<1x1x1x4xf32> to vector<1x1x1x1x1x4xf32> +// CHECK: scf.yield %[[BCAST]] : vector<1x1x1x1x1x4xf32> +// CHECK: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf32>, memref<16x16xf32{{.*}}> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir new file mode 100644 index 000000000000..543f22e616e7 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_distribute_layout.mlir @@ -0,0 +1,193 @@ +// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-llvmgpu-vector-distribute{test-layout}, canonicalize, cse))' %s | FileCheck %s + +func.func @matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf32>) -> vector<96x64xf32> attributes { + mma_schedule = #iree_gpu.mma_schedule< + intrinsic = #iree_gpu.mfma_layout, + subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>, + workgroup_size = [64, 1, 1]} { + %0 = vector.contract { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} + %lhs, %rhs, %init : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf32> + return %0 : vector<96x64xf32> +} + +// CHECK: contract A vector layout: #iree_vector_ext.nested_layout< +// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4], +// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [1, 0], element_order = [0, 1], +// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 32]> +// CHECK: contract B vector layout: #iree_vector_ext.nested_layout< +// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1], +// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [1, 0], +// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 32]> +// CHECK: contract C vector layout: #iree_vector_ext.nested_layout< +// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [4, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1], +// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [1, 0], +// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 32]> + +// ----- + +func.func @matmul_96x64x16_mmt(%lhs: vector<96x16xf16>, %rhs: vector<64x16xf16>, %init: vector<96x64xf32>) -> vector<96x64xf32> attributes { + mma_schedule = #iree_gpu.mma_schedule< + intrinsic = #iree_gpu.mfma_layout, + subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>, + workgroup_size = [64, 1, 1]} { + %0 = vector.contract { + indexing_maps = [affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (n, k)>, affine_map<(m, n, d2) -> (m, n)>], + iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} + %lhs, %rhs, %init : vector<96x16xf16>, vector<64x16xf16> into vector<96x64xf32> + return %0 : vector<96x64xf32> +} + +// CHECK: contract A vector layout: #iree_vector_ext.nested_layout< +// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4], +// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [1, 0], element_order = [0, 1], +// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 32]> +// CHECK: contract B vector layout: #iree_vector_ext.nested_layout< +// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 2], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 4], +// CHECK-SAME: subgroup_order = [1, 0], batch_order = [1, 0], outer_order = [1, 0], thread_order = [1, 0], element_order = [0, 1], +// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 32]> +// CHECK: contract C vector layout: #iree_vector_ext.nested_layout< +// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [3, 2], outers_per_batch = [4, 1], threads_per_outer = [2, 32], elements_per_thread = [4, 1], +// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [1, 0], +// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [2, 32]> + +// ----- + +func.func @matmul_192x64x16_mmt_multisubgroup(%lhs: vector<192x16xf16>, %rhs: vector<16x64xf16>, %init: vector<192x64xf32>) -> vector<192x64xf32> attributes { + mma_schedule = #iree_gpu.mma_schedule< + intrinsic = #iree_gpu.mfma_layout, + subgroup_m_count = 2, subgroup_n_count = 1, subgroup_m_tile_count = 3, subgroup_n_tile_count = 2, subgroup_k_tile_count = 2>, + workgroup_size = [64, 2, 1]} { + %0 = vector.contract { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} + %lhs, %rhs, %init : vector<192x16xf16>, vector<16x64xf16> into vector<192x64xf32> + return %0 : vector<192x64xf32> +} + +// CHECK: contract A vector layout: #iree_vector_ext.nested_layout, #hal.descriptor_type>, + %rhs: memref<256x16xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, + %out: memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type>) + attributes { + mma_schedule = #iree_gpu.mma_schedule, + subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>, + workgroup_size = [64, 1, 1]} { + %alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space> + %alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space> + %cst = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant dense<0.000000e+00> : vector<16x16xf32> + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c0 = arith.constant 0 : index + %5 = scf.for %arg0 = %c0 to %c256 step %c32 iter_args(%arg1 = %cst_1) -> (vector<16x16xf32>) { + %6 = vector.transfer_read %lhs[%c0, %arg0], %cst {in_bounds = [true, true]} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, vector<16x32xf16> + %7 = vector.transfer_read %rhs[%arg0, %c0], %cst {in_bounds = [true, true]} : memref<256x16xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, vector<32x16xf16> + vector.transfer_write %6, %alloc_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x32xf16>, memref<16x32xf16, #gpu.address_space> + gpu.barrier + vector.transfer_write %7, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<32x16xf16>, memref<32x16xf16, #gpu.address_space> + gpu.barrier + %8 = vector.transfer_read %alloc_0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x32xf16, #gpu.address_space>, vector<16x32xf16> + %9 = vector.transfer_read %alloc[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x16xf16, #gpu.address_space>, vector<32x16xf16> + %10 = vector.contract { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} + %8, %9, %arg1 : vector<16x32xf16>, vector<32x16xf16> into vector<16x16xf32> + scf.yield %10 : vector<16x16xf32> + } + vector.transfer_write %5, %out[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type> + memref.dealloc %alloc_0 : memref<16x32xf16, #gpu.address_space> + memref.dealloc %alloc : memref<32x16xf16, #gpu.address_space> + return +} + +// CHECK: transfer '{{.+}} memref<16x256xf16{{.+}}>, vector<16x32xf16>' vector layout: #iree_vector_ext.nested_layout< +// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 8], +// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [0, 1], +// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [16, 4]> +// CHECK: transfer '{{.+}} memref<256x16xf16{{.+}}>, vector<32x16xf16>' vector layout: #iree_vector_ext.nested_layout< +// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [32, 2], elements_per_thread = [1, 8], +// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [0, 1], +// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [32, 2]> + +// CHECK: contract A vector layout: #iree_vector_ext.nested_layout< +// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 4], +// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [1, 0], element_order = [0, 1], +// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]> +// CHECK: contract B vector layout: #iree_vector_ext.nested_layout< +// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1], +// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [1, 0], +// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]> +// CHECK: contract C vector layout: #iree_vector_ext.nested_layout< +// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1], +// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [1, 0], +// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]> + +// ----- + +func.func @matmul_16x16x256_read_permute(%lhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, + %rhs: memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, + %out: memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type>) + attributes { + mma_schedule = #iree_gpu.mma_schedule, + subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 2>, + workgroup_size = [64, 1, 1]} { + %alloc = memref.alloc() : memref<32x16xf16, #gpu.address_space> + %alloc_0 = memref.alloc() : memref<16x32xf16, #gpu.address_space> + %cst = arith.constant 0.000000e+00 : f16 + %cst_f32 = arith.constant 0.000000e+00 : f32 + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c0 = arith.constant 0 : index + %init_acc = vector.transfer_read %out[%c0, %c0], %cst_f32 {in_bounds = [true, true]} + : memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type>, vector<16x16xf32> + %5 = scf.for %arg0 = %c0 to %c256 step %c32 iter_args(%arg1 = %init_acc) -> (vector<16x16xf32>) { + %6 = vector.transfer_read %lhs[%c0, %arg0], %cst {in_bounds = [true, true]} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, vector<16x32xf16> + %7 = vector.transfer_read %rhs[%arg0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : memref<16x256xf16, strided<[256, 1], offset: ?>, #hal.descriptor_type>, vector<32x16xf16> + vector.transfer_write %6, %alloc_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x32xf16>, memref<16x32xf16, #gpu.address_space> + gpu.barrier + vector.transfer_write %7, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<32x16xf16>, memref<32x16xf16, #gpu.address_space> + gpu.barrier + %8 = vector.transfer_read %alloc_0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x32xf16, #gpu.address_space>, vector<16x32xf16> + %9 = vector.transfer_read %alloc[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x16xf16, #gpu.address_space>, vector<32x16xf16> + %10 = vector.contract { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} + %8, %9, %arg1 : vector<16x32xf16>, vector<32x16xf16> into vector<16x16xf32> + scf.yield %10 : vector<16x16xf32> + } + vector.transfer_write %5, %out[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32, strided<[256, 1], offset: ?>, #hal.descriptor_type> + memref.dealloc %alloc_0 : memref<16x32xf16, #gpu.address_space> + memref.dealloc %alloc : memref<32x16xf16, #gpu.address_space> + return +} + +// CHECK-NOT: transfer '{{.+}} memref<16x16xf16{{.+}}>, vector<16x16xf16>' vector layout +// CHECK: transfer '{{.+}} memref<16x256xf16{{.+}}>, vector<16x32xf16>' vector layout: #iree_vector_ext.nested_layout< +// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 8], +// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [0, 1], +// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [16, 4]> +// CHECK: transfer '{{.+}} memref<16x256xf16{{.+}}storage_buffer>>, vector<32x16xf16>' vector layout: #iree_vector_ext.nested_layout< +// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [8, 1], +// CHECK-SAME: subgroup_order = [1, 0], batch_order = [1, 0], outer_order = [1, 0], thread_order = [1, 0], element_order = [1, 0], +// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]> + +// CHECK: contract A vector layout: #iree_vector_ext.nested_layout< +// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 2], outers_per_batch = [1, 1], threads_per_outer = [16, 4], elements_per_thread = [1, 4], +// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [1, 0], element_order = [0, 1], +// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]> +// CHECK: contract B vector layout: #iree_vector_ext.nested_layout< +// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [2, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1], +// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [1, 0], +// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]> +// CHECK: contract C vector layout: #iree_vector_ext.nested_layout< +// CHECK-SAME: subgroups_per_workgroup = [1, 1], batches_per_subgroup = [1, 1], outers_per_batch = [1, 1], threads_per_outer = [4, 16], elements_per_thread = [4, 1], +// CHECK-SAME: subgroup_order = [0, 1], batch_order = [0, 1], outer_order = [0, 1], thread_order = [0, 1], element_order = [1, 0], +// CHECK-SAME: subgroup_basis = [1, 1], thread_basis = [4, 16]> + diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/workgroup_specialization_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/workgroup_specialization_pipeline_test.mlir index 8a4e5a1d398b..c93c4aa292a1 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/workgroup_specialization_pipeline_test.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/workgroup_specialization_pipeline_test.mlir @@ -1,37 +1,36 @@ // RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-select-lowering-strategy, iree-llvmgpu-lower-executable-target)))" %s | FileCheck %s -module attributes {hal.device.targets = [#hal.device.target<"cuda", {executable_targets = [#hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>]}>]} { - hal.executable private @forward_dispatch_116 { - hal.executable.variant public @cuda_nvptx_fb target(<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>) { - hal.executable.export public @forward_dispatch_116_matmul_128x30522x768 ordinal(0) layout(#hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) { - ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index): - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 - hal.return %x, %y, %z : index, index, index - } - builtin.module { - func.func @forward_dispatch_116_matmul_128x30522x768() { - %c512 = arith.constant 512 : index - %c786944 = arith.constant 786944 : index - %c265458176 = arith.constant 265458176 : index - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f32 - %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c512) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c786944) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c265458176) : !flow.dispatch.tensor> - %3 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %4 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 768], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<128x768xf32> - %5 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [768, 30522], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<768x30522xf32> - %6 = flow.dispatch.tensor.load %2, offsets = [0], sizes = [30522], strides = [1] : !flow.dispatch.tensor> -> tensor<30522xf32> - %7 = tensor.empty() : tensor<128x30522xf32> - %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<128x30522xf32>) -> tensor<128x30522xf32> - %9 = linalg.matmul ins(%4, %5 : tensor<128x768xf32>, tensor<768x30522xf32>) outs(%8 : tensor<128x30522xf32>) -> tensor<128x30522xf32> - %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %6 : tensor<128x30522xf32>, tensor<30522xf32>) outs(%7 : tensor<128x30522xf32>) { - ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): - %11 = arith.addf %arg0, %arg1 : f32 - linalg.yield %11 : f32 - } -> tensor<128x30522xf32> - flow.dispatch.tensor.store %10, %3, offsets = [0, 0], sizes = [128, 30522], strides = [1, 1] : tensor<128x30522xf32> -> !flow.dispatch.tensor> - return - } + +hal.executable private @forward_dispatch_116 { + hal.executable.variant public @cuda_nvptx_fb target(<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>) { + hal.executable.export public @forward_dispatch_116_matmul_128x30522x768 ordinal(0) layout(#hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @forward_dispatch_116_matmul_128x30522x768() { + %c512 = arith.constant 512 : index + %c786944 = arith.constant 786944 : index + %c265458176 = arith.constant 265458176 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c512) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c786944) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c265458176) : !flow.dispatch.tensor> + %3 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %4 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 768], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<128x768xf32> + %5 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [768, 30522], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<768x30522xf32> + %6 = flow.dispatch.tensor.load %2, offsets = [0], sizes = [30522], strides = [1] : !flow.dispatch.tensor> -> tensor<30522xf32> + %7 = tensor.empty() : tensor<128x30522xf32> + %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<128x30522xf32>) -> tensor<128x30522xf32> + %9 = linalg.matmul ins(%4, %5 : tensor<128x768xf32>, tensor<768x30522xf32>) outs(%8 : tensor<128x30522xf32>) -> tensor<128x30522xf32> + %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %6 : tensor<128x30522xf32>, tensor<30522xf32>) outs(%7 : tensor<128x30522xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %11 = arith.addf %arg0, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor<128x30522xf32> + flow.dispatch.tensor.store %10, %3, offsets = [0, 0], sizes = [128, 30522], strides = [1, 1] : tensor<128x30522xf32> -> !flow.dispatch.tensor> + return } } } @@ -57,33 +56,30 @@ module attributes {hal.device.targets = [#hal.device.target<"cuda", {executable_ #map = affine_map<(d0) -> (d0)> #pipeline_layout = #hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]> #executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_60"}> -#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#executable_target_cuda_nvptx_fb], legacy_sync}> -module attributes {hal.device.targets = [#device_target_cuda]} { - hal.executable private @vectorized_dispatch_0 { - hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { - hal.executable.export public @vectorized_dispatch_0_generic_102401 ordinal(0) layout(#pipeline_layout) { - ^bb0(%arg0: !hal.device, %arg1: index): - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1 - hal.return %x, %y, %z : index, index, index - } - builtin.module { - func.func @vectorized_dispatch_0_generic_102401() { - %c0 = arith.constant 0 : index - %cst = arith.constant -3.000000e+00 : f32 - %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %3 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [102401], strides = [1] : !flow.dispatch.tensor> -> tensor<102401xf32> - %4 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [102401], strides = [1] : !flow.dispatch.tensor> -> tensor<102401xf32> - %5 = tensor.empty() : tensor<102401xf32> - %6 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%3, %4 : tensor<102401xf32>, tensor<102401xf32>) outs(%5 : tensor<102401xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %7 = math.fma %cst, %in, %in_0 : f32 - linalg.yield %7 : f32 - } -> tensor<102401xf32> - flow.dispatch.tensor.store %6, %2, offsets = [0], sizes = [102401], strides = [1] : tensor<102401xf32> -> !flow.dispatch.tensor> - return - } +hal.executable private @vectorized_dispatch_0 { + hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) { + hal.executable.export public @vectorized_dispatch_0_generic_102401 ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @vectorized_dispatch_0_generic_102401() { + %c0 = arith.constant 0 : index + %cst = arith.constant -3.000000e+00 : f32 + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [102401], strides = [1] : !flow.dispatch.tensor> -> tensor<102401xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [102401], strides = [1] : !flow.dispatch.tensor> -> tensor<102401xf32> + %5 = tensor.empty() : tensor<102401xf32> + %6 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%3, %4 : tensor<102401xf32>, tensor<102401xf32>) outs(%5 : tensor<102401xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %7 = math.fma %cst, %in, %in_0 : f32 + linalg.yield %7 : f32 + } -> tensor<102401xf32> + flow.dispatch.tensor.store %6, %2, offsets = [0], sizes = [102401], strides = [1] : tensor<102401xf32> -> !flow.dispatch.tensor> + return } } } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel index 8401c1d53f70..1d99e3394006 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel @@ -89,6 +89,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Common", "//compiler/src/iree/compiler/Codegen/Common:TransformDialectInterpreterPass", "//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses", + "//compiler/src/iree/compiler/Codegen/Common/GPU:GPUHeuristics", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", "//compiler/src/iree/compiler/Codegen/TransformStrategies/GPU", "//compiler/src/iree/compiler/Codegen/Transforms", diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt index 1515b03ef677..28a100d15f1c 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt @@ -139,6 +139,7 @@ iree_cc_library( MLIRVectorTransforms iree::compiler::Codegen::Common iree::compiler::Codegen::Common::GPU::CommonGPUPasses + iree::compiler::Codegen::Common::GPU::GPUHeuristics iree::compiler::Codegen::Common::TransformDialectInterpreterPass iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect iree::compiler::Codegen::TransformStrategies::GPU diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp index 2ca6c05ba393..dec463690980 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp @@ -7,10 +7,12 @@ #include "iree/compiler/Codegen/SPIRV/KernelConfig.h" #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "iree/compiler/Codegen/Common/GPU/GPUHeuristics.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/SPIRV/Utils.h" #include "iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "iree/compiler/Codegen/Utils/LinalgOpInfo.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" #include "llvm/ADT/ArrayRef.h" @@ -37,8 +39,8 @@ using llvm::divideCeil; using llvm::APIntOps::GreatestCommonDivisor; -// The default number of tiles along K dimension to use per workgroup. -constexpr unsigned numTilesPerSubgroupDimK = 2; +// The default number of tiles along K dimension to use per subgroup/workgroup. +constexpr unsigned numKTilesPerSubgroup = 2; constexpr int kMaxVectorNumBits = 128; @@ -55,28 +57,6 @@ using CodeGenPipeline = IREE::Codegen::DispatchLoweringPassPipeline; // Utility Functions //===----------------------------------------------------------------------===// -bool isMatmulOrBatchMatmul(linalg::LinalgOp linalgOp) { - // (Batch) matmul should be a reduction op with 2/3 parallel dimensions. - if (!linalg::isaContractionOpInterface(linalgOp) || - !llvm::is_contained({2u, 3u}, linalgOp.getNumParallelLoops())) - return false; - - // Also exclude the case of matvec, which has only one non-unit parallel dim. - // They should go down different pipelines. - int nonUnitParallelDimCount = 0; - SmallVector bounds = linalgOp.getStaticLoopRanges(); - FailureOr contractionDims = - mlir::linalg::inferContractionDims(linalgOp); - assert(succeeded(contractionDims) && "Could not infer contraction dims"); - for (auto mDim : contractionDims->m) { - nonUnitParallelDimCount += bounds[mDim] != 1; - } - for (auto nDim : contractionDims->n) { - nonUnitParallelDimCount += bounds[nDim] != 1; - } - return nonUnitParallelDimCount > 1; -} - // Check if the given linalg op is fused with another op that may result // in too much shared memory usage. static bool fusedOpMayUseExtraSharedMemory(linalg::LinalgOp matmul) { @@ -862,118 +842,6 @@ bool needToPrmoteCForCooperativeMatrix(linalg::LinalgOp matmulOp) { return true; // Be conservative. } -struct CooperativeMatrixSize { - int64_t mSize; // Native cooperative matrix size along M dimension - int64_t nSize; // Native cooperative matrix size along N dimension - int64_t kSize; // Native cooperative matrix size along K dimension - int64_t mWarpCount; // # subgroups along M dimension - int64_t nWarpCount; // # subgroups along N dimension - int64_t mTileCount; // # tiles per subgroup along M dimension - int64_t nTileCount; // # tiles per subgroup along N dimension - int64_t kTileCount; // # tiles along K dimension -}; - -/// Returns the cooperative matrix (M, N, K) sizes that are supported by the -/// target environment and match the given parameters. -static std::optional -getCooperativeMatrixSize(spirv::ResourceLimitsAttr resourceLimits, - const unsigned numSubgroupsPerWorkgroup, - const unsigned numMNTilesPerSubgroup, Type aType, - Type bType, Type cType, int64_t m, int64_t n, - int64_t k) { - auto properties = - resourceLimits.getCooperativeMatrixPropertiesKhr() - .getAsRange(); - for (auto property : properties) { - if (property.getAType() != aType || property.getBType() != bType || - property.getCType() != cType || property.getResultType() != cType || - property.getScope().getValue() != spirv::Scope::Subgroup) { - continue; // Cannot use this cooperative matrix configuration - } - - const unsigned matmulM = property.getMSize(); - const unsigned matmulN = property.getNSize(); - const unsigned matmulK = property.getKSize(); - if (m % matmulM != 0 || n % matmulN != 0 || k % matmulK != 0) - continue; - - uint64_t nTotalTileCount = n / matmulN; - uint64_t mTotalTileCount = m / matmulM; - - uint64_t remainingWarps = numSubgroupsPerWorkgroup; - uint64_t remainingTiles = numMNTilesPerSubgroup; - // Assign more warps to the M dimension (used later) to balance thread - // counts along X and Y dimensions. - uint64_t warpSqrt = 1ull << (divideCeil(llvm::Log2_64(remainingWarps), 2)); - uint64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2); - - int64_t mWarpCount = 0, nWarpCount = 0; - int64_t mTileCount = 0, nTileCount = 0; - - // See if the square root can divide mTotalTileCount. If so it means we can - // distribute to both dimensions evenly. Otherwise, try to distribute to N - // and then M. - if (mTotalTileCount > (warpSqrt * tileSqrt) && - mTotalTileCount % (warpSqrt * tileSqrt) == 0) { - mWarpCount = warpSqrt; - mTileCount = tileSqrt; - - remainingWarps /= warpSqrt; - remainingTiles /= tileSqrt; - - APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), - APInt(64, remainingWarps)); - nWarpCount = nGCD.getSExtValue(); - nTotalTileCount /= nWarpCount; - remainingWarps /= nWarpCount; - - nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), - APInt(64, remainingTiles)); - nTileCount = nGCD.getSExtValue(); - } else { - APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), - APInt(64, remainingWarps)); - nWarpCount = nGCD.getSExtValue(); - nTotalTileCount /= nWarpCount; - remainingWarps /= nWarpCount; - - nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), - APInt(64, remainingTiles)); - nTileCount = nGCD.getSExtValue(); - remainingTiles /= nTileCount; - - APInt mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount), - APInt(64, remainingWarps)); - mWarpCount = mGCD.getSExtValue(); - mTotalTileCount /= mWarpCount; - remainingWarps /= mWarpCount; - - mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount), - APInt(64, remainingTiles)); - mTileCount = mGCD.getSExtValue(); - } - - const uint64_t kTotalTileCount = k / matmulK; - APInt kGCD = GreatestCommonDivisor(APInt(64, kTotalTileCount), - APInt(64, numTilesPerSubgroupDimK)); - int64_t kTileCount = kGCD.getSExtValue(); - - LLVM_DEBUG({ - llvm::dbgs() << "chosen cooperative matrix configuration:\n"; - llvm::dbgs() << " (M, N, K) size = (" << matmulM << ", " << matmulN - << ", " << matmulK << ")\n"; - llvm::dbgs() << " (M, N) subgroup count = (" << mWarpCount << ", " - << nWarpCount << ")\n"; - llvm::dbgs() << " (M, N, K) tile count per subgroup = (" << mTileCount - << ", " << nTileCount << ", " << kTileCount << ")\n"; - }); - return CooperativeMatrixSize{matmulM, matmulN, matmulK, - mWarpCount, nWarpCount, mTileCount, - nTileCount, kTileCount}; - } - return std::nullopt; -} - namespace detail { LogicalResult setCooperativeMatrixConfig( @@ -1022,12 +890,29 @@ LogicalResult setCooperativeMatrixConfig( return llvm::cast(v.getType()).getElementType(); }; + Type lhsElem = getElementType(lhs); + Type rhsElem = getElementType(rhs); + Type initElem = getElementType(init); + GPUMatmulShapeType problem(dimM, dimN, dimK, lhsElem, rhsElem, initElem); + spirv::ResourceLimitsAttr limits = targetEnv.getResourceLimits(); - std::optional coopMatSize = getCooperativeMatrixSize( - limits, numSubgroupsPerWorkgroup, numMNTilesPerSubgroup, - getElementType(lhs), getElementType(rhs), getElementType(init), dimM, - dimN, dimK); - if (!coopMatSize) + auto properties = + limits.getCooperativeMatrixPropertiesKhr() + .getAsRange(); + + SmallVector intrinsics; + intrinsics.reserve(limits.getCooperativeMatrixPropertiesKhr().size()); + for (auto p : properties) { + intrinsics.emplace_back(p.getMSize(), p.getNSize(), p.getKSize(), + p.getAType(), p.getBType(), p.getCType()); + } + + GPUMMAHeuristicSeeds seeds{numSubgroupsPerWorkgroup, numMNTilesPerSubgroup, + numKTilesPerSubgroup}; + + std::optional schedule = + deduceMMASchedule(problem, intrinsics, seeds); + if (!schedule) return failure(); auto pipeline = CodeGenPipeline::SPIRVCooperativeMatrixVectorize; @@ -1040,36 +925,34 @@ LogicalResult setCooperativeMatrixConfig( subgroupSize = *minSize; } - std::array workgroupSize{coopMatSize->nWarpCount * *subgroupSize, - coopMatSize->mWarpCount, 1}; + std::array workgroupSize{schedule->nWarpCount * *subgroupSize, + schedule->mWarpCount, 1}; SmallVector vectorSizes(kIndex + 1, 0); if (isBM) vectorSizes[bIndex] = 1; - vectorSizes[mIndex] = coopMatSize->mSize; - vectorSizes[nIndex] = coopMatSize->nSize; - vectorSizes[kIndex] = coopMatSize->kSize; + vectorSizes[mIndex] = schedule->mSize; + vectorSizes[nIndex] = schedule->nSize; + vectorSizes[kIndex] = schedule->kSize; SmallVector subgroupTileSizes(lastParallelDim + 1, 0); if (isBM) subgroupTileSizes[bIndex] = 1; - subgroupTileSizes[mIndex] = coopMatSize->mTileCount * vectorSizes[mIndex]; - subgroupTileSizes[nIndex] = coopMatSize->nTileCount * vectorSizes[nIndex]; + subgroupTileSizes[mIndex] = schedule->mTileCount * vectorSizes[mIndex]; + subgroupTileSizes[nIndex] = schedule->nTileCount * vectorSizes[nIndex]; SmallVector workgroupTileSizes(lastParallelDim + 1, 0); if (isBM) workgroupTileSizes[bIndex] = 1; - workgroupTileSizes[mIndex] = - coopMatSize->mWarpCount * subgroupTileSizes[mIndex]; - workgroupTileSizes[nIndex] = - coopMatSize->nWarpCount * subgroupTileSizes[nIndex]; + workgroupTileSizes[mIndex] = schedule->mWarpCount * subgroupTileSizes[mIndex]; + workgroupTileSizes[nIndex] = schedule->nWarpCount * subgroupTileSizes[nIndex]; // Also create one level for reduction. This is needed because of // SPIRVTileAndPromotePass requires it. // TODO(#10499): Consolidate tiling configuration across different pipelines. SmallVector reductionTileSizes; reductionTileSizes.append(kIndex, 0); - reductionTileSizes.push_back(coopMatSize->kTileCount * coopMatSize->kSize); + reductionTileSizes.push_back(schedule->kTileCount * schedule->kSize); TileSizesListType tileSizes; tileSizes.reserve(3); @@ -1081,7 +964,7 @@ LogicalResult setCooperativeMatrixConfig( // Don't do multibuffering if the inner reduction loop is folded out. auto pipelineDepth = softwarePipelineDepth; auto storeStage = softwarePipelineStoreStage; - if (coopMatSize->kTileCount <= 1) { + if (schedule->kTileCount <= 1) { pipelineDepth = 0; storeStage = 0; } diff --git a/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel index b2d6f6f9f6e0..733142879ae5 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel @@ -66,3 +66,19 @@ iree_compiler_cc_library( "@llvm-project//mlir:ViewLikeInterface", ], ) + +iree_compiler_cc_library( + name = "VectorOpUtils", + srcs = [ + "VectorOpUtils.cpp", + ], + hdrs = [ + "VectorOpUtils.h", + ], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:VectorDialect", + ], +) diff --git a/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt index ef6d8d2a5696..a99d7b3e29af 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt @@ -62,4 +62,19 @@ iree_cc_library( PUBLIC ) +iree_cc_library( + NAME + VectorOpUtils + HDRS + "VectorOpUtils.h" + SRCS + "VectorOpUtils.cpp" + DEPS + LLVMSupport + MLIRIR + MLIRSupport + MLIRVectorDialect + PUBLIC +) + ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinalgOpInfo.cpp b/compiler/src/iree/compiler/Codegen/Utils/LinalgOpInfo.cpp index 304ceea3dbd0..624af259254f 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/LinalgOpInfo.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/LinalgOpInfo.cpp @@ -73,7 +73,8 @@ LinalgOpInfo::LinalgOpInfo(linalg::LinalgOp linalgOp, // * Consider transpose + reductions. // * Consider input and output transposes. static SmallVector -computeTransposeInfo(LinalgOp linalgOp, TransposeMapFilter transposeMapFilter) { +computeTransposeInfo(LinalgOp linalgOp, + LinalgOpInfo::TransposeMapFilter transposeMapFilter) { SmallVector transposeOperands; // Reductions are not supported. @@ -124,4 +125,26 @@ void LinalgOpInfo::computeInfo(LinalgOp linalgOp) { dynamicTrait = computeDynamicInfo(linalgOp); } +bool isMatmulOrBatchMatmul(linalg::LinalgOp linalgOp) { + // (Batch) matmul should be a reduction op with 2/3 parallel dimensions. + if (!linalg::isaContractionOpInterface(linalgOp) || + !llvm::is_contained({2u, 3u}, linalgOp.getNumParallelLoops())) + return false; + + // Also exclude the case of matvec, which has only one non-unit parallel dim. + // They should go down different pipelines. + int nonUnitParallelDimCount = 0; + SmallVector bounds = linalgOp.getStaticLoopRanges(); + FailureOr contractionDims = + mlir::linalg::inferContractionDims(linalgOp); + assert(succeeded(contractionDims) && "Could not infer contraction dims"); + for (auto mDim : contractionDims->m) { + nonUnitParallelDimCount += bounds[mDim] != 1; + } + for (auto nDim : contractionDims->n) { + nonUnitParallelDimCount += bounds[nDim] != 1; + } + return nonUnitParallelDimCount > 1; +} + } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinalgOpInfo.h b/compiler/src/iree/compiler/Codegen/Utils/LinalgOpInfo.h index 1bfb661c5a00..26de493f2466 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/LinalgOpInfo.h +++ b/compiler/src/iree/compiler/Codegen/Utils/LinalgOpInfo.h @@ -6,22 +6,17 @@ #ifndef IREE_COMPILER_CODEGEN_COMMON_LINALGOPINFO_H_ #define IREE_COMPILER_CODEGEN_COMMON_LINALGOPINFO_H_ -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" -namespace mlir::linalg { -class LinalgOp; -} // namespace mlir::linalg +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/IR/AffineMap.h" namespace mlir::iree_compiler { -/// Returns true if a map represents the appropriate transpose. Pass this into -/// the LinalgOpInfo for additional transpose granularity. -using TransposeMapFilter = std::function; - class LinalgOpInfo { public: + /// Returns true if a map represents a chosen transpose granularity. + using TransposeMapFilter = std::function; + LinalgOpInfo(linalg::LinalgOp linalgOp); LinalgOpInfo(linalg::LinalgOp linalgOp, TransposeMapFilter transposeMapFilter); @@ -43,6 +38,10 @@ class LinalgOpInfo { SmallVector transposeOperands; }; +// Returns true if the given |linalgOp| is a matmul or batch matmul. +// This also looks into the shape to filter out cases like matvec. +bool isMatmulOrBatchMatmul(linalg::LinalgOp linalgOp); + } // namespace mlir::iree_compiler #endif // IREE_COMPILER_CODEGEN_COMMON_LINALGOPINFO_H_ diff --git a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp new file mode 100644 index 000000000000..826ba1d93b7e --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp @@ -0,0 +1,68 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Utils/VectorOpUtils.h" + +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" + +namespace mlir::iree_compiler { + +std::optional> +VectorContractOpInfo::getOperandMNIndex() const { + switch (opKind) { + case OpKind::MK_KN_MN: + return std::make_pair(0, 1); + case OpKind::MK_NK_MN: + return std::make_pair(0, 0); + case OpKind::UNKNOWN: + break; + } + return std::nullopt; +} + +// Returns the (LHS K, RHS K) dimension index pair. +std::optional> +VectorContractOpInfo::getOperandKIndex() const { + switch (opKind) { + case OpKind::MK_KN_MN: + return std::make_pair(1, 0); + case OpKind::MK_NK_MN: + return std::make_pair(1, 1); + case OpKind::UNKNOWN: + break; + } + return std::nullopt; +} + +// Returns the result (M, N) dimension index pair. +std::optional> +VectorContractOpInfo::getResultMNIndex() const { + switch (opKind) { + case OpKind::MK_KN_MN: + case OpKind::MK_NK_MN: + return std::make_pair(0, 1); + default: + break; + } + return std::nullopt; +} + +VectorContractOpInfo::OpKind +VectorContractOpInfo::inferOpKind(MLIRContext *ctx, + SmallVector maps) const { + using MapList = ArrayRef>; + auto infer = [&](MapList m) { return AffineMap::inferFromExprList(m, ctx); }; + AffineExpr m, n, k; + bindDims(ctx, m, n, k); + if (maps == infer({{m, k}, {k, n}, {m, n}})) + return OpKind::MK_KN_MN; + if (maps == infer({{m, k}, {n, k}, {m, n}})) + return OpKind::MK_NK_MN; + return OpKind::UNKNOWN; +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h new file mode 100644 index 000000000000..be1ba22a721d --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h @@ -0,0 +1,38 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +namespace mlir::iree_compiler { + +/// A class for querying information about a contract op. +class VectorContractOpInfo { +public: + enum class OpKind { MK_KN_MN, MK_NK_MN, UNKNOWN }; + + explicit VectorContractOpInfo(vector::ContractionOp op) { + opKind = inferOpKind(op.getContext(), op.getIndexingMapsArray()); + } + + OpKind getOpKind() const { return opKind; } + + // Returns the (LHS M, RHS N) dimension index pair. + std::optional> getOperandMNIndex() const; + + // Returns the (LHS K, RHS K) dimension index pair. + std::optional> getOperandKIndex() const; + + // Returns the result (M, N) dimension index pair. + std::optional> getResultMNIndex() const; + +private: + // Gets the kind of a contract op with the given indexing |maps|. + OpKind inferOpKind(MLIRContext *ctx, SmallVector maps) const; + + OpKind opKind = OpKind::UNKNOWN; +}; + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp index 463dc5409043..ad398f953e49 100644 --- a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp +++ b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp @@ -366,6 +366,13 @@ struct JitGlobalsPass : public JitGlobalsBase { const SupportedFeatures getSupportedFeatures(MLIRContext *context) { SupportedFeatures s; Builder b(context); + + // Exclude vmvx backend since there is no i4 support there causing + // the `eval_i4_tensor` test in `jit_globals.mlir` to fail. + // TODO(#16321): Enable on other backends once this has been tested + // outside llvm-cpu. + if (requestedTargetBackend == "llvm-cpu" && hasRequestedTargetBackend) + s.addScalarType(b.getIntegerType(4)); s.addScalarType(b.getIntegerType(8)); s.addScalarType(b.getIntegerType(16)); s.addScalarType(b.getIntegerType(32)); @@ -373,6 +380,11 @@ struct JitGlobalsPass : public JitGlobalsBase { s.addScalarType(b.getF32Type()); s.addElementType(b.getIntegerType(1)); + + // TODO(#16321): Enable on other backends once this has been tested outside + // llvm-cpu. + if (requestedTargetBackend == "llvm-cpu" && hasRequestedTargetBackend) + s.addElementType(b.getIntegerType(4)); s.addElementType(b.getIntegerType(8)); s.addElementType(b.getIntegerType(16)); s.addElementType(b.getIntegerType(32)); diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel index e67142fcf26f..4ec86a35052d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel @@ -22,6 +22,7 @@ iree_td_library( name = "td_files", srcs = enforce_glob( [ + "HALAttrs.td", "HALBase.td", "HALDialect.td", "HALInterfaces.td", @@ -42,6 +43,7 @@ iree_td_library( iree_compiler_cc_library( name = "IR", srcs = [ + "HALAttrs.cpp", "HALOpFolders.cpp", "HALOps.cpp", "HALTypes.cpp", @@ -66,9 +68,9 @@ iree_compiler_cc_library( "HALTypeInterfaces.h.inc", ], deps = [ + ":HALAttrsGen", ":HALInterfacesGen", ":HALOpsGen", - ":HALTypesGen", "//compiler/src/iree/compiler/Dialect/Stream/IR", "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Utils", @@ -111,6 +113,37 @@ iree_compiler_cc_library( ], ) +iree_gentbl_cc_library( + name = "HALAttrsGen", + tbl_outs = [ + ( + [ + "--gen-attrdef-decls", + "--attrdefs-dialect=hal", + ], + "HALAttrs.h.inc", + ), + ( + [ + "--gen-attrdef-defs", + "--attrdefs-dialect=hal", + ], + "HALAttrs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "HALEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "HALEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "HALAttrs.td", + deps = [":td_files"], +) + iree_gentbl_cc_library( name = "HALInterfacesGen", tbl_outs = [ @@ -161,37 +194,6 @@ iree_gentbl_cc_library( deps = [":td_files"], ) -iree_gentbl_cc_library( - name = "HALTypesGen", - tbl_outs = [ - ( - [ - "--gen-attrdef-decls", - "--attrdefs-dialect=hal", - ], - "HALAttrs.h.inc", - ), - ( - [ - "--gen-attrdef-defs", - "--attrdefs-dialect=hal", - ], - "HALAttrs.cpp.inc", - ), - ( - ["--gen-enum-decls"], - "HALEnums.h.inc", - ), - ( - ["--gen-enum-defs"], - "HALEnums.cpp.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "HALBase.td", - deps = [":td_files"], -) - iree_tablegen_doc( name = "HALDialectDocGen", tbl_outs = [ diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt index 701d8f00c8be..16b490ea3fc4 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt @@ -31,13 +31,14 @@ iree_cc_library( "HALTypeInterfaces.cpp.inc" "HALTypeInterfaces.h.inc" SRCS + "HALAttrs.cpp" "HALOpFolders.cpp" "HALOps.cpp" "HALTypes.cpp" DEPS + ::HALAttrsGen ::HALInterfacesGen ::HALOpsGen - ::HALTypesGen LLVMSupport MLIRArithDialect MLIRControlFlowDialect @@ -83,6 +84,18 @@ iree_cc_library( PUBLIC ) +iree_tablegen_library( + NAME + HALAttrsGen + TD_FILE + "HALAttrs.td" + OUTS + --gen-attrdef-decls --attrdefs-dialect=hal HALAttrs.h.inc + --gen-attrdef-defs --attrdefs-dialect=hal HALAttrs.cpp.inc + --gen-enum-decls HALEnums.h.inc + --gen-enum-defs HALEnums.cpp.inc +) + iree_tablegen_library( NAME HALInterfacesGen @@ -107,18 +120,6 @@ iree_tablegen_library( --gen-op-defs HALOps.cpp.inc ) -iree_tablegen_library( - NAME - HALTypesGen - TD_FILE - "HALBase.td" - OUTS - --gen-attrdef-decls --attrdefs-dialect=hal HALAttrs.h.inc - --gen-attrdef-defs --attrdefs-dialect=hal HALAttrs.cpp.inc - --gen-enum-decls HALEnums.h.inc - --gen-enum-defs HALEnums.cpp.inc -) - iree_tablegen_doc( NAME HALDialectDocGen diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp new file mode 100644 index 000000000000..8ad4f14016e3 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp @@ -0,0 +1,777 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "iree/compiler/Utils/StringUtils.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" +#include "mlir/Parser/Parser.h" + +// clang-format off: must be included after all LLVM/MLIR headers. +#define GET_ATTRDEF_CLASSES +#include "iree/compiler/Dialect/HAL/IR/HALAttrs.cpp.inc" // IWYU pragma: keep +#include "iree/compiler/Dialect/HAL/IR/HALEnums.cpp.inc" // IWYU pragma: keep +// clang-format on + +namespace mlir::iree_compiler::IREE::HAL { + +//===----------------------------------------------------------------------===// +// Enum utilities +//===----------------------------------------------------------------------===// + +template +static LogicalResult parseEnumAttr(AsmParser &parser, StringRef attrName, + AttrType &attr) { + Attribute genericAttr; + auto loc = parser.getCurrentLocation(); + if (failed(parser.parseAttribute(genericAttr, + parser.getBuilder().getNoneType()))) { + return parser.emitError(loc) + << "failed to parse '" << attrName << "' enum string value"; + } + auto stringAttr = llvm::dyn_cast(genericAttr); + if (!stringAttr) { + return parser.emitError(loc) + << "expected " << attrName << " attribute specified as string"; + } + auto symbolized = + symbolizeEnum(stringAttr.getValue()); + if (!symbolized.hasValue()) { + return parser.emitError(loc) + << "failed to parse '" << attrName << "' enum value"; + } + attr = AttrType::get(parser.getBuilder().getContext(), symbolized.getValue()); + return success(); +} + +template +static LogicalResult parseOptionalEnumAttr(AsmParser &parser, + StringRef attrName, AttrType &attr) { + if (succeeded(parser.parseOptionalQuestion())) { + // Special case `?` to indicate any/none/undefined/etc. + attr = AttrType::get(parser.getBuilder().getContext(), 0); + return success(); + } + return parseEnumAttr(parser, attrName, attr); +} + +//===----------------------------------------------------------------------===// +// #hal.collective<*> +//===----------------------------------------------------------------------===// + +// See the iree/hal/command_buffer.h iree_hal_collective_op_t for details. +uint32_t CollectiveAttr::getEncodedValue() const { + union { + uint32_t packed; // packed value + struct { + uint8_t kind; + uint8_t reduction; + uint8_t elementType; + uint8_t reserved; + }; + } value = {0}; + value.kind = static_cast(getKind()); + value.reduction = static_cast( + getReduction().value_or(CollectiveReductionOp::None)); + value.elementType = static_cast(getElementType()); + return value.packed; +} + +//===----------------------------------------------------------------------===// +// #hal.device.target<*> +//===----------------------------------------------------------------------===// + +// static +DeviceTargetAttr DeviceTargetAttr::get(MLIRContext *context, + StringRef deviceID) { + // TODO(benvanik): query default configuration from the target backend. + return get(context, StringAttr::get(context, deviceID), + DictionaryAttr::get(context)); +} + +// static +Attribute DeviceTargetAttr::parse(AsmParser &p, Type type) { + StringAttr deviceIDAttr; + DictionaryAttr configAttr; + // `<"device-id"` + if (failed(p.parseLess()) || failed(p.parseAttribute(deviceIDAttr))) { + return {}; + } + // `, {config}` + if (succeeded(p.parseOptionalComma()) && + failed(p.parseAttribute(configAttr))) { + return {}; + } + // `>` + if (failed(p.parseGreater())) { + return {}; + } + return get(p.getContext(), deviceIDAttr, configAttr); +} + +void DeviceTargetAttr::print(AsmPrinter &p) const { + auto &os = p.getStream(); + os << "<"; + p.printAttribute(getDeviceID()); + auto configAttr = getConfiguration(); + if (configAttr && !configAttr.empty()) { + os << ", "; + p.printAttribute(configAttr); + } + os << ">"; +} + +std::string DeviceTargetAttr::getSymbolNameFragment() { + return sanitizeSymbolName(getDeviceID().getValue().lower()); +} + +bool DeviceTargetAttr::hasConfigurationAttr(StringRef name) { + auto configAttr = getConfiguration(); + return configAttr && configAttr.get(name); +} + +SmallVector DeviceTargetAttr::getExecutableTargets() { + SmallVector resultAttrs; + auto configAttr = getConfiguration(); + if (configAttr) { + auto targetsAttr = configAttr.getAs("executable_targets"); + if (targetsAttr) { + for (auto attr : targetsAttr.getValue()) { + resultAttrs.push_back(llvm::dyn_cast(attr)); + } + } + } + return resultAttrs; +} + +// static +SmallVector +DeviceTargetAttr::lookup(Operation *op) { + auto attrId = mlir::StringAttr::get(op->getContext(), "hal.device.targets"); + while (op) { + auto targetsAttr = op->getAttrOfType(attrId); + if (targetsAttr) { + SmallVector result; + for (auto targetAttr : targetsAttr) { + result.push_back(llvm::cast(targetAttr)); + } + return result; + } + op = op->getParentOp(); + } + return {}; // No devices found; let caller decide what to do. +} + +// Returns a set of all configuration attributes from all device targets with +// a configuration set. Targets with no configuration set are ignored. +static SmallVector lookupOptionalConfigAttrs(Operation *op) { + auto targetAttrs = IREE::HAL::DeviceTargetAttr::lookup(op); + if (targetAttrs.empty()) + return {}; + SmallVector configAttrs; + for (auto targetAttr : targetAttrs) { + auto configAttr = targetAttr.getConfiguration(); + if (configAttr) + configAttrs.push_back(configAttr); + } + return configAttrs; +} + +// Returns a set of all configuration attributes from all device targets. +// Returns nullopt if any target is missing a configuration attribute. +static std::optional> +lookupRequiredConfigAttrs(Operation *op) { + auto targetAttrs = IREE::HAL::DeviceTargetAttr::lookup(op); + if (targetAttrs.empty()) + return std::nullopt; + SmallVector configAttrs; + for (auto targetAttr : targetAttrs) { + auto configAttr = targetAttr.getConfiguration(); + if (!configAttr) + return std::nullopt; + configAttrs.push_back(configAttr); + } + return configAttrs; +} + +template +static std::optional joinConfigAttrs( + ArrayRef configAttrs, StringRef name, + std::function + join) { + if (configAttrs.empty()) + return std::nullopt; + auto firstValue = configAttrs.front().getAs(name); + if (!firstValue) + return std::nullopt; + auto result = firstValue.getValue(); + for (auto configAttr : configAttrs.drop_front(1)) { + auto value = configAttr.getAs(name); + if (!value) + return std::nullopt; + result = join(result, value.getValue()); + } + return result; +} + +template +static std::optional> +joinConfigStaticRanges(ArrayRef configAttrs, StringRef name, + std::function( + StaticRange, + StaticRange)> + join) { + if (configAttrs.empty()) + return std::nullopt; + auto firstValue = configAttrs.front().getAs(name); + if (!firstValue) + return std::nullopt; + StaticRange result{firstValue.getValue()}; + for (auto configAttr : configAttrs.drop_front(1)) { + auto value = configAttr.getAs(name); + if (!value) + return std::nullopt; + result = + join(result, StaticRange{value.getValue()}); + } + return result; +} + +// static +bool DeviceTargetAttr::lookupConfigAttrAny(Operation *op, StringRef name) { + auto configAttrs = lookupOptionalConfigAttrs(op); + if (configAttrs.empty()) + return false; + for (auto configAttr : configAttrs) { + if (configAttr.get(name)) + return true; + } + return false; +} + +// static +bool DeviceTargetAttr::lookupConfigAttrAll(Operation *op, StringRef name) { + auto configAttrs = lookupRequiredConfigAttrs(op); + if (!configAttrs) + return false; + for (auto configAttr : *configAttrs) { + if (!configAttr.get(name)) + return false; + } + return true; +} + +// static +std::optional DeviceTargetAttr::lookupConfigAttrAnd(Operation *op, + StringRef name) { + auto configAttrs = lookupRequiredConfigAttrs(op); + if (!configAttrs) + return std::nullopt; + return joinConfigAttrs( + configAttrs.value(), name, [](bool lhs, bool rhs) { return lhs && rhs; }); +} + +// static +std::optional DeviceTargetAttr::lookupConfigAttrOr(Operation *op, + StringRef name) { + auto configAttrs = lookupRequiredConfigAttrs(op); + if (!configAttrs) + return std::nullopt; + return joinConfigAttrs( + configAttrs.value(), name, [](bool lhs, bool rhs) { return lhs || rhs; }); +} + +// static +std::optional> +DeviceTargetAttr::lookupConfigAttrRange(Operation *op, StringRef name) { + auto configAttrs = lookupRequiredConfigAttrs(op); + if (!configAttrs) + return std::nullopt; + return joinConfigStaticRanges( + configAttrs.value(), name, + [](StaticRange lhs, StaticRange rhs) { + return StaticRange{ + llvm::APIntOps::smin(lhs.min, rhs.min), + llvm::APIntOps::smax(lhs.max, rhs.max), + }; + }); +} + +// static +SmallVector +DeviceTargetAttr::lookupExecutableTargets(Operation *op) { + SmallVector resultAttrs; + for (auto deviceTargetAttr : lookup(op)) { + for (auto executableTargetAttr : deviceTargetAttr.getExecutableTargets()) { + if (!llvm::is_contained(resultAttrs, executableTargetAttr)) { + resultAttrs.push_back(executableTargetAttr); + } + } + } + return resultAttrs; +} + +//===----------------------------------------------------------------------===// +// #hal.executable.target<*> +//===----------------------------------------------------------------------===// + +// static +ExecutableTargetAttr ExecutableTargetAttr::get(MLIRContext *context, + StringRef backend, + StringRef format) { + return get(context, StringAttr::get(context, backend), + StringAttr::get(context, format), DictionaryAttr::get(context)); +} + +// static +Attribute ExecutableTargetAttr::parse(AsmParser &p, Type type) { + StringAttr backendAttr; + StringAttr formatAttr; + DictionaryAttr configurationAttr; + // `<"backend", "format"` + if (failed(p.parseLess()) || failed(p.parseAttribute(backendAttr)) || + failed(p.parseComma()) || failed(p.parseAttribute(formatAttr))) { + return {}; + } + // `, {config}` + if (succeeded(p.parseOptionalComma()) && + failed(p.parseAttribute(configurationAttr))) { + return {}; + } + // `>` + if (failed(p.parseGreater())) { + return {}; + } + return get(p.getContext(), backendAttr, formatAttr, configurationAttr); +} + +void ExecutableTargetAttr::print(AsmPrinter &p) const { + auto &os = p.getStream(); + os << "<"; + p.printAttribute(getBackend()); + os << ", "; + p.printAttribute(getFormat()); + auto config = getConfiguration(); + if (config && !config.empty()) { + os << ", "; + p.printAttribute(config); + } + os << ">"; +} + +std::string ExecutableTargetAttr::getSymbolNameFragment() const { + return sanitizeSymbolName(getFormat().getValue().lower()); +} + +bool ExecutableTargetAttr::hasConfigurationAttr(StringRef name) { + auto configAttr = getConfiguration(); + return configAttr && configAttr.get(name); +} + +// For now this is very simple: if there are any specified fields that are +// present in this attribute they must match. We could allow target backends +// to customize this via attribute interfaces in the future if we needed. +bool ExecutableTargetAttr::isGenericOf( + IREE::HAL::ExecutableTargetAttr specificAttr) { + if (getBackend() != specificAttr.getBackend() || + getFormat() != specificAttr.getFormat()) { + // Totally different backends and binary formats. + // There may be cases where we want to share things - such as when targeting + // both DLLs and dylibs or something - but today almost all of these are + // unique situations. + return false; + } + + // If the config is empty on either we can quickly match. + // This is the most common case for users manually specifying targets. + auto genericConfigAttr = getConfiguration(); + auto specificConfigAttr = specificAttr.getConfiguration(); + if (!genericConfigAttr || !specificConfigAttr) + return true; + + // Ensure all fields in specificConfigAttr either don't exist or match. + for (auto expectedAttr : specificConfigAttr.getValue()) { + auto actualValue = genericConfigAttr.getNamed(expectedAttr.getName()); + if (!actualValue) { + continue; // ignore, not present in generic + } + if (actualValue->getValue() != expectedAttr.getValue()) { + return false; // mismatch, both have values but they differ + } + } + + // Ensure all fields in genericConfigAttr exist in the specific one. + // If missing then the generic is _more_ specific and can't match. + for (auto actualAttr : genericConfigAttr.getValue()) { + if (!specificConfigAttr.getNamed(actualAttr.getName())) { + return false; // mismatch, present in generic but not specific + } + } + + // All fields match or are omitted in the generic version. + return true; +} + +// static +ExecutableTargetAttr ExecutableTargetAttr::lookup(Operation *op) { + auto *context = op->getContext(); + auto attrId = StringAttr::get(context, "hal.executable.target"); + while (op) { + // Take directly from the enclosing variant. + if (auto variantOp = llvm::dyn_cast(op)) { + return variantOp.getTarget(); + } + // Use an override if specified. + auto attr = op->getAttrOfType(attrId); + if (attr) + return attr; + // Continue walk. + op = op->getParentOp(); + } + // No target found during walk. No default to provide so fail and let the + // caller decide what to do (assert/fallback/etc). + return nullptr; +} + +//===----------------------------------------------------------------------===// +// #hal.executable.object<*> +//===----------------------------------------------------------------------===// + +// static +Attribute ExecutableObjectAttr::parse(AsmParser &p, Type type) { + NamedAttrList dict; + // `<{` dict `}>` + if (failed(p.parseLess()) || failed(p.parseOptionalAttrDict(dict)) || + failed(p.parseGreater())) { + return {}; + } + auto pathAttr = llvm::dyn_cast_if_present(dict.get("path")); + auto dataAttr = + llvm::dyn_cast_if_present(dict.get("data")); + return get(p.getContext(), pathAttr, dataAttr); +} + +void ExecutableObjectAttr::print(AsmPrinter &p) const { + auto &os = p.getStream(); + os << "<{"; + if (auto pathAttr = getPath()) { + os << "path = "; + p.printAttribute(getPath()); + } + if (auto dataAttr = getData()) { + os << ", data = "; + p.printAttribute(getData()); + } + os << "}>"; +} + +// static +void ExecutableObjectAttr::filterObjects( + ArrayAttr objectAttrs, ArrayRef extensions, + SmallVectorImpl &filteredAttrs) { + if (!objectAttrs) + return; + for (auto objectAttr : + objectAttrs.getAsRange()) { + auto path = objectAttr.getPath(); + auto ext = llvm::sys::path::extension(path); + if (llvm::is_contained(extensions, ext)) { + filteredAttrs.push_back(objectAttr); + } + } +} + +// Tries to find |filePath| on disk either at its absolute path or joined with +// any of the specified |searchPaths| in order. +// Returns the absolute file path when found or a failure if there are no hits. +static FailureOr +findFileInPaths(StringRef filePath, ArrayRef searchPaths) { + // First try to see if it's an absolute path - we don't want to perform any + // additional processing on top of that. + if (llvm::sys::path::is_absolute(filePath)) { + if (llvm::sys::fs::exists(filePath)) + return filePath.str(); + return failure(); + } + + // Try a relative lookup from the current working directory. + if (llvm::sys::fs::exists(filePath)) + return filePath.str(); + + // Search each path in turn for a file that exists. + // It doesn't mean we can open it but we'll get a better error out of the + // actual open attempt than what we could produce here. + for (auto searchPath : searchPaths) { + SmallVector tryPath{searchPath.begin(), searchPath.end()}; + llvm::sys::path::append(tryPath, filePath); + if (llvm::sys::fs::exists(Twine(tryPath))) + return Twine(tryPath).str(); + } + + // Not found in either the user-specified absolute path, cwd, or the search + // paths. + return failure(); +} + +static llvm::cl::list clExecutableObjectSearchPath( + "iree-hal-executable-object-search-path", + llvm::cl::desc("Additional search paths for resolving " + "#hal.executable.object file references."), + llvm::cl::ZeroOrMore); + +FailureOr ExecutableObjectAttr::getAbsolutePath() { + auto pathAttr = getPath(); + if (!pathAttr) + return failure(); // not a file reference + return findFileInPaths(pathAttr.getValue(), clExecutableObjectSearchPath); +} + +std::optional ExecutableObjectAttr::loadData() { + if (auto dataAttr = getData()) { + // This is shady but so is using this feature. + // TODO(benvanik): figure out a way to limit the attribute to signless int8. + // We could share the attribute -> byte array code with the VM constant + // serialization if we wanted. + auto rawData = dataAttr.getRawData(); + return std::string(rawData.data(), rawData.size()); + } else if (auto pathAttr = getPath()) { + // Search for file and try to load it if found. + auto filePath = + findFileInPaths(pathAttr.getValue(), clExecutableObjectSearchPath); + if (failed(filePath)) { + llvm::errs() + << "ERROR: referenced object file not found on any path; use " + "--iree-hal-executable-object-search-path= to add search paths: " + << *this << "\n"; + return std::nullopt; + } + auto file = llvm::MemoryBuffer::getFile(*filePath); + if (!file) + return std::nullopt; + return std::string((*file)->getBuffer()); + } + return std::nullopt; +} + +//===----------------------------------------------------------------------===// +// #hal.executable.objects<*> +//===----------------------------------------------------------------------===// + +// static +LogicalResult ExecutableObjectsAttr::verify( + function_ref emitError, ArrayAttr targetsAttr, + ArrayAttr targetObjectsAttr) { + if (targetsAttr.size() != targetObjectsAttr.size()) { + return emitError() << "targets and objects must be 1:1"; + } + for (auto targetAttr : targetsAttr) { + if (!llvm::isa(targetAttr)) { + return emitError() + << "target keys must be #hal.executable.target attributes"; + } + } + for (auto objectsAttr : targetObjectsAttr) { + auto objectsArrayAttr = llvm::dyn_cast(objectsAttr); + if (!objectsArrayAttr) { + return emitError() << "target objects must be an array of " + "#hal.executable.object attributes"; + } + } + return success(); +} + +// static +Attribute ExecutableObjectsAttr::parse(AsmParser &p, Type type) { + // `<{` target = [objects, ...], ... `}>` + SmallVector targetAttrs; + SmallVector objectsAttrs; + if (failed(p.parseLess())) + return {}; + if (succeeded(p.parseLBrace()) && !succeeded(p.parseOptionalRBrace())) { + do { + Attribute targetAttr; + ArrayAttr objectsAttr; + if (failed(p.parseAttribute(targetAttr)) || failed(p.parseEqual()) || + failed(p.parseAttribute(objectsAttr))) { + return {}; + } + targetAttrs.push_back(targetAttr); + objectsAttrs.push_back(objectsAttr); + } while (succeeded(p.parseOptionalComma())); + if (failed(p.parseRBrace())) + return {}; + } + if (failed(p.parseGreater())) + return {}; + return get(p.getContext(), ArrayAttr::get(p.getContext(), targetAttrs), + ArrayAttr::get(p.getContext(), objectsAttrs)); +} + +void ExecutableObjectsAttr::print(AsmPrinter &p) const { + auto &os = p.getStream(); + os << "<{"; + llvm::interleaveComma(llvm::zip_equal(getTargets(), getTargetObjects()), os, + [&](std::tuple keyValue) { + p.printAttribute(std::get<0>(keyValue)); + os << " = "; + p.printAttributeWithoutType(std::get<1>(keyValue)); + }); + os << "}>"; +} + +std::optional ExecutableObjectsAttr::getApplicableObjects( + IREE::HAL::ExecutableTargetAttr specificTargetAttr) { + SmallVector allObjectAttrs; + for (auto [targetAttr, objectsAttr] : + llvm::zip_equal(getTargets(), getTargetObjects())) { + auto genericTargetAttr = + llvm::cast(targetAttr); + if (genericTargetAttr.isGenericOf(specificTargetAttr)) { + auto objectsArrayAttr = llvm::cast(objectsAttr); + allObjectAttrs.append(objectsArrayAttr.begin(), objectsArrayAttr.end()); + } + } + if (allObjectAttrs.empty()) + return std::nullopt; + return ArrayAttr::get(specificTargetAttr.getContext(), allObjectAttrs); +} + +//===----------------------------------------------------------------------===// +// #hal.affinity.queue<*> +//===----------------------------------------------------------------------===// + +// static +Attribute AffinityQueueAttr::parse(AsmParser &p, Type type) { + int64_t mask = 0; + // `<` + if (failed(p.parseLess())) + return {}; + // `*` (any) + if (succeeded(p.parseOptionalStar())) { + mask = -1; + } else { + // `[`queue_bit[, ...] `]` + if (failed(p.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() { + int64_t i = 0; + if (failed(p.parseInteger(i))) + return failure(); + mask |= 1ll << i; + return success(); + }))) { + return {}; + } + } + // `>` + if (failed(p.parseGreater())) + return {}; + return get(p.getContext(), mask); +} + +void AffinityQueueAttr::print(AsmPrinter &p) const { + auto &os = p.getStream(); + os << "<"; + int64_t mask = getMask(); + if (mask == -1) { + os << "*"; + } else { + os << "["; + for (int i = 0, j = 0; i < sizeof(mask) * 8; ++i) { + if (mask & (1ll << i)) { + if (j++ > 0) + os << ", "; + os << i; + } + } + os << "]"; + } + os << ">"; +} + +bool AffinityQueueAttr::isExecutableWith( + IREE::Stream::AffinityAttr other) const { + if (!other) + return true; + // Only compatible with other queue affinities today. When we extend the + // attributes to specify device targets we'd want to check here. + auto otherQueueAttr = llvm::dyn_cast_if_present(other); + if (!otherQueueAttr) + return false; + // If this affinity is a subset of the target affinity then it can execute + // with it. + if ((getMask() & otherQueueAttr.getMask()) == getMask()) + return true; + // Otherwise not compatible. + return false; +} + +IREE::Stream::AffinityAttr +AffinityQueueAttr::joinOR(IREE::Stream::AffinityAttr other) const { + if (!other) + return *this; + if (!IREE::Stream::AffinityAttr::canExecuteTogether(*this, other)) { + return nullptr; + } + auto otherQueueAttr = llvm::dyn_cast_if_present(other); + return AffinityQueueAttr::get(getContext(), + getMask() | otherQueueAttr.getMask()); +} + +IREE::Stream::AffinityAttr +AffinityQueueAttr::joinAND(IREE::Stream::AffinityAttr other) const { + if (!other) + return *this; + if (!IREE::Stream::AffinityAttr::canExecuteTogether(*this, other)) { + return nullptr; + } + auto otherQueueAttr = llvm::dyn_cast_if_present(other); + return AffinityQueueAttr::get(getContext(), + getMask() & otherQueueAttr.getMask()); +} + +//===----------------------------------------------------------------------===// +// IREE::HAL::HALDialect +//===----------------------------------------------------------------------===// + +// At the end so it can use functions above: +#include "iree/compiler/Dialect/HAL/IR/HALAttrInterfaces.cpp.inc" + +void HALDialect::registerAttributes() { + // Register command line flags: + (void)clExecutableObjectSearchPath; + + addAttributes< +#define GET_ATTRDEF_LIST +#include "iree/compiler/Dialect/HAL/IR/HALAttrs.cpp.inc" // IWYU pragma: keep + >(); +} + +Attribute HALDialect::parseAttribute(DialectAsmParser &parser, + Type type) const { + StringRef mnemonic; + Attribute genAttr; + OptionalParseResult parseResult = + generatedAttributeParser(parser, &mnemonic, type, genAttr); + if (parseResult.has_value()) + return genAttr; + parser.emitError(parser.getNameLoc()) + << "unknown HAL attribute: " << mnemonic; + return {}; +} + +void HALDialect::printAttribute(Attribute attr, DialectAsmPrinter &p) const { + TypeSwitch(attr).Default([&](Attribute) { + if (failed(generatedAttributePrinter(attr, p))) { + assert(false && "unhandled HAL attribute kind"); + } + }); +} + +} // namespace mlir::iree_compiler::IREE::HAL diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td new file mode 100644 index 000000000000..a03879b6bc14 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td @@ -0,0 +1,778 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_DIALECT_HAL_IR_HAL_ATTRS +#define IREE_DIALECT_HAL_IR_HAL_ATTRS + +include "iree/compiler/Dialect/HAL/IR/HALBase.td" +include "iree/compiler/Dialect/HAL/IR/HALInterfaces.td" +include "iree/compiler/Dialect/Stream/IR/StreamInterfaces.td" +include "iree/compiler/Dialect/Util/IR/UtilTypes.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" +include "mlir/IR/EnumAttr.td" + +//===----------------------------------------------------------------------===// +// General enums +//===----------------------------------------------------------------------===// + +// Wrapper over base I32EnumAttr to set common fields for HAL enums. +class HAL_I32Enum cases> + : I32EnumAttr { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; +} +class HAL_I32EnumAttr cases> + : EnumAttr, mnemonic> { + let assemblyFormat = "`<` $value `>`"; +} + +def HAL_MemoryModel_Unified : I32EnumAttrCase<"Unified", 0>; +def HAL_MemoryModel_Discrete : I32EnumAttrCase<"Discrete", 1>; +def HAL_MemoryModelAttr : + I32EnumAttr<"MemoryModel", "IREE HAL MemoryModel", [ + HAL_MemoryModel_Unified, + HAL_MemoryModel_Discrete, + ]> { + let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; +} + +def HAL_MemoryType_None : I32BitEnumAttrCase<"None", 0x0000>; // ? +def HAL_MemoryType_Optimal : I32BitEnumAttrCase<"Optimal", 0x0001>; // ! +def HAL_MemoryType_HostVisible : I32BitEnumAttrCase<"HostVisible", 0x0002>; // h +def HAL_MemoryType_HostCoherent : I32BitEnumAttrCase<"HostCoherent", 0x0004>; // c +def HAL_MemoryType_HostCached : I32BitEnumAttrCase<"HostCached", 0x0008>; // C +def HAL_MemoryType_HostLocal : I32BitEnumAttrCase<"HostLocal", 0x0046>; // H +def HAL_MemoryType_DeviceVisible : I32BitEnumAttrCase<"DeviceVisible", 0x0010>; // d +def HAL_MemoryType_DeviceLocal : I32BitEnumAttrCase<"DeviceLocal", 0x0030>; // D +def HAL_MemoryTypeBitfieldAttr : + I32BitEnumAttr<"MemoryTypeBitfield", "valid MemoryType", [ + HAL_MemoryType_None, + HAL_MemoryType_Optimal, + HAL_MemoryType_HostVisible, + HAL_MemoryType_HostCoherent, + HAL_MemoryType_HostCached, + HAL_MemoryType_HostLocal, + HAL_MemoryType_DeviceVisible, + HAL_MemoryType_DeviceLocal, + ]> { + let cppNamespace = "mlir::iree_compiler::IREE::HAL"; +} + +def HAL_MemoryAccess_None : I32BitEnumAttrCase<"None", 0x00000000>; +def HAL_MemoryAccess_Read : I32BitEnumAttrCase<"Read", 0x00000001>; +def HAL_MemoryAccess_Write : I32BitEnumAttrCase<"Write", 0x00000002>; +def HAL_MemoryAccess_Discard : I32BitEnumAttrCase<"Discard", 0x00000004>; +def HAL_MemoryAccess_MayAlias : I32BitEnumAttrCase<"MayAlias", 0x00000008>; +def HAL_MemoryAccess_Unaligned : I32BitEnumAttrCase<"Unaligned", 0x00000010>; +def HAL_MemoryAccess_Any : I32BitEnumAttrCase<"Any", 0x00000020>; +def HAL_MemoryAccessBitfieldAttr : + I32BitEnumAttr<"MemoryAccessBitfield", "valid MemoryAccess", [ + HAL_MemoryAccess_None, + HAL_MemoryAccess_Read, + HAL_MemoryAccess_Write, + HAL_MemoryAccess_Discard, + HAL_MemoryAccess_MayAlias, + HAL_MemoryAccess_Unaligned, + HAL_MemoryAccess_Any, + ]> { + let cppNamespace = "mlir::iree_compiler::IREE::HAL"; +} + +def HAL_BufferUsage_None : I32BitEnumAttrCase<"None", 0x00000000>; +def HAL_BufferUsage_TransferSource : I32BitEnumAttrCase<"TransferSource", 0x00000001>; +def HAL_BufferUsage_TransferTarget : I32BitEnumAttrCase<"TransferTarget", 0x00000002>; +def HAL_BufferUsage_Transfer : I32BitEnumAttrCase<"Transfer", 0x00000003>; +def HAL_BufferUsage_DispatchIndirectParams : I32BitEnumAttrCase<"DispatchIndirectParams", 0x00000100>; +def HAL_BufferUsage_DispatchUniformRead : I32BitEnumAttrCase<"DispatchUniformRead", 0x00000200>; +def HAL_BufferUsage_DispatchStorageRead : I32BitEnumAttrCase<"DispatchStorageRead", 0x00000400>; +def HAL_BufferUsage_DispatchStorageWrite : I32BitEnumAttrCase<"DispatchStorageWrite", 0x00000800>; +def HAL_BufferUsage_DispatchStorage : I32BitEnumAttrCase<"DispatchStorage", 0x00000C00>; +def HAL_BufferUsage_DispatchImageRead : I32BitEnumAttrCase<"DispatchImageRead", 0x00001000>; +def HAL_BufferUsage_DispatchImageWrite : I32BitEnumAttrCase<"DispatchImageWrite", 0x00002000>; +def HAL_BufferUsage_DispatchImage : I32BitEnumAttrCase<"DispatchImage", 0x00003000>; +def HAL_BufferUsage_SharingExport : I32BitEnumAttrCase<"SharingExport", 0x00010000>; +def HAL_BufferUsage_SharingReplicate : I32BitEnumAttrCase<"SharingReplicate", 0x00020000>; +def HAL_BufferUsage_SharingConcurrent : I32BitEnumAttrCase<"SharingConcurrent", 0x00040000>; +def HAL_BufferUsage_SharingImmutable : I32BitEnumAttrCase<"SharingImmutable", 0x00080000>; +def HAL_BufferUsage_MappingScoped : I32BitEnumAttrCase<"MappingScoped", 0x01000000>; +def HAL_BufferUsage_MappingPersistent : I32BitEnumAttrCase<"MappingPersistent", 0x02000000>; +def HAL_BufferUsage_MappingOptional : I32BitEnumAttrCase<"MappingOptional", 0x04000000>; +def HAL_BufferUsage_MappingAccessRandom : I32BitEnumAttrCase<"MappingAccessRandom", 0x08000000>; +def HAL_BufferUsage_MappingAccessSequentialWrite : I32BitEnumAttrCase<"MappingAccessSequentialWrite", 0x10000000>; +def HAL_BufferUsage_Mapping : I32BitEnumAttrCase<"Mapping", 0x09000000>; +def HAL_BufferUsageBitfieldAttr : + I32BitEnumAttr<"BufferUsageBitfield", "valid BufferUsage", [ + HAL_BufferUsage_None, + HAL_BufferUsage_TransferSource, + HAL_BufferUsage_TransferTarget, + HAL_BufferUsage_Transfer, + HAL_BufferUsage_DispatchIndirectParams, + HAL_BufferUsage_DispatchUniformRead, + HAL_BufferUsage_DispatchStorageRead, + HAL_BufferUsage_DispatchStorageWrite, + HAL_BufferUsage_DispatchStorage, + HAL_BufferUsage_DispatchImageRead, + HAL_BufferUsage_DispatchImageWrite, + HAL_BufferUsage_DispatchImage, + HAL_BufferUsage_SharingExport, + HAL_BufferUsage_SharingReplicate, + HAL_BufferUsage_SharingConcurrent, + HAL_BufferUsage_SharingImmutable, + HAL_BufferUsage_MappingScoped, + HAL_BufferUsage_MappingPersistent, + HAL_BufferUsage_MappingOptional, + HAL_BufferUsage_MappingAccessRandom, + HAL_BufferUsage_MappingAccessSequentialWrite, + HAL_BufferUsage_Mapping, + ]> { + let cppNamespace = "mlir::iree_compiler::IREE::HAL"; +} + +def HAL_CommandBufferMode_None : I32BitEnumAttrCase<"None", 0x0000>; +def HAL_CommandBufferMode_OneShot : I32BitEnumAttrCase<"OneShot", 0x0001>; +def HAL_CommandBufferMode_Nested : I32BitEnumAttrCase<"Nested", 0x0002>; +def HAL_CommandBufferMode_AllowInlineExecution : I32BitEnumAttrCase<"AllowInlineExecution", 0x0010>; +def HAL_CommandBufferModeBitfieldAttr : + I32BitEnumAttr<"CommandBufferModeBitfield", "valid CommandBufferMode", [ + HAL_CommandBufferMode_None, + HAL_CommandBufferMode_OneShot, + HAL_CommandBufferMode_Nested, + HAL_CommandBufferMode_AllowInlineExecution, + ]> { + let cppNamespace = "mlir::iree_compiler::IREE::HAL"; +} + +def HAL_CommandCategory_None : I32BitEnumAttrCase<"None", 0x0000>; +def HAL_CommandCategory_Transfer : I32BitEnumAttrCase<"Transfer", 0x0001>; +def HAL_CommandCategory_Dispatch : I32BitEnumAttrCase<"Dispatch", 0x0002>; +def HAL_CommandCategoryBitfieldAttr : + I32BitEnumAttr<"CommandCategoryBitfield", "valid CommandCategory", [ + HAL_CommandCategory_None, + HAL_CommandCategory_Transfer, + HAL_CommandCategory_Dispatch, + ]> { + let cppNamespace = "mlir::iree_compiler::IREE::HAL"; +} + +def HAL_DescriptorType_UniformBuffer : I32EnumAttrCase<"UniformBuffer", 6, "uniform_buffer">; +def HAL_DescriptorType_StorageBuffer : I32EnumAttrCase<"StorageBuffer", 7, "storage_buffer">; +def HAL_DescriptorTypeAttr : + HAL_I32EnumAttr<"DescriptorType", "valid DescriptorType", "descriptor_type", [ + HAL_DescriptorType_UniformBuffer, + HAL_DescriptorType_StorageBuffer, + ]>; + +def HAL_DescriptorFlags_None : I32BitEnumAttrCase<"None", 0x0000>; +def HAL_DescriptorFlags_ReadOnly : I32BitEnumAttrCase<"ReadOnly", 0x0001>; +def HAL_DescriptorFlagsAttr : + I32BitEnumAttr<"DescriptorFlags", "valid Descriptor flags", [ + HAL_DescriptorFlags_None, + HAL_DescriptorFlags_ReadOnly, + ]> { + let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; +} + +def HAL_DescriptorSetLayoutFlags_None : I32BitEnumAttrCase<"None", 0x0000>; +def HAL_DescriptorSetLayoutFlags_Indirect : I32BitEnumAttrCase<"Indirect", 0x0001>; +def HAL_DescriptorSetLayoutFlagsAttr : + I32BitEnumAttr<"DescriptorSetLayoutFlags", "valid DescriptorSetLayout flags", [ + HAL_DescriptorSetLayoutFlags_None, + HAL_DescriptorSetLayoutFlags_Indirect, + ]> { + let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; +} + +def HAL_ExecutionStage_None : I32BitEnumAttrCase<"None", 0x0000>; +def HAL_ExecutionStage_CommandIssue : I32BitEnumAttrCase<"CommandIssue", 0x0001>; +def HAL_ExecutionStage_CommandProcess : I32BitEnumAttrCase<"CommandProcess", 0x0002>; +def HAL_ExecutionStage_Dispatch : I32BitEnumAttrCase<"Dispatch", 0x0004>; +def HAL_ExecutionStage_Transfer : I32BitEnumAttrCase<"Transfer", 0x0008>; +def HAL_ExecutionStage_CommandRetire : I32BitEnumAttrCase<"CommandRetire", 0x0010>; +def HAL_ExecutionStage_Host : I32BitEnumAttrCase<"Host", 0x0020>; +def HAL_ExecutionStageBitfieldAttr : + I32BitEnumAttr<"ExecutionStageBitfield", "valid ExecutionStage", [ + HAL_ExecutionStage_None, + HAL_ExecutionStage_CommandIssue, + HAL_ExecutionStage_CommandProcess, + HAL_ExecutionStage_Dispatch, + HAL_ExecutionStage_Transfer, + HAL_ExecutionStage_CommandRetire, + HAL_ExecutionStage_Host + ]> { + let cppNamespace = "mlir::iree_compiler::IREE::HAL"; +} + +def HAL_ExecutionBarrierFlag_None : I32BitEnumAttrCase<"None", 0x0000>; +def HAL_ExecutionBarrierFlag_Reserved : I32BitEnumAttrCase<"Reserved", 0x0001>; +def HAL_ExecutionBarrierFlagBitfieldAttr : + I32BitEnumAttr<"ExecutionBarrierFlagBitfield", "valid ExecutionBarrierFlag", [ + HAL_ExecutionBarrierFlag_None, + HAL_ExecutionBarrierFlag_Reserved, + ]> { + let cppNamespace = "mlir::iree_compiler::IREE::HAL"; +} + +def HAL_FenceFlag_None : I32BitEnumAttrCase<"None", 0x0000>; +def HAL_FenceFlag_Reserved : I32BitEnumAttrCase<"Reserved", 0x0001>; +def HAL_FenceFlagBitfieldAttr : + I32BitEnumAttr<"FenceFlagBitfield", "valid FenceFlag", [ + HAL_FenceFlag_None, + HAL_FenceFlag_Reserved, + ]> { + let cppNamespace = "mlir::iree_compiler::IREE::HAL"; +} + +def HAL_AccessScope_None : I32BitEnumAttrCase<"None", 0x0000>; +def HAL_AccessScope_IndirectCommandRead : I32BitEnumAttrCase<"IndirectCommandRead", 0x0001>; +def HAL_AccessScope_ConstantRead : I32BitEnumAttrCase<"ConstantRead", 0x0002>; +def HAL_AccessScope_DispatchRead : I32BitEnumAttrCase<"DispatchRead", 0x0004>; +def HAL_AccessScope_DispatchWrite : I32BitEnumAttrCase<"DispatchWrite", 0x0008>; +def HAL_AccessScope_TransferRead : I32BitEnumAttrCase<"TransferRead", 0x0010>; +def HAL_AccessScope_TransferWrite : I32BitEnumAttrCase<"TransferWrite", 0x0020>; +def HAL_AccessScope_HostRead : I32BitEnumAttrCase<"HostRead", 0x0040>; +def HAL_AccessScope_HostWrite : I32BitEnumAttrCase<"HostWrite", 0x0080>; +def HAL_AccessScope_MemoryRead : I32BitEnumAttrCase<"MemoryRead", 0x0100>; +def HAL_AccessScope_MemoryWrite : I32BitEnumAttrCase<"MemoryWrite", 0x0200>; +def HAL_AccessScopeBitfieldAttr : + I32BitEnumAttr<"AccessScopeBitfield", "valid AccessScope", [ + HAL_AccessScope_None, + HAL_AccessScope_IndirectCommandRead, + HAL_AccessScope_ConstantRead, + HAL_AccessScope_DispatchRead, + HAL_AccessScope_DispatchWrite, + HAL_AccessScope_TransferRead, + HAL_AccessScope_TransferWrite, + HAL_AccessScope_HostRead, + HAL_AccessScope_HostWrite, + HAL_AccessScope_MemoryRead, + HAL_AccessScope_MemoryWrite + ]> { + let cppNamespace = "mlir::iree_compiler::IREE::HAL"; +} + +def HAL_CallingConvention_Default : I32EnumAttrCase<"Default", 0>; +def HAL_CallingConvention_ParameterStruct : I32EnumAttrCase<"ParameterStruct", 1>; +def HAL_CallingConventionAttr : + I32EnumAttr< + "CallingConvention", + "Calling conversions for linked functions",[ + HAL_CallingConvention_Default, + HAL_CallingConvention_ParameterStruct, + ]>{ + let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; +} + +//===----------------------------------------------------------------------===// +// #hal.collective<*> +//===----------------------------------------------------------------------===// + +def HAL_CollectiveKind_AllGather : I32EnumAttrCase<"AllGather", 0, "all_gather">; +def HAL_CollectiveKind_AllReduce : I32EnumAttrCase<"AllReduce", 1, "all_reduce">; +def HAL_CollectiveKind_AllToAll : I32EnumAttrCase<"AllToAll", 2, "all_to_all">; +def HAL_CollectiveKind_Broadcast : I32EnumAttrCase<"Broadcast", 3, "broadcast">; +def HAL_CollectiveKind_Reduce : I32EnumAttrCase<"Reduce", 4, "reduce">; +def HAL_CollectiveKind_ReduceScatter : I32EnumAttrCase<"ReduceScatter", 5, "reduce_scatter">; +def HAL_CollectiveKind_Send : I32EnumAttrCase<"Send", 6, "send">; +def HAL_CollectiveKind_Recv : I32EnumAttrCase<"Recv", 7, "recv">; +def HAL_CollectiveKind_SendRecv: I32EnumAttrCase<"SendRecv", 8, "send_recv">; +def HAL_CollectiveKindAttr : + I32EnumAttr<"CollectiveKind", "valid CollectiveKind", [ + HAL_CollectiveKind_AllGather, + HAL_CollectiveKind_AllReduce, + HAL_CollectiveKind_AllToAll, + HAL_CollectiveKind_Broadcast, + HAL_CollectiveKind_Reduce, + HAL_CollectiveKind_ReduceScatter, + HAL_CollectiveKind_Send, + HAL_CollectiveKind_Recv, + HAL_CollectiveKind_SendRecv, + ]> { + let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; +} + +def HAL_CollectiveReductionOp_None : I32EnumAttrCase<"None", 0, "none">; +def HAL_CollectiveReductionOp_ReductionSum : I32EnumAttrCase<"ReductionSum", 1, "sum">; +def HAL_CollectiveReductionOp_ReductionProduct : I32EnumAttrCase<"ReductionProduct", 2, "product">; +def HAL_CollectiveReductionOp_ReductionMinimum : I32EnumAttrCase<"ReductionMinimum", 3, "minimum">; +def HAL_CollectiveReductionOp_ReductionMaximum : I32EnumAttrCase<"ReductionMaximum", 4, "maximum">; +def HAL_CollectiveReductionOp_ReductionAverage : I32EnumAttrCase<"ReductionAverage", 5, "average">; +def HAL_CollectiveReductionOpAttr : + I32EnumAttr<"CollectiveReductionOp", "valid CollectiveReductionOp", [ + HAL_CollectiveReductionOp_None, + HAL_CollectiveReductionOp_ReductionSum, + HAL_CollectiveReductionOp_ReductionProduct, + HAL_CollectiveReductionOp_ReductionMinimum, + HAL_CollectiveReductionOp_ReductionMaximum, + HAL_CollectiveReductionOp_ReductionAverage, + ]> { + let cppNamespace = "mlir::iree_compiler::IREE::HAL"; +} + +def HAL_CollectiveElementType_Sint8 : I32EnumAttrCase<"Sint8", 0, "si8">; +def HAL_CollectiveElementType_Uint8 : I32EnumAttrCase<"Uint8", 1, "ui8">; +def HAL_CollectiveElementType_Sint16 : I32EnumAttrCase<"Sint16", 2, "si16">; +def HAL_CollectiveElementType_Uint16 : I32EnumAttrCase<"Uint16", 3, "ui16">; +def HAL_CollectiveElementType_Sint32 : I32EnumAttrCase<"Sint32", 4, "si32">; +def HAL_CollectiveElementType_Uint32 : I32EnumAttrCase<"Uint32", 5, "ui32">; +def HAL_CollectiveElementType_Sint64 : I32EnumAttrCase<"Sint64", 6, "si64">; +def HAL_CollectiveElementType_Uint64 : I32EnumAttrCase<"Uint64", 7, "ui64">; +def HAL_CollectiveElementType_Float16 : I32EnumAttrCase<"Float16", 8, "f16">; +def HAL_CollectiveElementType_Float32 : I32EnumAttrCase<"Float32", 9, "f32">; +def HAL_CollectiveElementType_Float64 : I32EnumAttrCase<"Float64", 10, "f64">; +def HAL_CollectiveElementType_BFloat16 : I32EnumAttrCase<"BFloat16", 11, "bf16">; +def HAL_CollectiveElementTypeAttr : + I32EnumAttr<"CollectiveElementType", "valid CollectiveElementType", [ + HAL_CollectiveElementType_Sint8, + HAL_CollectiveElementType_Uint8, + HAL_CollectiveElementType_Sint16, + HAL_CollectiveElementType_Uint16, + HAL_CollectiveElementType_Sint32, + HAL_CollectiveElementType_Uint32, + HAL_CollectiveElementType_Sint64, + HAL_CollectiveElementType_Uint64, + HAL_CollectiveElementType_Float16, + HAL_CollectiveElementType_Float32, + HAL_CollectiveElementType_Float64, + HAL_CollectiveElementType_BFloat16, + ]> { + let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; +} + +def HAL_CollectiveAttr : + AttrDef { + let mnemonic = "collective"; + let summary = [{collective operation and specification}]; + let description = [{ + Specifies the collective operation to perform and any mode bits required. + }]; + let parameters = (ins + AttrParameter<"CollectiveKind", "">:$kind, + OptionalParameter<"std::optional">:$reduction, + AttrParameter<"CollectiveElementType", "">:$element_type + ); + let assemblyFormat = [{ + `<` $kind (`with` $reduction^)? `:` $element_type `>` + }]; + let extraClassDeclaration = [{ + // Returns the runtime encoding of the collective attribute. + uint32_t getEncodedValue() const; + }]; +} + +//===----------------------------------------------------------------------===// +// hal.descriptor_set.binding<*> +//===----------------------------------------------------------------------===// + +def HAL_DescriptorSetBindingAttr : + AttrDef { + let mnemonic = "descriptor_set.binding"; + let summary = [{descriptor set binding specification}]; + let description = [{ + Specifies a single binding within a descriptor set layout. + }]; + let parameters = (ins + AttrParameter<"int64_t", "">:$ordinal, + AttrParameter<"DescriptorType", "">:$type, + OptionalParameter<"std::optional">:$flags + ); + let assemblyFormat = [{ + `<` $ordinal `,` $type (`,` $flags^)? `>` + }]; +} + +def HAL_DescriptorSetLayoutBindingArrayAttr : + TypedArrayAttrBase; + +//===----------------------------------------------------------------------===// +// hal.descriptor_set.layout<*> +//===----------------------------------------------------------------------===// + +def HAL_DescriptorSetLayoutAttr : + AttrDef { + let mnemonic = "descriptor_set.layout"; + let summary = [{descriptor set layout specification}]; + let description = [{ + Specifies the layout information of a single set of descriptors used within + an pipeline layout. Multiple of these sets may be used by a single entry + point to allow for bindings with similar update frequencies to be grouped. + }]; + let parameters = (ins + AttrParameter<"int64_t", "">:$ordinal, + ArrayRefParameter<"DescriptorSetBindingAttr", "">:$bindings, + OptionalParameter<"std::optional">:$flags + ); + let assemblyFormat = [{ + `<` + $ordinal `,` + `bindings` `=` `[` $bindings `]` + (`,` `flags` `=` $flags^)? + `>` + }]; +} + +//===----------------------------------------------------------------------===// +// hal.pipeline.layout<*> +//===----------------------------------------------------------------------===// + +def HAL_PipelineLayoutAttr : + AttrDef { + let mnemonic = "pipeline.layout"; + let summary = [{executable entry point layout specification}]; + let description = [{ + Specifies the layout information used for interacting with executable + functions. This allows host code to correctly map parameters to the + lower-level target-specific argument passing behavior. + }]; + let parameters = (ins + AttrParameter<"int64_t", "">:$pushConstants, + ArrayRefParameter<"DescriptorSetLayoutAttr", "">:$setLayouts + ); + let assemblyFormat = [{ + `<` + `push_constants` `=` $pushConstants `,` + `sets` `=` `[` $setLayouts `]` + `>` + }]; +} + +//===----------------------------------------------------------------------===// +// hal.interface.binding<*> +//===----------------------------------------------------------------------===// + +def HAL_InterfaceBindingAttr : + AttrDef { + let mnemonic = "interface.binding"; + let summary = [{interface binding specification}]; + let description = [{ + Specifies the descriptor set and binding ordinal of a particular layout + binding. + + Example: + ```mlir + #hal.interface.binding<0, 1> + ``` + }]; + let parameters = (ins + AttrParameter<"int64_t", "">:$set, + AttrParameter<"int64_t", "">:$binding + ); + let assemblyFormat = [{ + `<` $set `,` $binding `>` + }]; +} + +def HAL_InterfaceBindingArrayAttr : + TypedArrayAttrBase; + +//===----------------------------------------------------------------------===// +// #hal.device.target<*> +//===----------------------------------------------------------------------===// + +def HAL_DeviceTargetAttr : + AttrDef { + let mnemonic = "device.target"; + let summary = [{generic device target specification}]; + let description = [{ + Specifies the properties of a target runtime device. + Target devices are specified with a canonical identifier matching those used + by the runtime (such as `cpu`, `vulkan`, etc). Target devices may support + several target executable formats specified with `#hal.executable.target`. + An optional configuration dictionary allows for overriding backend defaults. + + Example: + ```mlir + #hal.device.target<"llvm-cpu", { + executable_targets = [ + #hal.executable.target<"llvm-cpu", "embedded-elf-arm_32">, + #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64">, + ] + }> + ``` + }]; + let parameters = (ins + AttrParameter<"StringAttr", "">:$deviceID, + AttrParameter<"DictionaryAttr", "">:$configuration + ); + let builders = [ + AttrBuilder<(ins "StringRef":$deviceID)>, + ]; + + let extraClassDeclaration = [{ + // Returns a symbol-compatible name that pseudo-uniquely identifies this + // target. Callers must perform deduplication when required. + std::string getSymbolNameFragment(); + + // Returns true if there's an attribute with the given name in the + // configuration dictionary. + bool hasConfigurationAttr(StringRef name); + + // Returns zero or more executable targets that this device supports. + SmallVector getExecutableTargets(); + + // Returns a list of target devices that may be active for the given + // operation. This will recursively walk parent operations until one with + // the `hal.device.targets` attribute is found. + static SmallVector lookup(Operation *op); + + // Returns true if there is any UnitAttr with |name| in any device + // configuration for the given |op|. + static bool lookupConfigAttrAny(Operation *op, StringRef name); + + // Returns true if all device configurations found for the given |op| have + // a UnitAttr with |name|. + static bool lookupConfigAttrAll(Operation *op, StringRef name); + + // Returns the AND of boolean attributes of |name| in all device + // configurations found for the given |op|. + // Returns nullopt if any config does not have the key defined indicating + // that it's not statically known/runtime dynamic. + static std::optional + lookupConfigAttrAnd(Operation *op, StringRef name); + + // Returns the OR of boolean attributes of |name| in all device + // configurations found for the given |op|. + // Returns nullopt if any config does not have the key defined indicating + // that it's not statically known/runtime dynamic. + static std::optional + lookupConfigAttrOr(Operation *op, StringRef name); + + // Returns the range of integer attributes of |name| in all device + // configurations found for the given |op|. + // Returns nullopt if any config does not have the key defined indicating + // that it's not statically known/runtime dynamic. + static std::optional> + lookupConfigAttrRange(Operation *op, StringRef name); + + // Returns a list of all target executable configurations that may be + // required for the given operation. + static SmallVector + lookupExecutableTargets(Operation *op); + }]; + let hasCustomAssemblyFormat = 1; +} + +//===----------------------------------------------------------------------===// +// #hal.executable.target<*> +//===----------------------------------------------------------------------===// + +def HAL_ExecutableTargetAttr : + AttrDef { + let mnemonic = "executable.target"; + let summary = [{generic executable target specification}]; + let description = [{ + Specifies how to compile an executable for a specific target backend. + A backend is used to translate and serialize the executable into the final + form passed to the runtime. The format of the executable is a + target-specific value indicating the required runtime support to load the + deployed artifact. An optionally provided configuration dictionary overrides + backend-specific defaults. + + Example: + ```mlir + // Produce a system-native ELF for x86-64 systems using the LLVM backend: + #hal.executable.target<"llvm-cpu", "system-elf-x86_64", { + triple = "x86_64-unknown-linux-elf", + cpu = "host", + cpu_features = "host", + abi = "lp32", + ... + }> + ``` + + The same compilation backend may be used to translate executables for + several different runtime devices. Likewise the same runtime device may use + one of many different executable targets. Assume an N:M mapping between the + two in all cases. + }]; + + let parameters = (ins + AttrParameter<"StringAttr", "">:$backend, + AttrParameter<"StringAttr", "">:$format, + AttrParameter<"DictionaryAttr", "">:$configuration + ); + + let builders = [ + AttrBuilder<(ins "StringRef":$backend, "StringRef":$format)>, + ]; + + let extraClassDeclaration = [{ + // Returns a symbol-compatible name that pseudo-uniquely identifies this + // target. Callers must perform deduplication when required. + std::string getSymbolNameFragment() const; + + // Returns true if there's an attribute with the given name in the + // configuration dictionary. + bool hasConfigurationAttr(StringRef name); + + // Returns true if this attribute is a generic version of |specificAttr|. + // A more generic version will match with many specific versions. + bool isGenericOf(IREE::HAL::ExecutableTargetAttr specificAttr); + + // Returns the executable target configuration for the given operation. + // This will recursively walk parent operations until one with the + // `hal.executable.target` attribute is found or a `hal.executable.variant` + // specifies a value. Returns nullptr if no target specification can be found. + static IREE::HAL::ExecutableTargetAttr lookup(Operation *op); + }]; + + let hasCustomAssemblyFormat = 1; +} + +//===----------------------------------------------------------------------===// +// #hal.executable.object<*> +//===----------------------------------------------------------------------===// + +def HAL_ExecutableObjectAttr : AttrDef { + let mnemonic = "executable.object"; + let summary = [{object file reference}]; + let description = [{ + Defines an object file that can be linked into executables. + Today this is only supported for external file references with paths the + compiler can successfully resolve from its current working directory. + Inlined data can optionally be provided to avoid the need for file system + access and ensure the data source is attached to the IR as it makes its way + through multiple compiler stages or reproducers. + + Future revisions may change this to an interface that allows both internal + and external resources to define the object contents. Linking needs to be + updated to support various object compositions and certain backends may + require additional infrastructure support. + + In the long term the goal is to allow combinations of declared objects and + generated code in order to give control of linking behavior to frontends. + Instead of needing global command line flags to link in additional blobs + the frontend can emit executables with the dependencies already defined per + variant without needing to reach into the IREE compiler code. + + Example: + ```mlir + #hal.executable.object<{path = "some/file.obj"}> + #hal.executable.object<{ + path = "some/embedded/file.obj", + data = dense<[...]> : vector<2048xi8> + }> + ``` + }]; + + let parameters = (ins + AttrParameter<"StringAttr", "">:$path, + OptionalParameter<"DenseIntElementsAttr", "">:$data + ); + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + // Returns a list of all objects with a path matching one of the provided + // file extensions. + static void filterObjects( + ArrayAttr objectAttrs, ArrayRef extensions, + SmallVectorImpl &filteredAttrs); + + // Returns the absolute path of the referenced object file if it exists. + FailureOr getAbsolutePath(); + + // Returns the contents of the object file or None if loading failed. + // TODO(benvanik): better return type to support mapping/etc? eh + std::optional loadData(); + }]; +} + +def HAL_ExecutableObjectArrayAttr : + TypedArrayAttrBase; + +//===----------------------------------------------------------------------===// +// #hal.executable.objects<*> +//===----------------------------------------------------------------------===// + +def HAL_ExecutableObjectsAttr : AttrDef { + let mnemonic = "executable.objects"; + let summary = [{target-specific object file references}]; + let description = [{ + A dictionary mapping executable target specifications to a list of objects. + This is used to allow layers of the stack that support multi-targeting to + specify information used during lowering into each particular target. + + The key attributes are matched against each target variant based on the + backend and format as well as any configuration data provided. When + comparing the configuration only fields present in both the key and + target variant will be checked and must match. This allows specification of + generic sets ("all x86_64 targets get these objects") as well as specific + ones ("only x86_64 targets with vector_size = 64 get these objects"). + + Example: + ```mlir + #hal.executable.objects<{ + #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64"> = [ + #hal.executable.object<{path = "some/file_arm_64.obj"}> + ], + #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64"> = [ + #hal.executable.object<{path = "some/file_x86_64.obj"}> + ] + }> + ``` + }]; + + let parameters = (ins + AttrParameter<"ArrayAttr", "">:$targets, + AttrParameter<"ArrayAttr", "">:$targetObjects + ); + + let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + // Returns the objects specified for the given generic target. + std::optional getApplicableObjects( + IREE::HAL::ExecutableTargetAttr specificTargetAttr); + }]; +} + +//===----------------------------------------------------------------------===// +// #hal.affinity.queue<*> +//===----------------------------------------------------------------------===// + +def HAL_AffinityQueueAttr : AttrDef, +]> { + let mnemonic = "affinity.queue"; + let summary = [{specifies a set of allowed queues for an operation}]; + let description = [{ + WIP; see [#10765](https://github.com/openxla/iree/issues/10765). + This may change in the future to either be a nested attribute on a larger + affinity struct or be defined by an implementation of the affinity attr + interface. For now this allows higher levels of the stack to specify + queues such that the stream dialect can understand them and they can be + lowered into the HAL dialect. + + Specifies that an annotated operation or scope is only allowed to execute on + the set of queues (0-64) provided. Operations will not run on other queues. + + Example: + ```mlir + // any queue + #hal.affinity.queue<*> + // queues 4 and 5 + #hal.affinity.queue<[4, 5]> + ``` + }]; + + let parameters = (ins + AttrParameter<"int64_t", "">:$mask + ); + + let hasCustomAssemblyFormat = 1; +} + +#endif // IREE_DIALECT_HAL_IR_HAL_ATTRS diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td index 0aa23c4790f5..a705528fae24 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td @@ -14,330 +14,6 @@ include "iree/compiler/Dialect/Util/IR/UtilTypes.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/EnumAttr.td" -//===----------------------------------------------------------------------===// -// HAL enums -//===----------------------------------------------------------------------===// - -// Wrapper over base I32EnumAttr to set common fields for HAL enums. -class HAL_I32Enum cases> - : I32EnumAttr { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; -} -class HAL_I32EnumAttr cases> - : EnumAttr, mnemonic> { - let assemblyFormat = "`<` $value `>`"; -} - -def HAL_MemoryModel_Unified : I32EnumAttrCase<"Unified", 0>; -def HAL_MemoryModel_Discrete : I32EnumAttrCase<"Discrete", 1>; -def HAL_MemoryModelAttr : - I32EnumAttr<"MemoryModel", "IREE HAL MemoryModel", [ - HAL_MemoryModel_Unified, - HAL_MemoryModel_Discrete, - ]> { - let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; -} - -def HAL_MemoryType_None : I32BitEnumAttrCase<"None", 0x0000>; // ? -def HAL_MemoryType_Optimal : I32BitEnumAttrCase<"Optimal", 0x0001>; // ! -def HAL_MemoryType_HostVisible : I32BitEnumAttrCase<"HostVisible", 0x0002>; // h -def HAL_MemoryType_HostCoherent : I32BitEnumAttrCase<"HostCoherent", 0x0004>; // c -def HAL_MemoryType_HostCached : I32BitEnumAttrCase<"HostCached", 0x0008>; // C -def HAL_MemoryType_HostLocal : I32BitEnumAttrCase<"HostLocal", 0x0046>; // H -def HAL_MemoryType_DeviceVisible : I32BitEnumAttrCase<"DeviceVisible", 0x0010>; // d -def HAL_MemoryType_DeviceLocal : I32BitEnumAttrCase<"DeviceLocal", 0x0030>; // D -def HAL_MemoryTypeBitfieldAttr : - I32BitEnumAttr<"MemoryTypeBitfield", "valid MemoryType", [ - HAL_MemoryType_None, - HAL_MemoryType_Optimal, - HAL_MemoryType_HostVisible, - HAL_MemoryType_HostCoherent, - HAL_MemoryType_HostCached, - HAL_MemoryType_HostLocal, - HAL_MemoryType_DeviceVisible, - HAL_MemoryType_DeviceLocal, - ]> { - let cppNamespace = "mlir::iree_compiler::IREE::HAL"; -} - -def HAL_MemoryAccess_None : I32BitEnumAttrCase<"None", 0x00000000>; -def HAL_MemoryAccess_Read : I32BitEnumAttrCase<"Read", 0x00000001>; -def HAL_MemoryAccess_Write : I32BitEnumAttrCase<"Write", 0x00000002>; -def HAL_MemoryAccess_Discard : I32BitEnumAttrCase<"Discard", 0x00000004>; -def HAL_MemoryAccess_MayAlias : I32BitEnumAttrCase<"MayAlias", 0x00000008>; -def HAL_MemoryAccess_Unaligned : I32BitEnumAttrCase<"Unaligned", 0x00000010>; -def HAL_MemoryAccess_Any : I32BitEnumAttrCase<"Any", 0x00000020>; -def HAL_MemoryAccessBitfieldAttr : - I32BitEnumAttr<"MemoryAccessBitfield", "valid MemoryAccess", [ - HAL_MemoryAccess_None, - HAL_MemoryAccess_Read, - HAL_MemoryAccess_Write, - HAL_MemoryAccess_Discard, - HAL_MemoryAccess_MayAlias, - HAL_MemoryAccess_Unaligned, - HAL_MemoryAccess_Any, - ]> { - let cppNamespace = "mlir::iree_compiler::IREE::HAL"; -} - -def HAL_BufferUsage_None : I32BitEnumAttrCase<"None", 0x00000000>; -def HAL_BufferUsage_TransferSource : I32BitEnumAttrCase<"TransferSource", 0x00000001>; -def HAL_BufferUsage_TransferTarget : I32BitEnumAttrCase<"TransferTarget", 0x00000002>; -def HAL_BufferUsage_Transfer : I32BitEnumAttrCase<"Transfer", 0x00000003>; -def HAL_BufferUsage_DispatchIndirectParams : I32BitEnumAttrCase<"DispatchIndirectParams", 0x00000100>; -def HAL_BufferUsage_DispatchUniformRead : I32BitEnumAttrCase<"DispatchUniformRead", 0x00000200>; -def HAL_BufferUsage_DispatchStorageRead : I32BitEnumAttrCase<"DispatchStorageRead", 0x00000400>; -def HAL_BufferUsage_DispatchStorageWrite : I32BitEnumAttrCase<"DispatchStorageWrite", 0x00000800>; -def HAL_BufferUsage_DispatchStorage : I32BitEnumAttrCase<"DispatchStorage", 0x00000C00>; -def HAL_BufferUsage_DispatchImageRead : I32BitEnumAttrCase<"DispatchImageRead", 0x00001000>; -def HAL_BufferUsage_DispatchImageWrite : I32BitEnumAttrCase<"DispatchImageWrite", 0x00002000>; -def HAL_BufferUsage_DispatchImage : I32BitEnumAttrCase<"DispatchImage", 0x00003000>; -def HAL_BufferUsage_SharingExport : I32BitEnumAttrCase<"SharingExport", 0x00010000>; -def HAL_BufferUsage_SharingReplicate : I32BitEnumAttrCase<"SharingReplicate", 0x00020000>; -def HAL_BufferUsage_SharingConcurrent : I32BitEnumAttrCase<"SharingConcurrent", 0x00040000>; -def HAL_BufferUsage_SharingImmutable : I32BitEnumAttrCase<"SharingImmutable", 0x00080000>; -def HAL_BufferUsage_MappingScoped : I32BitEnumAttrCase<"MappingScoped", 0x01000000>; -def HAL_BufferUsage_MappingPersistent : I32BitEnumAttrCase<"MappingPersistent", 0x02000000>; -def HAL_BufferUsage_MappingOptional : I32BitEnumAttrCase<"MappingOptional", 0x04000000>; -def HAL_BufferUsage_MappingAccessRandom : I32BitEnumAttrCase<"MappingAccessRandom", 0x08000000>; -def HAL_BufferUsage_MappingAccessSequentialWrite : I32BitEnumAttrCase<"MappingAccessSequentialWrite", 0x10000000>; -def HAL_BufferUsage_Mapping : I32BitEnumAttrCase<"Mapping", 0x09000000>; -def HAL_BufferUsageBitfieldAttr : - I32BitEnumAttr<"BufferUsageBitfield", "valid BufferUsage", [ - HAL_BufferUsage_None, - HAL_BufferUsage_TransferSource, - HAL_BufferUsage_TransferTarget, - HAL_BufferUsage_Transfer, - HAL_BufferUsage_DispatchIndirectParams, - HAL_BufferUsage_DispatchUniformRead, - HAL_BufferUsage_DispatchStorageRead, - HAL_BufferUsage_DispatchStorageWrite, - HAL_BufferUsage_DispatchStorage, - HAL_BufferUsage_DispatchImageRead, - HAL_BufferUsage_DispatchImageWrite, - HAL_BufferUsage_DispatchImage, - HAL_BufferUsage_SharingExport, - HAL_BufferUsage_SharingReplicate, - HAL_BufferUsage_SharingConcurrent, - HAL_BufferUsage_SharingImmutable, - HAL_BufferUsage_MappingScoped, - HAL_BufferUsage_MappingPersistent, - HAL_BufferUsage_MappingOptional, - HAL_BufferUsage_MappingAccessRandom, - HAL_BufferUsage_MappingAccessSequentialWrite, - HAL_BufferUsage_Mapping, - ]> { - let cppNamespace = "mlir::iree_compiler::IREE::HAL"; -} - -def HAL_CommandBufferMode_None : I32BitEnumAttrCase<"None", 0x0000>; -def HAL_CommandBufferMode_OneShot : I32BitEnumAttrCase<"OneShot", 0x0001>; -def HAL_CommandBufferMode_Nested : I32BitEnumAttrCase<"Nested", 0x0002>; -def HAL_CommandBufferMode_AllowInlineExecution : I32BitEnumAttrCase<"AllowInlineExecution", 0x0010>; -def HAL_CommandBufferModeBitfieldAttr : - I32BitEnumAttr<"CommandBufferModeBitfield", "valid CommandBufferMode", [ - HAL_CommandBufferMode_None, - HAL_CommandBufferMode_OneShot, - HAL_CommandBufferMode_Nested, - HAL_CommandBufferMode_AllowInlineExecution, - ]> { - let cppNamespace = "mlir::iree_compiler::IREE::HAL"; -} - -def HAL_CommandCategory_None : I32BitEnumAttrCase<"None", 0x0000>; -def HAL_CommandCategory_Transfer : I32BitEnumAttrCase<"Transfer", 0x0001>; -def HAL_CommandCategory_Dispatch : I32BitEnumAttrCase<"Dispatch", 0x0002>; -def HAL_CommandCategoryBitfieldAttr : - I32BitEnumAttr<"CommandCategoryBitfield", "valid CommandCategory", [ - HAL_CommandCategory_None, - HAL_CommandCategory_Transfer, - HAL_CommandCategory_Dispatch, - ]> { - let cppNamespace = "mlir::iree_compiler::IREE::HAL"; -} - -def HAL_DescriptorType_UniformBuffer : I32EnumAttrCase<"UniformBuffer", 6, "uniform_buffer">; -def HAL_DescriptorType_StorageBuffer : I32EnumAttrCase<"StorageBuffer", 7, "storage_buffer">; -def HAL_DescriptorTypeAttr : - HAL_I32EnumAttr<"DescriptorType", "valid DescriptorType", "descriptor_type", [ - HAL_DescriptorType_UniformBuffer, - HAL_DescriptorType_StorageBuffer, - ]>; - -def HAL_DescriptorFlags_None : I32BitEnumAttrCase<"None", 0x0000>; -def HAL_DescriptorFlags_ReadOnly : I32BitEnumAttrCase<"ReadOnly", 0x0001>; -def HAL_DescriptorFlagsAttr : - I32BitEnumAttr<"DescriptorFlags", "valid Descriptor flags", [ - HAL_DescriptorFlags_None, - HAL_DescriptorFlags_ReadOnly, - ]> { - let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; -} - -def HAL_DescriptorSetLayoutFlags_None : I32BitEnumAttrCase<"None", 0x0000>; -def HAL_DescriptorSetLayoutFlags_Indirect : I32BitEnumAttrCase<"Indirect", 0x0001>; -def HAL_DescriptorSetLayoutFlagsAttr : - I32BitEnumAttr<"DescriptorSetLayoutFlags", "valid DescriptorSetLayout flags", [ - HAL_DescriptorSetLayoutFlags_None, - HAL_DescriptorSetLayoutFlags_Indirect, - ]> { - let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; -} - -def HAL_ExecutionStage_None : I32BitEnumAttrCase<"None", 0x0000>; -def HAL_ExecutionStage_CommandIssue : I32BitEnumAttrCase<"CommandIssue", 0x0001>; -def HAL_ExecutionStage_CommandProcess : I32BitEnumAttrCase<"CommandProcess", 0x0002>; -def HAL_ExecutionStage_Dispatch : I32BitEnumAttrCase<"Dispatch", 0x0004>; -def HAL_ExecutionStage_Transfer : I32BitEnumAttrCase<"Transfer", 0x0008>; -def HAL_ExecutionStage_CommandRetire : I32BitEnumAttrCase<"CommandRetire", 0x0010>; -def HAL_ExecutionStage_Host : I32BitEnumAttrCase<"Host", 0x0020>; -def HAL_ExecutionStageBitfieldAttr : - I32BitEnumAttr<"ExecutionStageBitfield", "valid ExecutionStage", [ - HAL_ExecutionStage_None, - HAL_ExecutionStage_CommandIssue, - HAL_ExecutionStage_CommandProcess, - HAL_ExecutionStage_Dispatch, - HAL_ExecutionStage_Transfer, - HAL_ExecutionStage_CommandRetire, - HAL_ExecutionStage_Host - ]> { - let cppNamespace = "mlir::iree_compiler::IREE::HAL"; -} - -def HAL_ExecutionBarrierFlag_None : I32BitEnumAttrCase<"None", 0x0000>; -def HAL_ExecutionBarrierFlag_Reserved : I32BitEnumAttrCase<"Reserved", 0x0001>; -def HAL_ExecutionBarrierFlagBitfieldAttr : - I32BitEnumAttr<"ExecutionBarrierFlagBitfield", "valid ExecutionBarrierFlag", [ - HAL_ExecutionBarrierFlag_None, - HAL_ExecutionBarrierFlag_Reserved, - ]> { - let cppNamespace = "mlir::iree_compiler::IREE::HAL"; -} - -def HAL_FenceFlag_None : I32BitEnumAttrCase<"None", 0x0000>; -def HAL_FenceFlag_Reserved : I32BitEnumAttrCase<"Reserved", 0x0001>; -def HAL_FenceFlagBitfieldAttr : - I32BitEnumAttr<"FenceFlagBitfield", "valid FenceFlag", [ - HAL_FenceFlag_None, - HAL_FenceFlag_Reserved, - ]> { - let cppNamespace = "mlir::iree_compiler::IREE::HAL"; -} - -def HAL_AccessScope_None : I32BitEnumAttrCase<"None", 0x0000>; -def HAL_AccessScope_IndirectCommandRead : I32BitEnumAttrCase<"IndirectCommandRead", 0x0001>; -def HAL_AccessScope_ConstantRead : I32BitEnumAttrCase<"ConstantRead", 0x0002>; -def HAL_AccessScope_DispatchRead : I32BitEnumAttrCase<"DispatchRead", 0x0004>; -def HAL_AccessScope_DispatchWrite : I32BitEnumAttrCase<"DispatchWrite", 0x0008>; -def HAL_AccessScope_TransferRead : I32BitEnumAttrCase<"TransferRead", 0x0010>; -def HAL_AccessScope_TransferWrite : I32BitEnumAttrCase<"TransferWrite", 0x0020>; -def HAL_AccessScope_HostRead : I32BitEnumAttrCase<"HostRead", 0x0040>; -def HAL_AccessScope_HostWrite : I32BitEnumAttrCase<"HostWrite", 0x0080>; -def HAL_AccessScope_MemoryRead : I32BitEnumAttrCase<"MemoryRead", 0x0100>; -def HAL_AccessScope_MemoryWrite : I32BitEnumAttrCase<"MemoryWrite", 0x0200>; -def HAL_AccessScopeBitfieldAttr : - I32BitEnumAttr<"AccessScopeBitfield", "valid AccessScope", [ - HAL_AccessScope_None, - HAL_AccessScope_IndirectCommandRead, - HAL_AccessScope_ConstantRead, - HAL_AccessScope_DispatchRead, - HAL_AccessScope_DispatchWrite, - HAL_AccessScope_TransferRead, - HAL_AccessScope_TransferWrite, - HAL_AccessScope_HostRead, - HAL_AccessScope_HostWrite, - HAL_AccessScope_MemoryRead, - HAL_AccessScope_MemoryWrite - ]> { - let cppNamespace = "mlir::iree_compiler::IREE::HAL"; -} - -def HAL_CollectiveKind_AllGather : I32EnumAttrCase<"AllGather", 0, "all_gather">; -def HAL_CollectiveKind_AllReduce : I32EnumAttrCase<"AllReduce", 1, "all_reduce">; -def HAL_CollectiveKind_AllToAll : I32EnumAttrCase<"AllToAll", 2, "all_to_all">; -def HAL_CollectiveKind_Broadcast : I32EnumAttrCase<"Broadcast", 3, "broadcast">; -def HAL_CollectiveKind_Reduce : I32EnumAttrCase<"Reduce", 4, "reduce">; -def HAL_CollectiveKind_ReduceScatter : I32EnumAttrCase<"ReduceScatter", 5, "reduce_scatter">; -def HAL_CollectiveKind_Send : I32EnumAttrCase<"Send", 6, "send">; -def HAL_CollectiveKind_Recv : I32EnumAttrCase<"Recv", 7, "recv">; -def HAL_CollectiveKind_SendRecv: I32EnumAttrCase<"SendRecv", 8, "send_recv">; -def HAL_CollectiveKindAttr : - I32EnumAttr<"CollectiveKind", "valid CollectiveKind", [ - HAL_CollectiveKind_AllGather, - HAL_CollectiveKind_AllReduce, - HAL_CollectiveKind_AllToAll, - HAL_CollectiveKind_Broadcast, - HAL_CollectiveKind_Reduce, - HAL_CollectiveKind_ReduceScatter, - HAL_CollectiveKind_Send, - HAL_CollectiveKind_Recv, - HAL_CollectiveKind_SendRecv, - ]> { - let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; -} - -def HAL_CollectiveReductionOp_None : I32EnumAttrCase<"None", 0, "none">; -def HAL_CollectiveReductionOp_ReductionSum : I32EnumAttrCase<"ReductionSum", 1, "sum">; -def HAL_CollectiveReductionOp_ReductionProduct : I32EnumAttrCase<"ReductionProduct", 2, "product">; -def HAL_CollectiveReductionOp_ReductionMinimum : I32EnumAttrCase<"ReductionMinimum", 3, "minimum">; -def HAL_CollectiveReductionOp_ReductionMaximum : I32EnumAttrCase<"ReductionMaximum", 4, "maximum">; -def HAL_CollectiveReductionOp_ReductionAverage : I32EnumAttrCase<"ReductionAverage", 5, "average">; -def HAL_CollectiveReductionOpAttr : - I32EnumAttr<"CollectiveReductionOp", "valid CollectiveReductionOp", [ - HAL_CollectiveReductionOp_None, - HAL_CollectiveReductionOp_ReductionSum, - HAL_CollectiveReductionOp_ReductionProduct, - HAL_CollectiveReductionOp_ReductionMinimum, - HAL_CollectiveReductionOp_ReductionMaximum, - HAL_CollectiveReductionOp_ReductionAverage, - ]> { - let cppNamespace = "mlir::iree_compiler::IREE::HAL"; -} - -def HAL_CollectiveElementType_Sint8 : I32EnumAttrCase<"Sint8", 0, "si8">; -def HAL_CollectiveElementType_Uint8 : I32EnumAttrCase<"Uint8", 1, "ui8">; -def HAL_CollectiveElementType_Sint16 : I32EnumAttrCase<"Sint16", 2, "si16">; -def HAL_CollectiveElementType_Uint16 : I32EnumAttrCase<"Uint16", 3, "ui16">; -def HAL_CollectiveElementType_Sint32 : I32EnumAttrCase<"Sint32", 4, "si32">; -def HAL_CollectiveElementType_Uint32 : I32EnumAttrCase<"Uint32", 5, "ui32">; -def HAL_CollectiveElementType_Sint64 : I32EnumAttrCase<"Sint64", 6, "si64">; -def HAL_CollectiveElementType_Uint64 : I32EnumAttrCase<"Uint64", 7, "ui64">; -def HAL_CollectiveElementType_Float16 : I32EnumAttrCase<"Float16", 8, "f16">; -def HAL_CollectiveElementType_Float32 : I32EnumAttrCase<"Float32", 9, "f32">; -def HAL_CollectiveElementType_Float64 : I32EnumAttrCase<"Float64", 10, "f64">; -def HAL_CollectiveElementType_BFloat16 : I32EnumAttrCase<"BFloat16", 11, "bf16">; -def HAL_CollectiveElementTypeAttr : - I32EnumAttr<"CollectiveElementType", "valid CollectiveElementType", [ - HAL_CollectiveElementType_Sint8, - HAL_CollectiveElementType_Uint8, - HAL_CollectiveElementType_Sint16, - HAL_CollectiveElementType_Uint16, - HAL_CollectiveElementType_Sint32, - HAL_CollectiveElementType_Uint32, - HAL_CollectiveElementType_Sint64, - HAL_CollectiveElementType_Uint64, - HAL_CollectiveElementType_Float16, - HAL_CollectiveElementType_Float32, - HAL_CollectiveElementType_Float64, - HAL_CollectiveElementType_BFloat16, - ]> { - let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; -} - -def HAL_CallingConvention_Default : I32EnumAttrCase<"Default", 0>; -def HAL_CallingConvention_ParameterStruct : I32EnumAttrCase<"ParameterStruct", 1>; -def HAL_CallingConventionAttr : - I32EnumAttr< - "CallingConvention", - "Calling conversions for linked functions",[ - HAL_CallingConvention_Default, - HAL_CallingConvention_ParameterStruct, - ]>{ - let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; -} - //===----------------------------------------------------------------------===// // HAL types //===----------------------------------------------------------------------===// @@ -558,418 +234,6 @@ def HAL_DurationMillisAttr : SignlessIntElementsAttr<32> { // }]; } -//===----------------------------------------------------------------------===// -// HAL structs -//===----------------------------------------------------------------------===// - -def HAL_CollectiveAttr : - AttrDef { - let mnemonic = "collective"; - let summary = [{collective operation and specification}]; - let description = [{ - Specifies the collective operation to perform and any mode bits required. - }]; - let parameters = (ins - AttrParameter<"CollectiveKind", "">:$kind, - OptionalParameter<"std::optional">:$reduction, - AttrParameter<"CollectiveElementType", "">:$element_type - ); - let assemblyFormat = [{ - `<` $kind (`with` $reduction^)? `:` $element_type `>` - }]; - let extraClassDeclaration = [{ - // Returns the runtime encoding of the collective attribute. - uint32_t getEncodedValue() const; - }]; -} - -def HAL_DescriptorSetBindingAttr : - AttrDef { - let mnemonic = "descriptor_set.binding"; - let summary = [{descriptor set binding specification}]; - let description = [{ - Specifies a single binding within a descriptor set layout. - }]; - let parameters = (ins - AttrParameter<"int64_t", "">:$ordinal, - AttrParameter<"DescriptorType", "">:$type, - OptionalParameter<"std::optional">:$flags - ); - let assemblyFormat = [{ - `<` $ordinal `,` $type (`,` $flags^)? `>` - }]; -} - -def HAL_DescriptorSetLayoutBindingArrayAttr : - TypedArrayAttrBase; - -def HAL_DescriptorSetLayoutAttr : - AttrDef { - let mnemonic = "descriptor_set.layout"; - let summary = [{descriptor set layout specification}]; - let description = [{ - Specifies the layout information of a single set of descriptors used within - an pipeline layout. Multiple of these sets may be used by a single entry - point to allow for bindings with similar update frequencies to be grouped. - }]; - let parameters = (ins - AttrParameter<"int64_t", "">:$ordinal, - ArrayRefParameter<"DescriptorSetBindingAttr", "">:$bindings, - OptionalParameter<"std::optional">:$flags - ); - let assemblyFormat = [{ - `<` - $ordinal `,` - `bindings` `=` `[` $bindings `]` - (`,` `flags` `=` $flags^)? - `>` - }]; -} - -def HAL_PipelineLayoutAttr : - AttrDef { - let mnemonic = "pipeline.layout"; - let summary = [{executable entry point layout specification}]; - let description = [{ - Specifies the layout information used for interacting with executable - functions. This allows host code to correctly map parameters to the - lower-level target-specific argument passing behavior. - }]; - let parameters = (ins - AttrParameter<"int64_t", "">:$pushConstants, - ArrayRefParameter<"DescriptorSetLayoutAttr", "">:$setLayouts - ); - let assemblyFormat = [{ - `<` - `push_constants` `=` $pushConstants `,` - `sets` `=` `[` $setLayouts `]` - `>` - }]; -} - -def HAL_InterfaceBindingAttr : - AttrDef { - let mnemonic = "interface.binding"; - let summary = [{interface binding specification}]; - let description = [{ - Specifies the descriptor set and binding ordinal of a particular layout - binding. - - Example: - ```mlir - #hal.interface.binding<0, 1> - ``` - }]; - let parameters = (ins - AttrParameter<"int64_t", "">:$set, - AttrParameter<"int64_t", "">:$binding - ); - let assemblyFormat = [{ - `<` $set `,` $binding `>` - }]; -} - -def HAL_InterfaceBindingArrayAttr : - TypedArrayAttrBase; - -//===----------------------------------------------------------------------===// -// Device and executable target specification -//===----------------------------------------------------------------------===// - -def HAL_DeviceTargetAttr : - AttrDef { - let mnemonic = "device.target"; - let summary = [{generic device target specification}]; - let description = [{ - Specifies the properties of a target runtime device. - Target devices are specified with a canonical identifier matching those used - by the runtime (such as `cpu`, `vulkan`, etc). Target devices may support - several target executable formats specified with `#hal.executable.target`. - An optional configuration dictionary allows for overriding backend defaults. - - Example: - ```mlir - #hal.device.target<"llvm-cpu", { - executable_targets = [ - #hal.executable.target<"llvm-cpu", "embedded-elf-arm_32">, - #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64">, - ] - }> - ``` - }]; - let parameters = (ins - AttrParameter<"StringAttr", "">:$deviceID, - AttrParameter<"DictionaryAttr", "">:$configuration - ); - let builders = [ - AttrBuilder<(ins "StringRef":$deviceID)>, - ]; - - let extraClassDeclaration = [{ - // Returns a symbol-compatible name that pseudo-uniquely identifies this - // target. Callers must perform deduplication when required. - std::string getSymbolNameFragment(); - - // Returns true if there's an attribute with the given name in the - // configuration dictionary. - bool hasConfigurationAttr(StringRef name); - - // Returns zero or more executable targets that this device supports. - SmallVector getExecutableTargets(); - - // Returns a list of target devices that may be active for the given - // operation. This will recursively walk parent operations until one with - // the `hal.device.targets` attribute is found. - static SmallVector lookup(Operation *op); - - // Returns true if there is any UnitAttr with |name| in any device - // configuration for the given |op|. - static bool lookupConfigAttrAny(Operation *op, StringRef name); - - // Returns true if all device configurations found for the given |op| have - // a UnitAttr with |name|. - static bool lookupConfigAttrAll(Operation *op, StringRef name); - - // Returns the AND of boolean attributes of |name| in all device - // configurations found for the given |op|. - // Returns nullopt if any config does not have the key defined indicating - // that it's not statically known/runtime dynamic. - static std::optional - lookupConfigAttrAnd(Operation *op, StringRef name); - - // Returns the OR of boolean attributes of |name| in all device - // configurations found for the given |op|. - // Returns nullopt if any config does not have the key defined indicating - // that it's not statically known/runtime dynamic. - static std::optional - lookupConfigAttrOr(Operation *op, StringRef name); - - // Returns the range of integer attributes of |name| in all device - // configurations found for the given |op|. - // Returns nullopt if any config does not have the key defined indicating - // that it's not statically known/runtime dynamic. - static std::optional> - lookupConfigAttrRange(Operation *op, StringRef name); - - // Returns a list of all target executable configurations that may be - // required for the given operation. - static SmallVector - lookupExecutableTargets(Operation *op); - }]; - let hasCustomAssemblyFormat = 1; -} - -def HAL_ExecutableTargetAttr : - AttrDef { - let mnemonic = "executable.target"; - let summary = [{generic executable target specification}]; - let description = [{ - Specifies how to compile an executable for a specific target backend. - A backend is used to translate and serialize the executable into the final - form passed to the runtime. The format of the executable is a - target-specific value indicating the required runtime support to load the - deployed artifact. An optionally provided configuration dictionary overrides - backend-specific defaults. - - Example: - ```mlir - // Produce a system-native ELF for x86-64 systems using the LLVM backend: - #hal.executable.target<"llvm-cpu", "system-elf-x86_64", { - triple = "x86_64-unknown-linux-elf", - cpu = "host", - cpu_features = "host", - abi = "lp32", - ... - }> - ``` - - The same compilation backend may be used to translate executables for - several different runtime devices. Likewise the same runtime device may use - one of many different executable targets. Assume an N:M mapping between the - two in all cases. - }]; - - let parameters = (ins - AttrParameter<"StringAttr", "">:$backend, - AttrParameter<"StringAttr", "">:$format, - AttrParameter<"DictionaryAttr", "">:$configuration - ); - - let builders = [ - AttrBuilder<(ins "StringRef":$backend, "StringRef":$format)>, - ]; - - let extraClassDeclaration = [{ - // Returns a symbol-compatible name that pseudo-uniquely identifies this - // target. Callers must perform deduplication when required. - std::string getSymbolNameFragment() const; - - // Returns true if there's an attribute with the given name in the - // configuration dictionary. - bool hasConfigurationAttr(StringRef name); - - // Returns true if this attribute is a generic version of |specificAttr|. - // A more generic version will match with many specific versions. - bool isGenericOf(IREE::HAL::ExecutableTargetAttr specificAttr); - - // Returns the executable target configuration for the given operation. - // This will recursively walk parent operations until one with the - // `hal.executable.target` attribute is found or a `hal.executable.variant` - // specifies a value. Returns nullptr if no target specification can be found. - static IREE::HAL::ExecutableTargetAttr lookup(Operation *op); - }]; - - let hasCustomAssemblyFormat = 1; -} - -//===----------------------------------------------------------------------===// -// #hal.executable.object<*> -//===----------------------------------------------------------------------===// - -def HAL_ExecutableObjectAttr : AttrDef { - let mnemonic = "executable.object"; - let summary = [{object file reference}]; - let description = [{ - Defines an object file that can be linked into executables. - Today this is only supported for external file references with paths the - compiler can successfully resolve from its current working directory. - Inlined data can optionally be provided to avoid the need for file system - access and ensure the data source is attached to the IR as it makes its way - through multiple compiler stages or reproducers. - - Future revisions may change this to an interface that allows both internal - and external resources to define the object contents. Linking needs to be - updated to support various object compositions and certain backends may - require additional infrastructure support. - - In the long term the goal is to allow combinations of declared objects and - generated code in order to give control of linking behavior to frontends. - Instead of needing global command line flags to link in additional blobs - the frontend can emit executables with the dependencies already defined per - variant without needing to reach into the IREE compiler code. - - Example: - ```mlir - #hal.executable.object<{path = "some/file.obj"}> - #hal.executable.object<{ - path = "some/embedded/file.obj", - data = dense<[...]> : vector<2048xi8> - }> - ``` - }]; - - let parameters = (ins - AttrParameter<"StringAttr", "">:$path, - OptionalParameter<"DenseIntElementsAttr", "">:$data - ); - - let hasCustomAssemblyFormat = 1; - - let extraClassDeclaration = [{ - // Returns a list of all objects with a path matching one of the provided - // file extensions. - static void filterObjects( - ArrayAttr objectAttrs, ArrayRef extensions, - SmallVectorImpl &filteredAttrs); - - // Returns the absolute path of the referenced object file if it exists. - FailureOr getAbsolutePath(); - - // Returns the contents of the object file or None if loading failed. - // TODO(benvanik): better return type to support mapping/etc? eh - std::optional loadData(); - }]; -} - -def HAL_ExecutableObjectArrayAttr : - TypedArrayAttrBase; - -def HAL_ExecutableObjectsAttr : AttrDef { - let mnemonic = "executable.objects"; - let summary = [{target-specific object file references}]; - let description = [{ - A dictionary mapping executable target specifications to a list of objects. - This is used to allow layers of the stack that support multi-targeting to - specify information used during lowering into each particular target. - - The key attributes are matched against each target variant based on the - backend and format as well as any configuration data provided. When - comparing the configuration only fields present in both the key and - target variant will be checked and must match. This allows specification of - generic sets ("all x86_64 targets get these objects") as well as specific - ones ("only x86_64 targets with vector_size = 64 get these objects"). - - Example: - ```mlir - #hal.executable.objects<{ - #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64"> = [ - #hal.executable.object<{path = "some/file_arm_64.obj"}> - ], - #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64"> = [ - #hal.executable.object<{path = "some/file_x86_64.obj"}> - ] - }> - ``` - }]; - - let parameters = (ins - AttrParameter<"ArrayAttr", "">:$targets, - AttrParameter<"ArrayAttr", "">:$targetObjects - ); - - let genVerifyDecl = 1; - let hasCustomAssemblyFormat = 1; - - let extraClassDeclaration = [{ - // Returns the objects specified for the given generic target. - std::optional getApplicableObjects( - IREE::HAL::ExecutableTargetAttr specificTargetAttr); - }]; -} - -//===----------------------------------------------------------------------===// -// #hal.affinity.queue<*> -//===----------------------------------------------------------------------===// - -def HAL_AffinityQueueAttr : AttrDef, -]> { - let mnemonic = "affinity.queue"; - let summary = [{specifies a set of allowed queues for an operation}]; - let description = [{ - WIP; see [#10765](https://github.com/openxla/iree/issues/10765). - This may change in the future to either be a nested attribute on a larger - affinity struct or be defined by an implementation of the affinity attr - interface. For now this allows higher levels of the stack to specify - queues such that the stream dialect can understand them and they can be - lowered into the HAL dialect. - - Specifies that an annotated operation or scope is only allowed to execute on - the set of queues (0-64) provided. Operations will not run on other queues. - - Example: - ```mlir - // any queue - #hal.affinity.queue<*> - // queues 4 and 5 - #hal.affinity.queue<[4, 5]> - ``` - }]; - - let parameters = (ins - AttrParameter<"int64_t", "">:$mask - ); - - let hasCustomAssemblyFormat = 1; -} - //===----------------------------------------------------------------------===// // Base HAL op classes //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index d37007683b31..f61a913580a6 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -949,15 +949,6 @@ void DescriptorSetLayoutCreateOp::getAsmResultNames( setNameFn(getResult(), "descriptor_set_layout"); } -//===----------------------------------------------------------------------===// -// hal.descriptor_set_layout.lookup -//===----------------------------------------------------------------------===// - -void DescriptorSetLayoutLookupOp::getAsmResultNames( - function_ref setNameFn) { - setNameFn(getResult(), "descriptor_set_layout"); -} - //===----------------------------------------------------------------------===// // hal.device.allocator //===----------------------------------------------------------------------===// @@ -982,6 +973,17 @@ LogicalResult DeviceQueryOp::verify() { return success(); } +// static +Value DeviceQueryOp::createI1(Location loc, Value device, StringRef category, + StringRef key, OpBuilder &builder) { + auto i1Type = builder.getI1Type(); + return builder + .create( + loc, i1Type, i1Type, device, builder.getStringAttr(category), + builder.getStringAttr(key), builder.getIntegerAttr(i1Type, 0)) + .getValue(); +} + //===----------------------------------------------------------------------===// // hal.device.queue.* //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index fcce21b07f8a..0dd54ea23cc7 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -7,6 +7,7 @@ #ifndef IREE_DIALECT_HAL_OPS #define IREE_DIALECT_HAL_OPS +include "iree/compiler/Dialect/HAL/IR/HALAttrs.td" include "iree/compiler/Dialect/HAL/IR/HALBase.td" include "iree/compiler/Dialect/HAL/IR/HALInterfaces.td" include "iree/compiler/Dialect/Util/IR/UtilAttrs.td" @@ -1538,33 +1539,6 @@ def HAL_DescriptorSetLayoutCreateOp : }]; } -def HAL_DescriptorSetLayoutLookupOp : HAL_PureOp<"descriptor_set_layout.lookup", [ - DeclareOpInterfaceMethods, - ]> { - let summary = [{descriptor set layout cache lookup pseudo-op}]; - let description = [{ - Used during conversion to provide a placeholder for a globally cached and - possibly lazy-initialized descriptor set layout. - }]; - - let arguments = (ins - HAL_Device:$device, - HAL_DescriptorSetLayoutFlagsAttr:$flags, - HAL_DescriptorSetLayoutBindingArrayAttr:$bindings - ); - let results = (outs - HAL_DescriptorSetLayout:$result - ); - - let assemblyFormat = [{ - `device` `(` $device `:` type($device) `)` - `flags` `(` $flags `)` - `bindings` `(` $bindings `)` - `:` type($result) - attr-dict-with-keyword - }]; -} - } // OpGroupDescriptorSetLayoutOps //===----------------------------------------------------------------------===// @@ -1655,7 +1629,10 @@ def HAL_DeviceQueryOp : Well-known keys: - * hal.executable.format :: {some format} + * hal.device.id :: {some id pattern} + Returns 1 if the device identifier matches the given pattern string. + + * hal.executable.format :: {some format pattern} Returns 1 if the given format is supported by the device loader. * hal.device :: concurrency @@ -1688,6 +1665,14 @@ def HAL_DeviceQueryOp : attr-dict-with-keyword }]; + let extraClassDeclaration = [{ + // Returns a true i1 if the given query returns a non-zero value. + // Returns false if the query fails or returns a zero value. + static Value createI1(Location loc, Value device, + StringRef category, StringRef key, + OpBuilder &builder); + }]; + let hasVerifier = 1; } diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp index ab74bc9f798e..292c05a6a707 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp @@ -10,86 +10,13 @@ #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "iree/compiler/Utils/StringUtils.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/FileSystem.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/Path.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" -// clang-format off: must be included after all LLVM/MLIR headers. -#define GET_ATTRDEF_CLASSES -#include "iree/compiler/Dialect/HAL/IR/HALAttrs.cpp.inc" // IWYU pragma: keep -#include "iree/compiler/Dialect/HAL/IR/HALEnums.cpp.inc" // IWYU pragma: keep -// clang-format on - namespace mlir::iree_compiler::IREE::HAL { -//===----------------------------------------------------------------------===// -// Enum utilities -//===----------------------------------------------------------------------===// - -template -static LogicalResult parseEnumAttr(AsmParser &parser, StringRef attrName, - AttrType &attr) { - Attribute genericAttr; - auto loc = parser.getCurrentLocation(); - if (failed(parser.parseAttribute(genericAttr, - parser.getBuilder().getNoneType()))) { - return parser.emitError(loc) - << "failed to parse '" << attrName << "' enum string value"; - } - auto stringAttr = llvm::dyn_cast(genericAttr); - if (!stringAttr) { - return parser.emitError(loc) - << "expected " << attrName << " attribute specified as string"; - } - auto symbolized = - symbolizeEnum(stringAttr.getValue()); - if (!symbolized.hasValue()) { - return parser.emitError(loc) - << "failed to parse '" << attrName << "' enum value"; - } - attr = AttrType::get(parser.getBuilder().getContext(), symbolized.getValue()); - return success(); -} - -template -static LogicalResult parseOptionalEnumAttr(AsmParser &parser, - StringRef attrName, AttrType &attr) { - if (succeeded(parser.parseOptionalQuestion())) { - // Special case `?` to indicate any/none/undefined/etc. - attr = AttrType::get(parser.getBuilder().getContext(), 0); - return success(); - } - return parseEnumAttr(parser, attrName, attr); -} - -//===----------------------------------------------------------------------===// -// Element types -//===----------------------------------------------------------------------===// - -// See the iree/hal/command_buffer.h iree_hal_collective_op_t for details. -uint32_t CollectiveAttr::getEncodedValue() const { - union { - uint32_t packed; // packed value - struct { - uint8_t kind; - uint8_t reduction; - uint8_t elementType; - uint8_t reserved; - }; - } value = {0}; - value.kind = static_cast(getKind()); - value.reduction = static_cast( - getReduction().value_or(CollectiveReductionOp::None)); - value.elementType = static_cast(getElementType()); - return value.packed; -} - //===----------------------------------------------------------------------===// // Alignment //===----------------------------------------------------------------------===// @@ -159,676 +86,13 @@ Value DeviceType::resolveAny(Location loc, OpBuilder &builder) { loc, builder.getType(), deviceIndex); } -//===----------------------------------------------------------------------===// -// #hal.device.target -//===----------------------------------------------------------------------===// - -// static -DeviceTargetAttr DeviceTargetAttr::get(MLIRContext *context, - StringRef deviceID) { - // TODO(benvanik): query default configuration from the target backend. - return get(context, StringAttr::get(context, deviceID), - DictionaryAttr::get(context)); -} - -// static -Attribute DeviceTargetAttr::parse(AsmParser &p, Type type) { - StringAttr deviceIDAttr; - DictionaryAttr configAttr; - // `<"device-id"` - if (failed(p.parseLess()) || failed(p.parseAttribute(deviceIDAttr))) { - return {}; - } - // `, {config}` - if (succeeded(p.parseOptionalComma()) && - failed(p.parseAttribute(configAttr))) { - return {}; - } - // `>` - if (failed(p.parseGreater())) { - return {}; - } - return get(p.getContext(), deviceIDAttr, configAttr); -} - -void DeviceTargetAttr::print(AsmPrinter &p) const { - auto &os = p.getStream(); - os << "<"; - p.printAttribute(getDeviceID()); - auto configAttr = getConfiguration(); - if (configAttr && !configAttr.empty()) { - os << ", "; - p.printAttribute(configAttr); - } - os << ">"; -} - -std::string DeviceTargetAttr::getSymbolNameFragment() { - return sanitizeSymbolName(getDeviceID().getValue().lower()); -} - -bool DeviceTargetAttr::hasConfigurationAttr(StringRef name) { - auto configAttr = getConfiguration(); - return configAttr && configAttr.get(name); -} - -SmallVector DeviceTargetAttr::getExecutableTargets() { - SmallVector resultAttrs; - auto configAttr = getConfiguration(); - if (configAttr) { - auto targetsAttr = configAttr.getAs("executable_targets"); - if (targetsAttr) { - for (auto attr : targetsAttr.getValue()) { - resultAttrs.push_back(llvm::dyn_cast(attr)); - } - } - } - return resultAttrs; -} - -// static -SmallVector -DeviceTargetAttr::lookup(Operation *op) { - auto attrId = mlir::StringAttr::get(op->getContext(), "hal.device.targets"); - while (op) { - auto targetsAttr = op->getAttrOfType(attrId); - if (targetsAttr) { - SmallVector result; - for (auto targetAttr : targetsAttr) { - result.push_back(llvm::cast(targetAttr)); - } - return result; - } - op = op->getParentOp(); - } - return {}; // No devices found; let caller decide what to do. -} - -// Returns a set of all configuration attributes from all device targets with -// a configuration set. Targets with no configuration set are ignored. -static SmallVector lookupOptionalConfigAttrs(Operation *op) { - auto targetAttrs = IREE::HAL::DeviceTargetAttr::lookup(op); - if (targetAttrs.empty()) - return {}; - SmallVector configAttrs; - for (auto targetAttr : targetAttrs) { - auto configAttr = targetAttr.getConfiguration(); - if (configAttr) - configAttrs.push_back(configAttr); - } - return configAttrs; -} - -// Returns a set of all configuration attributes from all device targets. -// Returns nullopt if any target is missing a configuration attribute. -static std::optional> -lookupRequiredConfigAttrs(Operation *op) { - auto targetAttrs = IREE::HAL::DeviceTargetAttr::lookup(op); - if (targetAttrs.empty()) - return std::nullopt; - SmallVector configAttrs; - for (auto targetAttr : targetAttrs) { - auto configAttr = targetAttr.getConfiguration(); - if (!configAttr) - return std::nullopt; - configAttrs.push_back(configAttr); - } - return configAttrs; -} - -template -static std::optional joinConfigAttrs( - ArrayRef configAttrs, StringRef name, - std::function - join) { - if (configAttrs.empty()) - return std::nullopt; - auto firstValue = configAttrs.front().getAs(name); - if (!firstValue) - return std::nullopt; - auto result = firstValue.getValue(); - for (auto configAttr : configAttrs.drop_front(1)) { - auto value = configAttr.getAs(name); - if (!value) - return std::nullopt; - result = join(result, value.getValue()); - } - return result; -} - -template -static std::optional> -joinConfigStaticRanges(ArrayRef configAttrs, StringRef name, - std::function( - StaticRange, - StaticRange)> - join) { - if (configAttrs.empty()) - return std::nullopt; - auto firstValue = configAttrs.front().getAs(name); - if (!firstValue) - return std::nullopt; - StaticRange result{firstValue.getValue()}; - for (auto configAttr : configAttrs.drop_front(1)) { - auto value = configAttr.getAs(name); - if (!value) - return std::nullopt; - result = - join(result, StaticRange{value.getValue()}); - } - return result; -} - -// static -bool DeviceTargetAttr::lookupConfigAttrAny(Operation *op, StringRef name) { - auto configAttrs = lookupOptionalConfigAttrs(op); - if (configAttrs.empty()) - return false; - for (auto configAttr : configAttrs) { - if (configAttr.get(name)) - return true; - } - return false; -} - -// static -bool DeviceTargetAttr::lookupConfigAttrAll(Operation *op, StringRef name) { - auto configAttrs = lookupRequiredConfigAttrs(op); - if (!configAttrs) - return false; - for (auto configAttr : *configAttrs) { - if (!configAttr.get(name)) - return false; - } - return true; -} - -// static -std::optional DeviceTargetAttr::lookupConfigAttrAnd(Operation *op, - StringRef name) { - auto configAttrs = lookupRequiredConfigAttrs(op); - if (!configAttrs) - return std::nullopt; - return joinConfigAttrs( - configAttrs.value(), name, [](bool lhs, bool rhs) { return lhs && rhs; }); -} - -// static -std::optional DeviceTargetAttr::lookupConfigAttrOr(Operation *op, - StringRef name) { - auto configAttrs = lookupRequiredConfigAttrs(op); - if (!configAttrs) - return std::nullopt; - return joinConfigAttrs( - configAttrs.value(), name, [](bool lhs, bool rhs) { return lhs || rhs; }); -} - -// static -std::optional> -DeviceTargetAttr::lookupConfigAttrRange(Operation *op, StringRef name) { - auto configAttrs = lookupRequiredConfigAttrs(op); - if (!configAttrs) - return std::nullopt; - return joinConfigStaticRanges( - configAttrs.value(), name, - [](StaticRange lhs, StaticRange rhs) { - return StaticRange{ - llvm::APIntOps::smin(lhs.min, rhs.min), - llvm::APIntOps::smax(lhs.max, rhs.max), - }; - }); -} - -// static -SmallVector -DeviceTargetAttr::lookupExecutableTargets(Operation *op) { - SmallVector resultAttrs; - for (auto deviceTargetAttr : lookup(op)) { - for (auto executableTargetAttr : deviceTargetAttr.getExecutableTargets()) { - if (!llvm::is_contained(resultAttrs, executableTargetAttr)) { - resultAttrs.push_back(executableTargetAttr); - } - } - } - return resultAttrs; -} - -//===----------------------------------------------------------------------===// -// #hal.executable.target -//===----------------------------------------------------------------------===// - -// static -ExecutableTargetAttr ExecutableTargetAttr::get(MLIRContext *context, - StringRef backend, - StringRef format) { - return get(context, StringAttr::get(context, backend), - StringAttr::get(context, format), DictionaryAttr::get(context)); -} - -// static -Attribute ExecutableTargetAttr::parse(AsmParser &p, Type type) { - StringAttr backendAttr; - StringAttr formatAttr; - DictionaryAttr configurationAttr; - // `<"backend", "format"` - if (failed(p.parseLess()) || failed(p.parseAttribute(backendAttr)) || - failed(p.parseComma()) || failed(p.parseAttribute(formatAttr))) { - return {}; - } - // `, {config}` - if (succeeded(p.parseOptionalComma()) && - failed(p.parseAttribute(configurationAttr))) { - return {}; - } - // `>` - if (failed(p.parseGreater())) { - return {}; - } - return get(p.getContext(), backendAttr, formatAttr, configurationAttr); -} - -void ExecutableTargetAttr::print(AsmPrinter &p) const { - auto &os = p.getStream(); - os << "<"; - p.printAttribute(getBackend()); - os << ", "; - p.printAttribute(getFormat()); - auto config = getConfiguration(); - if (config && !config.empty()) { - os << ", "; - p.printAttribute(config); - } - os << ">"; -} - -std::string ExecutableTargetAttr::getSymbolNameFragment() const { - return sanitizeSymbolName(getFormat().getValue().lower()); -} - -bool ExecutableTargetAttr::hasConfigurationAttr(StringRef name) { - auto configAttr = getConfiguration(); - return configAttr && configAttr.get(name); -} - -// For now this is very simple: if there are any specified fields that are -// present in this attribute they must match. We could allow target backends -// to customize this via attribute interfaces in the future if we needed. -bool ExecutableTargetAttr::isGenericOf( - IREE::HAL::ExecutableTargetAttr specificAttr) { - if (getBackend() != specificAttr.getBackend() || - getFormat() != specificAttr.getFormat()) { - // Totally different backends and binary formats. - // There may be cases where we want to share things - such as when targeting - // both DLLs and dylibs or something - but today almost all of these are - // unique situations. - return false; - } - - // If the config is empty on either we can quickly match. - // This is the most common case for users manually specifying targets. - auto genericConfigAttr = getConfiguration(); - auto specificConfigAttr = specificAttr.getConfiguration(); - if (!genericConfigAttr || !specificConfigAttr) - return true; - - // Ensure all fields in specificConfigAttr either don't exist or match. - for (auto expectedAttr : specificConfigAttr.getValue()) { - auto actualValue = genericConfigAttr.getNamed(expectedAttr.getName()); - if (!actualValue) { - continue; // ignore, not present in generic - } - if (actualValue->getValue() != expectedAttr.getValue()) { - return false; // mismatch, both have values but they differ - } - } - - // Ensure all fields in genericConfigAttr exist in the specific one. - // If missing then the generic is _more_ specific and can't match. - for (auto actualAttr : genericConfigAttr.getValue()) { - if (!specificConfigAttr.getNamed(actualAttr.getName())) { - return false; // mismatch, present in generic but not specific - } - } - - // All fields match or are omitted in the generic version. - return true; -} - -// static -ExecutableTargetAttr ExecutableTargetAttr::lookup(Operation *op) { - auto *context = op->getContext(); - auto attrId = StringAttr::get(context, "hal.executable.target"); - while (op) { - // Take directly from the enclosing variant. - if (auto variantOp = llvm::dyn_cast(op)) { - return variantOp.getTarget(); - } - // Use an override if specified. - auto attr = op->getAttrOfType(attrId); - if (attr) - return attr; - // Continue walk. - op = op->getParentOp(); - } - // No target found during walk. No default to provide so fail and let the - // caller decide what to do (assert/fallback/etc). - return nullptr; -} - -//===----------------------------------------------------------------------===// -// #hal.executable.object -//===----------------------------------------------------------------------===// - -// static -Attribute ExecutableObjectAttr::parse(AsmParser &p, Type type) { - NamedAttrList dict; - // `<{` dict `}>` - if (failed(p.parseLess()) || failed(p.parseOptionalAttrDict(dict)) || - failed(p.parseGreater())) { - return {}; - } - auto pathAttr = llvm::dyn_cast_if_present(dict.get("path")); - auto dataAttr = - llvm::dyn_cast_if_present(dict.get("data")); - return get(p.getContext(), pathAttr, dataAttr); -} - -void ExecutableObjectAttr::print(AsmPrinter &p) const { - auto &os = p.getStream(); - os << "<{"; - if (auto pathAttr = getPath()) { - os << "path = "; - p.printAttribute(getPath()); - } - if (auto dataAttr = getData()) { - os << ", data = "; - p.printAttribute(getData()); - } - os << "}>"; -} - -// static -void ExecutableObjectAttr::filterObjects( - ArrayAttr objectAttrs, ArrayRef extensions, - SmallVectorImpl &filteredAttrs) { - if (!objectAttrs) - return; - for (auto objectAttr : - objectAttrs.getAsRange()) { - auto path = objectAttr.getPath(); - auto ext = llvm::sys::path::extension(path); - if (llvm::is_contained(extensions, ext)) { - filteredAttrs.push_back(objectAttr); - } - } -} - -// Tries to find |filePath| on disk either at its absolute path or joined with -// any of the specified |searchPaths| in order. -// Returns the absolute file path when found or a failure if there are no hits. -static FailureOr -findFileInPaths(StringRef filePath, ArrayRef searchPaths) { - // First try to see if it's an absolute path - we don't want to perform any - // additional processing on top of that. - if (llvm::sys::path::is_absolute(filePath)) { - if (llvm::sys::fs::exists(filePath)) - return filePath.str(); - return failure(); - } - - // Try a relative lookup from the current working directory. - if (llvm::sys::fs::exists(filePath)) - return filePath.str(); - - // Search each path in turn for a file that exists. - // It doesn't mean we can open it but we'll get a better error out of the - // actual open attempt than what we could produce here. - for (auto searchPath : searchPaths) { - SmallVector tryPath{searchPath.begin(), searchPath.end()}; - llvm::sys::path::append(tryPath, filePath); - if (llvm::sys::fs::exists(Twine(tryPath))) - return Twine(tryPath).str(); - } - - // Not found in either the user-specified absolute path, cwd, or the search - // paths. - return failure(); -} - -static llvm::cl::list clExecutableObjectSearchPath( - "iree-hal-executable-object-search-path", - llvm::cl::desc("Additional search paths for resolving " - "#hal.executable.object file references."), - llvm::cl::ZeroOrMore); - -FailureOr ExecutableObjectAttr::getAbsolutePath() { - auto pathAttr = getPath(); - if (!pathAttr) - return failure(); // not a file reference - return findFileInPaths(pathAttr.getValue(), clExecutableObjectSearchPath); -} - -std::optional ExecutableObjectAttr::loadData() { - if (auto dataAttr = getData()) { - // This is shady but so is using this feature. - // TODO(benvanik): figure out a way to limit the attribute to signless int8. - // We could share the attribute -> byte array code with the VM constant - // serialization if we wanted. - auto rawData = dataAttr.getRawData(); - return std::string(rawData.data(), rawData.size()); - } else if (auto pathAttr = getPath()) { - // Search for file and try to load it if found. - auto filePath = - findFileInPaths(pathAttr.getValue(), clExecutableObjectSearchPath); - if (failed(filePath)) { - llvm::errs() - << "ERROR: referenced object file not found on any path; use " - "--iree-hal-executable-object-search-path= to add search paths: " - << *this << "\n"; - return std::nullopt; - } - auto file = llvm::MemoryBuffer::getFile(*filePath); - if (!file) - return std::nullopt; - return std::string((*file)->getBuffer()); - } - return std::nullopt; -} - -//===----------------------------------------------------------------------===// -// #hal.executable.objects -//===----------------------------------------------------------------------===// - -// static -LogicalResult ExecutableObjectsAttr::verify( - function_ref emitError, ArrayAttr targetsAttr, - ArrayAttr targetObjectsAttr) { - if (targetsAttr.size() != targetObjectsAttr.size()) { - return emitError() << "targets and objects must be 1:1"; - } - for (auto targetAttr : targetsAttr) { - if (!llvm::isa(targetAttr)) { - return emitError() - << "target keys must be #hal.executable.target attributes"; - } - } - for (auto objectsAttr : targetObjectsAttr) { - auto objectsArrayAttr = llvm::dyn_cast(objectsAttr); - if (!objectsArrayAttr) { - return emitError() << "target objects must be an array of " - "#hal.executable.object attributes"; - } - } - return success(); -} - -// static -Attribute ExecutableObjectsAttr::parse(AsmParser &p, Type type) { - // `<{` target = [objects, ...], ... `}>` - SmallVector targetAttrs; - SmallVector objectsAttrs; - if (failed(p.parseLess())) - return {}; - if (succeeded(p.parseLBrace()) && !succeeded(p.parseOptionalRBrace())) { - do { - Attribute targetAttr; - ArrayAttr objectsAttr; - if (failed(p.parseAttribute(targetAttr)) || failed(p.parseEqual()) || - failed(p.parseAttribute(objectsAttr))) { - return {}; - } - targetAttrs.push_back(targetAttr); - objectsAttrs.push_back(objectsAttr); - } while (succeeded(p.parseOptionalComma())); - if (failed(p.parseRBrace())) - return {}; - } - if (failed(p.parseGreater())) - return {}; - return get(p.getContext(), ArrayAttr::get(p.getContext(), targetAttrs), - ArrayAttr::get(p.getContext(), objectsAttrs)); -} - -void ExecutableObjectsAttr::print(AsmPrinter &p) const { - auto &os = p.getStream(); - os << "<{"; - llvm::interleaveComma(llvm::zip_equal(getTargets(), getTargetObjects()), os, - [&](std::tuple keyValue) { - p.printAttribute(std::get<0>(keyValue)); - os << " = "; - p.printAttributeWithoutType(std::get<1>(keyValue)); - }); - os << "}>"; -} - -std::optional ExecutableObjectsAttr::getApplicableObjects( - IREE::HAL::ExecutableTargetAttr specificTargetAttr) { - SmallVector allObjectAttrs; - for (auto [targetAttr, objectsAttr] : - llvm::zip_equal(getTargets(), getTargetObjects())) { - auto genericTargetAttr = - llvm::cast(targetAttr); - if (genericTargetAttr.isGenericOf(specificTargetAttr)) { - auto objectsArrayAttr = llvm::cast(objectsAttr); - allObjectAttrs.append(objectsArrayAttr.begin(), objectsArrayAttr.end()); - } - } - if (allObjectAttrs.empty()) - return std::nullopt; - return ArrayAttr::get(specificTargetAttr.getContext(), allObjectAttrs); -} - -//===----------------------------------------------------------------------===// -// #hal.affinity.queue -//===----------------------------------------------------------------------===// - -// static -Attribute AffinityQueueAttr::parse(AsmParser &p, Type type) { - int64_t mask = 0; - // `<` - if (failed(p.parseLess())) - return {}; - // `*` (any) - if (succeeded(p.parseOptionalStar())) { - mask = -1; - } else { - // `[`queue_bit[, ...] `]` - if (failed(p.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() { - int64_t i = 0; - if (failed(p.parseInteger(i))) - return failure(); - mask |= 1ll << i; - return success(); - }))) { - return {}; - } - } - // `>` - if (failed(p.parseGreater())) - return {}; - return get(p.getContext(), mask); -} - -void AffinityQueueAttr::print(AsmPrinter &p) const { - auto &os = p.getStream(); - os << "<"; - int64_t mask = getMask(); - if (mask == -1) { - os << "*"; - } else { - os << "["; - for (int i = 0, j = 0; i < sizeof(mask) * 8; ++i) { - if (mask & (1ll << i)) { - if (j++ > 0) - os << ", "; - os << i; - } - } - os << "]"; - } - os << ">"; -} - -bool AffinityQueueAttr::isExecutableWith( - IREE::Stream::AffinityAttr other) const { - if (!other) - return true; - // Only compatible with other queue affinities today. When we extend the - // attributes to specify device targets we'd want to check here. - auto otherQueueAttr = llvm::dyn_cast_if_present(other); - if (!otherQueueAttr) - return false; - // If this affinity is a subset of the target affinity then it can execute - // with it. - if ((getMask() & otherQueueAttr.getMask()) == getMask()) - return true; - // Otherwise not compatible. - return false; -} - -IREE::Stream::AffinityAttr -AffinityQueueAttr::joinOR(IREE::Stream::AffinityAttr other) const { - if (!other) - return *this; - if (!IREE::Stream::AffinityAttr::canExecuteTogether(*this, other)) { - return nullptr; - } - auto otherQueueAttr = llvm::dyn_cast_if_present(other); - return AffinityQueueAttr::get(getContext(), - getMask() | otherQueueAttr.getMask()); -} - -IREE::Stream::AffinityAttr -AffinityQueueAttr::joinAND(IREE::Stream::AffinityAttr other) const { - if (!other) - return *this; - if (!IREE::Stream::AffinityAttr::canExecuteTogether(*this, other)) { - return nullptr; - } - auto otherQueueAttr = llvm::dyn_cast_if_present(other); - return AffinityQueueAttr::get(getContext(), - getMask() & otherQueueAttr.getMask()); -} - //===----------------------------------------------------------------------===// // Dialect registration //===----------------------------------------------------------------------===// -#include "iree/compiler/Dialect/HAL/IR/HALAttrInterfaces.cpp.inc" #include "iree/compiler/Dialect/HAL/IR/HALOpInterfaces.cpp.inc" #include "iree/compiler/Dialect/HAL/IR/HALTypeInterfaces.cpp.inc" -void HALDialect::registerAttributes() { - // Register command line flags: - (void)clExecutableObjectSearchPath; - - addAttributes< -#define GET_ATTRDEF_LIST -#include "iree/compiler/Dialect/HAL/IR/HALAttrs.cpp.inc" // IWYU pragma: keep - >(); -} - void HALDialect::registerTypes() { addTypes(); } -//===----------------------------------------------------------------------===// -// Attribute printing and parsing -//===----------------------------------------------------------------------===// - -Attribute HALDialect::parseAttribute(DialectAsmParser &parser, - Type type) const { - StringRef mnemonic; - Attribute genAttr; - OptionalParseResult parseResult = - generatedAttributeParser(parser, &mnemonic, type, genAttr); - if (parseResult.has_value()) - return genAttr; - parser.emitError(parser.getNameLoc()) - << "unknown HAL attribute: " << mnemonic; - return {}; -} - -void HALDialect::printAttribute(Attribute attr, DialectAsmPrinter &p) const { - TypeSwitch(attr).Default([&](Attribute) { - if (failed(generatedAttributePrinter(attr, p))) { - assert(false && "unhandled HAL attribute kind"); - } - }); -} - -//===----------------------------------------------------------------------===// -// Type printing and parsing -//===----------------------------------------------------------------------===// - Type HALDialect::parseType(DialectAsmParser &parser) const { StringRef typeKind; if (parser.parseKeyword(&typeKind)) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp index f47fd809107e..4b540967551d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp @@ -128,10 +128,6 @@ class VulkanSPIRVTargetBackend : public TargetBackend { Builder b(context); SmallVector configItems; - // Indicates that the runtime HAL driver operates only in the legacy - // synchronous mode. - configItems.emplace_back(b.getStringAttr("legacy_sync"), b.getUnitAttr()); - configItems.emplace_back(b.getStringAttr("executable_targets"), getExecutableTargets(context)); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp index e60747a5a2dc..9761580ac94a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp @@ -41,16 +41,12 @@ struct MaterializeResourceCachesPass // likely it's already been run. We could fix the pass to better support // partial materialization but there's no use cases for that today. auto executableOps = llvm::to_vector<8>(moduleOp.getOps()); - SmallVector - descriptorSetLayoutLookupOps; SmallVector pipelineLayoutLookupOps; SmallVector executableLookupOps; for (auto funcOp : moduleOp.getOps()) { for (auto &block : funcOp.getFunctionBody()) { block.walk([&](Operation *op) { - if (auto lookupOp = dyn_cast(op)) { - descriptorSetLayoutLookupOps.push_back(lookupOp); - } else if (auto lookupOp = dyn_cast(op)) { + if (auto lookupOp = dyn_cast(op)) { pipelineLayoutLookupOps.push_back(lookupOp); } else if (auto lookupOp = dyn_cast(op)) { executableLookupOps.push_back(lookupOp); @@ -58,8 +54,7 @@ struct MaterializeResourceCachesPass }); } } - if (descriptorSetLayoutLookupOps.empty() && - pipelineLayoutLookupOps.empty() && executableLookupOps.empty()) { + if (pipelineLayoutLookupOps.empty() && executableLookupOps.empty()) { return; } @@ -85,9 +80,6 @@ struct MaterializeResourceCachesPass // Generate cached resource singletons and replace lookup ops with direct // loads from variables. - for (auto lookupOp : descriptorSetLayoutLookupOps) { - replaceDescriptorSetLayoutLookupOp(lookupOp); - } for (auto lookupOp : pipelineLayoutLookupOps) { replacePipelineLayoutLookupOp(lookupOp); } @@ -312,17 +304,6 @@ struct MaterializeResourceCachesPass [](OpResult result) -> Value { return result; }); } - void - replaceDescriptorSetLayoutLookupOp(DescriptorSetLayoutLookupOp &lookupOp) { - OpBuilder builder(lookupOp); - auto globalOp = defineDescriptorSetLayoutOp( - lookupOp.getLoc(), lookupOp.getBindings(), lookupOp.getFlags()); - auto loadedValue = globalOp.createLoadOp(lookupOp.getLoc(), builder) - .getLoadedGlobalValue(); - lookupOp.replaceAllUsesWith(loadedValue); - lookupOp.erase(); - } - void replacePipelineLayoutLookupOp(PipelineLayoutLookupOp &lookupOp) { OpBuilder builder(lookupOp); auto globalOp = diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir index 590649ef8103..d66c4009fc09 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir @@ -4,9 +4,6 @@ #executable_target_embedded_elf_aarch64 = #hal.executable.target<"llvm-cpu", "embedded-elf-aarch64"> #executable_target_embedded_elf_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64"> -#device_target_cpu = #hal.device.target<"llvm-cpu", { - executable_targets = [#executable_target_embedded_elf_x86_64] -}> #pipeline_layout = #hal.pipeline.layout, @@ -15,149 +12,144 @@ ]> ]> -// CHECK: module -module attributes {hal.device.targets = [#device_target_cpu]} { - - // CHECK: hal.executable private @ex - hal.executable private @ex { - hal.executable.variant public @embedded_elf_aarch64 target(#executable_target_embedded_elf_aarch64) { - hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout) attributes { - translation_info = #iree_codegen.translation_info - } { - ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors - %c1 = arith.constant 1 : index - %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0] - hal.return %0, %c1, %c1 : index, index, index - } - builtin.module { - // Opaque at this point (in some target-specific dialects). - } +// CHECK: hal.executable private @ex +hal.executable private @ex { + hal.executable.variant public @embedded_elf_aarch64 target(#executable_target_embedded_elf_aarch64) { + hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout) attributes { + translation_info = #iree_codegen.translation_info + } { + ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors + %c1 = arith.constant 1 : index + %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0] + hal.return %0, %c1, %c1 : index, index, index } - hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64) { - hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout) attributes { - translation_info = #iree_codegen.translation_info - } { - ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors - %c1 = arith.constant 1 : index - %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0] - hal.return %0, %c1, %c1 : index, index, index - } - builtin.module { - // Opaque at this point (in some target-specific dialects). - } + builtin.module { + // Opaque at this point (in some target-specific dialects). } } + hal.executable.variant public @embedded_elf_x86_64 target(#executable_target_embedded_elf_x86_64) { + hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout) attributes { + translation_info = #iree_codegen.translation_info + } { + ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors + %c1 = arith.constant 1 : index + %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0] + hal.return %0, %c1, %c1 : index, index, index + } + builtin.module { + // Opaque at this point (in some target-specific dialects). + } + } +} - // CHECK-LABEL: util.func public @simpleDispatch - // CHECK-SAME: (%[[ARG0:.+]]: !hal.buffer_view, %[[ARG1:.+]]: !hal.buffer_view) -> !hal.buffer_view - util.func public @simpleDispatch(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} { - %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c16 = arith.constant 16 : index - %c0 = arith.constant 0 : index - - // CHECK: %[[ARG0_BUFFER:.+]] = hal.buffer_view.buffer<%[[ARG0]] : !hal.buffer_view> : !hal.buffer - - // (annoyingly out of order) - // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} - // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator - - // CHECK: hal.buffer.assert<%[[ARG0_BUFFER]] : !hal.buffer> - // CHECK-SAME: message("tensor") - // CHECK-SAME: allocator(%[[ALLOCATOR]] : !hal.allocator) - // CHECK-SAME: minimum_length(%c16) - // CHECK-SAME: type(DeviceVisible) - // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") - %arg0_resource = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<4xf32> in !stream.resource{%c16} - - // CHECK: %[[ARG1_BUFFER:.+]] = hal.buffer_view.buffer<%[[ARG1]] : !hal.buffer_view> : !hal.buffer - // CHECK: hal.buffer.assert<%[[ARG1_BUFFER]] : !hal.buffer> - // CHECK-SAME: message("tensor") - // CHECK-SAME: allocator(%[[ALLOCATOR]] : !hal.allocator) - // CHECK-SAME: minimum_length(%c16) - // CHECK-SAME: type(DeviceVisible) - // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") - %arg1_resource = stream.tensor.import %arg1 : !hal.buffer_view -> tensor<4xf32> in !stream.resource{%c16} - - // CHECK: %[[RESULT_BUFFER:.+]] = hal.allocator.allocate<%[[ALLOCATOR]] : !hal.allocator> - // CHECK-SAME: type("DeviceVisible|DeviceLocal") - // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") - // CHECK-SAME: : !hal.buffer{%c16} - %result_resource = stream.resource.alloc uninitialized : !stream.resource{%c16} - - // CHECK: %[[CMD:.+]] = hal.command_buffer.create +// CHECK-LABEL: util.func public @simpleDispatch +// CHECK-SAME: (%[[ARG0:.+]]: !hal.buffer_view, %[[ARG1:.+]]: !hal.buffer_view) -> !hal.buffer_view +util.func public @simpleDispatch(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} { + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + %c0 = arith.constant 0 : index + + // CHECK: %[[ARG0_BUFFER:.+]] = hal.buffer_view.buffer<%[[ARG0]] : !hal.buffer_view> : !hal.buffer + + // (annoyingly out of order) + // CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} + // CHECK-DAG: %[[ALLOCATOR:.+]] = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator + + // CHECK: hal.buffer.assert<%[[ARG0_BUFFER]] : !hal.buffer> + // CHECK-SAME: message("tensor") + // CHECK-SAME: allocator(%[[ALLOCATOR]] : !hal.allocator) + // CHECK-SAME: minimum_length(%c16) + // CHECK-SAME: type(DeviceVisible) + // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") + %arg0_resource = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<4xf32> in !stream.resource{%c16} + + // CHECK: %[[ARG1_BUFFER:.+]] = hal.buffer_view.buffer<%[[ARG1]] : !hal.buffer_view> : !hal.buffer + // CHECK: hal.buffer.assert<%[[ARG1_BUFFER]] : !hal.buffer> + // CHECK-SAME: message("tensor") + // CHECK-SAME: allocator(%[[ALLOCATOR]] : !hal.allocator) + // CHECK-SAME: minimum_length(%c16) + // CHECK-SAME: type(DeviceVisible) + // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") + %arg1_resource = stream.tensor.import %arg1 : !hal.buffer_view -> tensor<4xf32> in !stream.resource{%c16} + + // CHECK: %[[RESULT_BUFFER:.+]] = hal.allocator.allocate<%[[ALLOCATOR]] : !hal.allocator> + // CHECK-SAME: type("DeviceVisible|DeviceLocal") + // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") + // CHECK-SAME: : !hal.buffer{%c16} + %result_resource = stream.resource.alloc uninitialized : !stream.resource{%c16} + + // CHECK: %[[CMD:.+]] = hal.command_buffer.create + // CHECK-SAME: device(%[[DEVICE]] : !hal.device) + // CHECK-SAME: mode("OneShot|AllowInlineExecution") + // CHECK-SAME: categories("Transfer|Dispatch") : !hal.command_buffer + %timepoint = stream.cmd.execute + with(%arg0_resource as %arg0_capture: !stream.resource{%c16}, + %arg1_resource as %arg1_capture: !stream.resource{%c16}, + %result_resource as %result_capture: !stream.resource{%c16}) { + + // CHECK-DAG: %{{.+}}, %[[FORMAT_AARCH64:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "embedded-elf-aarch64") + // CHECK-DAG: %{{.+}}, %[[FORMAT_X86_64:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "embedded-elf-x86_64") + // CHECK-DAG: %[[SWITCH1:.+]] = arith.select %[[FORMAT_X86_64]], %c1, %c-1 + // CHECK-DAG: %[[SWITCH0:.+]] = arith.select %[[FORMAT_AARCH64]], %c0, %[[SWITCH1]] + // CHECK: scf.index_switch %[[SWITCH0]] + // CHECK: case 0 { + // CHECK: %[[PIPELINE_LAYOUT:.+]] = hal.pipeline_layout.lookup // CHECK-SAME: device(%[[DEVICE]] : !hal.device) - // CHECK-SAME: mode("OneShot|AllowInlineExecution") - // CHECK-SAME: categories("Transfer|Dispatch") : !hal.command_buffer - %timepoint = stream.cmd.execute - with(%arg0_resource as %arg0_capture: !stream.resource{%c16}, - %arg1_resource as %arg1_capture: !stream.resource{%c16}, - %result_resource as %result_capture: !stream.resource{%c16}) { - - // CHECK-DAG: %{{.+}}, %[[FORMAT_AARCH64:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "embedded-elf-aarch64") - // CHECK-DAG: %{{.+}}, %[[FORMAT_X86_64:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "embedded-elf-x86_64") - // CHECK-DAG: %[[SWITCH1:.+]] = arith.select %[[FORMAT_X86_64]], %c1, %c-1 - // CHECK-DAG: %[[SWITCH0:.+]] = arith.select %[[FORMAT_AARCH64]], %c0, %[[SWITCH1]] - // CHECK: scf.index_switch %[[SWITCH0]] - // CHECK: case 0 { - // CHECK: %[[PIPELINE_LAYOUT:.+]] = hal.pipeline_layout.lookup - // CHECK-SAME: device(%[[DEVICE]] : !hal.device) - // CHECK-SAME: layout(#pipeline_layout) : !hal.pipeline_layout - // CHECK: hal.command_buffer.push_descriptor_set<%[[CMD]] : !hal.command_buffer> - // CHECK-SAME: layout(%[[PIPELINE_LAYOUT]] : !hal.pipeline_layout)[%c0] - // CHECK-SAME: bindings([ - // CHECK: %c0 = (%[[ARG0_BUFFER]] : !hal.buffer)[%c0, %c16], - // CHECK: %c1 = (%[[ARG1_BUFFER]] : !hal.buffer)[%c0, %c16], - // CHECK: %c2 = (%[[RESULT_BUFFER]] : !hal.buffer)[%c0, %c16] - // CHECK: ]) - // CHECK: hal.command_buffer.dispatch.symbol<%[[CMD]] : !hal.command_buffer> - // CHECK-SAME: target(@ex::@embedded_elf_aarch64::@dispatch) - // CHECK-SAME: workgroups([%c1, %c1, %c1]) - // CHECK: scf.yield - // CHECK: } - // CHECK: case 1 { - // CHECK: hal.command_buffer.dispatch.symbol<%[[CMD]] : !hal.command_buffer> - // CHECK-SAME: target(@ex::@embedded_elf_x86_64::@dispatch) - // CHECK: scf.yield - // CHECK: } - stream.cmd.dispatch {@ex::@embedded_elf_aarch64::@dispatch, @ex::@embedded_elf_x86_64::@dispatch}[%c4, %c1, %c1] { - ro %arg0_capture[%c0 for %c16] : !stream.resource{%c16}, - ro %arg1_capture[%c0 for %c16] : !stream.resource{%c16}, - wo %result_capture[%c0 for %c16] : !stream.resource{%c16} - } attributes { - hal.interface.bindings = [ - #hal.interface.binding<0, 0>, - #hal.interface.binding<0, 1>, - #hal.interface.binding<0, 2> - ] - } - - // CHECK: hal.command_buffer.execution_barrier<%[[CMD]] : !hal.command_buffer> - // CHECK-SAME: source("Dispatch|Transfer|CommandRetire") - // CHECK-SAME: target("CommandIssue|Dispatch|Transfer") - // CHECK: hal.command_buffer.finalize<%[[CMD]] : !hal.command_buffer> - } => !stream.timepoint - - // CHECK: %[[WAIT_FENCE:.+]] = util.null : !hal.fence - // CHECK: %[[SIGNAL_FENCE:.+]] = hal.fence.create - // CHECK: hal.device.queue.execute<%[[DEVICE]] - // CHECK-SAME: wait(%[[WAIT_FENCE]]) - // CHECK-SAME: signal(%[[SIGNAL_FENCE]]) - // CHECK-SAME: commands([%[[CMD]]]) - - // CHECK: hal.fence.await until([%[[SIGNAL_FENCE]]]) - %result_ready = stream.timepoint.await %timepoint => %result_resource : !stream.resource{%c16} - - // CHECK-DAG: %[[ELEMENT_TYPE:.+]] = hal.element_type - // CHECK-DAG: %[[ENCODING_TYPE:.+]] = hal.encoding_type - // CHECK: %[[RESULT_VIEW:.+]] = hal.buffer_view.create - // CHECK-SAME: buffer(%[[RESULT_BUFFER]] : !hal.buffer) - // CHECK-SAME: shape([%c4]) - // CHECK-SAME: type(%[[ELEMENT_TYPE]]) - // CHECK-SAME: encoding(%[[ENCODING_TYPE]]) - %result_view = stream.tensor.export %result_ready : tensor<4xf32> in !stream.resource{%c16} -> !hal.buffer_view - // CHECK: util.return - util.return %result_view : !hal.buffer_view - } + // CHECK-SAME: layout(#pipeline_layout) : !hal.pipeline_layout + // CHECK: hal.command_buffer.push_descriptor_set<%[[CMD]] : !hal.command_buffer> + // CHECK-SAME: layout(%[[PIPELINE_LAYOUT]] : !hal.pipeline_layout)[%c0] + // CHECK-SAME: bindings([ + // CHECK: %c0 = (%[[ARG0_BUFFER]] : !hal.buffer)[%c0, %c16], + // CHECK: %c1 = (%[[ARG1_BUFFER]] : !hal.buffer)[%c0, %c16], + // CHECK: %c2 = (%[[RESULT_BUFFER]] : !hal.buffer)[%c0, %c16] + // CHECK: ]) + // CHECK: hal.command_buffer.dispatch.symbol<%[[CMD]] : !hal.command_buffer> + // CHECK-SAME: target(@ex::@embedded_elf_aarch64::@dispatch) + // CHECK-SAME: workgroups([%c1, %c1, %c1]) + // CHECK: scf.yield + // CHECK: } + // CHECK: case 1 { + // CHECK: hal.command_buffer.dispatch.symbol<%[[CMD]] : !hal.command_buffer> + // CHECK-SAME: target(@ex::@embedded_elf_x86_64::@dispatch) + // CHECK: scf.yield + // CHECK: } + stream.cmd.dispatch {@ex::@embedded_elf_aarch64::@dispatch, @ex::@embedded_elf_x86_64::@dispatch}[%c4, %c1, %c1] { + ro %arg0_capture[%c0 for %c16] : !stream.resource{%c16}, + ro %arg1_capture[%c0 for %c16] : !stream.resource{%c16}, + wo %result_capture[%c0 for %c16] : !stream.resource{%c16} + } attributes { + hal.interface.bindings = [ + #hal.interface.binding<0, 0>, + #hal.interface.binding<0, 1>, + #hal.interface.binding<0, 2> + ] + } + // CHECK: hal.command_buffer.execution_barrier<%[[CMD]] : !hal.command_buffer> + // CHECK-SAME: source("Dispatch|Transfer|CommandRetire") + // CHECK-SAME: target("CommandIssue|Dispatch|Transfer") + // CHECK: hal.command_buffer.finalize<%[[CMD]] : !hal.command_buffer> + } => !stream.timepoint + + // CHECK: %[[WAIT_FENCE:.+]] = util.null : !hal.fence + // CHECK: %[[SIGNAL_FENCE:.+]] = hal.fence.create + // CHECK: hal.device.queue.execute<%[[DEVICE]] + // CHECK-SAME: wait(%[[WAIT_FENCE]]) + // CHECK-SAME: signal(%[[SIGNAL_FENCE]]) + // CHECK-SAME: commands([%[[CMD]]]) + + // CHECK: hal.fence.await until([%[[SIGNAL_FENCE]]]) + %result_ready = stream.timepoint.await %timepoint => %result_resource : !stream.resource{%c16} + + // CHECK-DAG: %[[ELEMENT_TYPE:.+]] = hal.element_type + // CHECK-DAG: %[[ENCODING_TYPE:.+]] = hal.encoding_type + // CHECK: %[[RESULT_VIEW:.+]] = hal.buffer_view.create + // CHECK-SAME: buffer(%[[RESULT_BUFFER]] : !hal.buffer) + // CHECK-SAME: shape([%c4]) + // CHECK-SAME: type(%[[ELEMENT_TYPE]]) + // CHECK-SAME: encoding(%[[ENCODING_TYPE]]) + %result_view = stream.tensor.export %result_ready : tensor<4xf32> in !stream.resource{%c16} -> !hal.buffer_view + // CHECK: util.return + util.return %result_view : !hal.buffer_view } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir index d706534b3c61..287822afe3ae 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir @@ -1,32 +1,5 @@ // RUN: iree-opt --split-input-file --iree-hal-materialize-resource-caches %s | FileCheck %s -// CHECK: util.global private @_descriptor_set_layout_0 : !hal.descriptor_set_layout -// CHECK-NEXT: util.initializer { -// CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}} -// CHECK-NEXT: %[[LAYOUT:.+]] = hal.descriptor_set_layout.create -// CHECK-SAME: device(%[[DEVICE]] : !hal.device) -// CHECK-SAME: flags("None") -// CHECK-SAME: bindings([ -// CHECK-SAME: #hal.descriptor_set.binding<0, storage_buffer>, -// CHECK-SAME: #hal.descriptor_set.binding<1, storage_buffer> -// CHECK-SAME: ]) : !hal.descriptor_set_layout -// CHECK-NEXT: util.global.store %[[LAYOUT]], @_descriptor_set_layout_0 : !hal.descriptor_set_layout - -// CHECK-LABEL: @descriptorSetLayoutLookup -util.func public @descriptorSetLayoutLookup(%device : !hal.device) -> !hal.descriptor_set_layout { - // CHECK-NEXT: %[[LAYOUT:.+]] = util.global.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout - %0 = hal.descriptor_set_layout.lookup device(%device : !hal.device) - flags("None") - bindings([ - #hal.descriptor_set.binding<0, storage_buffer>, - #hal.descriptor_set.binding<1, storage_buffer> - ]) : !hal.descriptor_set_layout - // CHECK-NEXT: util.return %[[LAYOUT]] - util.return %0 : !hal.descriptor_set_layout -} - -// ----- - // CHECK: util.global private @_descriptor_set_layout_0 : !hal.descriptor_set_layout // CHECK: util.global private @_pipeline_layout_0 : !hal.pipeline_layout @@ -55,53 +28,6 @@ util.func public @exeLayoutLookup(%device : !hal.device) -> !hal.pipeline_layout // ----- -// CHECK: util.global private @_descriptor_set_layout_0 -// CHECK: util.global private @_descriptor_set_layout_1 - -// CHECK: util.global private @_pipeline_layout_0 : !hal.pipeline_layout -// CHECK-NEXT: util.initializer { -// CHECK-DAG: %[[SET0:.+]] = util.global.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout -// CHECK-DAG: %[[SET1:.+]] = util.global.load @_descriptor_set_layout_1 : !hal.descriptor_set_layout -// CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} -// CHECK-NEXT: %[[LAYOUT:.+]] = hal.pipeline_layout.create -// CHECK-SAME: device(%[[DEVICE]] : !hal.device) -// CHECK-SAME: push_constants(1) -// CHECK-SAME: layouts([%[[SET0]], %[[SET1]]]) : !hal.pipeline_layout -// CHECK-NEXT: util.global.store %[[LAYOUT]], @_pipeline_layout_0 : !hal.pipeline_layout - -// CHECK-LABEL: @sharedLayoutLookup -util.func public @sharedLayoutLookup(%device : !hal.device) -> !hal.pipeline_layout { - // CHECK: %[[LAYOUT:.+]] = util.global.load @_pipeline_layout_0 : !hal.pipeline_layout - %0 = hal.pipeline_layout.lookup device(%device : !hal.device) - layout(#hal.pipeline.layout, - #hal.descriptor_set.binding<1, storage_buffer> - ]>, - #hal.descriptor_set.layout<1, bindings = [ - #hal.descriptor_set.binding<0, uniform_buffer>, - #hal.descriptor_set.binding<1, uniform_buffer> - ]> - ]>) : !hal.pipeline_layout - // CHECK-NEXT: util.return %[[LAYOUT]] - util.return %0 : !hal.pipeline_layout -} - -// CHECK: @otherDescriptorSetLayoutLookup -util.func public @otherDescriptorSetLayoutLookup(%device : !hal.device) -> !hal.descriptor_set_layout { - // CHECK: %[[LAYOUT:.+]] = util.global.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout - %0 = hal.descriptor_set_layout.lookup device(%device : !hal.device) - flags(None) - bindings([ - #hal.descriptor_set.binding<0, storage_buffer>, - #hal.descriptor_set.binding<1, storage_buffer> - ]) : !hal.descriptor_set_layout - // CHECK-NEXT: util.return %[[LAYOUT]] - util.return %0 : !hal.descriptor_set_layout -} - -// ----- - #pipeline_layout_0 = #hal.pipeline.layout, @@ -116,8 +42,6 @@ util.func public @otherDescriptorSetLayoutLookup(%device : !hal.device) -> !hal. ]> ]> -module attributes {hal.device.targets = [#hal.device.target<"llvm-cpu">]} { - // TODO(scotttodd): Test without depending on a specific HAL target? Or move to HAL/Target/*/test/? // - If there is no matching hal.executable.variant then the executable will not be cached hal.executable @exe { @@ -224,8 +148,6 @@ util.func public @exeLookup(%device : !hal.device) -> !hal.executable { util.return %0 : !hal.executable } -} - // ----- // Tests that materialization no-ops when resource caches have already been @@ -242,8 +164,6 @@ util.func public @exeLookup(%device : !hal.device) -> !hal.executable { ]> ]> -module attributes {hal.device.targets = [#hal.device.target<"llvm-cpu">]} { - util.global private @_descriptor_set_layout_0 : !hal.descriptor_set_layout util.initializer { %c0 = arith.constant 0 : index @@ -299,5 +219,3 @@ util.func public @exeLookup(%device : !hal.device) -> !hal.executable { // CHECK-NEXT: util.return %[[EXE]] util.return %0 : !hal.executable } - -} diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp index 2b928657d23b..353693af7d7a 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp @@ -53,8 +53,7 @@ bool isLegalConstExprType(Type t) { // support, but for now the consteval JIT has interop limitations. Lift // this restriction when the JIT interops for all types. auto bitWidth = t.getIntOrFloatBitWidth(); - return bitWidth == 1 || bitWidth == 8 || bitWidth == 16 || bitWidth == 32 || - bitWidth == 64; + return llvm::isPowerOf2_64(bitWidth) && bitWidth != 2 && bitWidth <= 64; } if (llvm::isa(t)) { diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/flow_hoist_into_globals.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/flow_hoist_into_globals.mlir index 8cf38af3b682..7817c2974c43 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/flow_hoist_into_globals.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/flow_hoist_into_globals.mlir @@ -81,6 +81,32 @@ module @hoist_sub_byte_tensor_transitive { // ----- +// CHECK-LABEL: @hoist_sub_byte_aligned_scalar_transitive +// CHECK-NOT: util.global +module @hoist_sub_byte_aligned_scalar_transitive { + func.func @main() -> i4 { + %c1_i4 = arith.constant 1 : i4 + %0 = "iree_unregistered.const_expr"(%c1_i4) : (i4) -> i4 + return %0 : i4 + } +} + +// ----- + +// CHECK-LABEL: @hoist_constant_pack_computation +// CHECK: util.global +module @hoist_constant_pack_computation { + func.func @main() -> tensor<4x1x16x2xi4> { + %pad = arith.constant 5 : i4 + %val1 = stablehlo.constant dense<3> : tensor<7x15xi4> + %val2 = tensor.empty() : tensor<4x1x16x2xi4> + %ret = tensor.pack %val1 padding_value(%pad : i4) inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %val2 : tensor<7x15xi4> -> tensor<4x1x16x2xi4> + return %ret : tensor<4x1x16x2xi4> + } +} + +// ----- + // We should not hoist metadata ops alone. // CHECK-LABEL: @do_not_hoist_metadata_leaf // CHECK-NOT: util.global @@ -91,3 +117,4 @@ module @do_not_hoist_metadata_leaf { util.return %1 : tensor<1xi32> } } + diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir index 61b257225be7..9184e99ddcb4 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir @@ -36,7 +36,7 @@ module attributes {hal.device.targets = [#device_target_llvm_cpu]} { #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> #map2 = affine_map<(d0, d1, d2) -> (d2, d1)> #map3 = affine_map<(d0, d1, d2) -> (d0, d1)> -#device_target_vulkan = #hal.device.target<"vulkan", {executable_targets = [#executable_target_vulkan_spirv_fb], legacy_sync}> +#device_target_vulkan = #hal.device.target<"vulkan", {executable_targets = [#executable_target_vulkan_spirv_fb]}> module attributes {hal.device.targets = [#device_target_vulkan]} { util.func public @lhs_encoding(%arg0: tensor) -> tensor { %cst = arith.constant 0.000000e+00 : f32 @@ -71,7 +71,7 @@ module attributes {hal.device.targets = [#device_target_vulkan]} { #executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", cpu_features = "+avx512f"}> #device_target_llvm_cpu = #hal.device.target<"llvm-cpu", {executable_targets = [#executable_target_embedded_elf_x86_64_]}> #executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan", "vulkan-spirv-fb"> -#device_target_vulkan = #hal.device.target<"vulkan", {executable_targets = [#executable_target_vulkan_spirv_fb], legacy_sync}> +#device_target_vulkan = #hal.device.target<"vulkan", {executable_targets = [#executable_target_vulkan_spirv_fb]}> module attributes {hal.device.targets = [#device_target_vulkan, #device_target_llvm_cpu]} { util.func public @lhs_encoding(%arg0: tensor) -> tensor { %cst = arith.constant 0.000000e+00 : f32 diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.td b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.td index 1f03b8f9dbed..d1e3981b683b 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.td +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/IOParametersOps.td @@ -7,7 +7,7 @@ #ifndef IREE_DIALECT_MODULES_IO_PARAMETERS_OPS #define IREE_DIALECT_MODULES_IO_PARAMETERS_OPS -include "iree/compiler/Dialect/HAL/IR/HALBase.td" +include "iree/compiler/Dialect/HAL/IR/HALAttrs.td" include "iree/compiler/Dialect/Util/IR/UtilAttrs.td" include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td" include "iree/compiler/Modules/IO/Parameters/IR/IOParametersBase.td" diff --git a/experimental/hip/hip_device.c b/experimental/hip/hip_device.c index 488313a01317..c0f203afc274 100644 --- a/experimental/hip/hip_device.c +++ b/experimental/hip/hip_device.c @@ -382,8 +382,15 @@ static iree_status_t iree_hal_hip_device_query_attribute( static iree_status_t iree_hal_hip_device_query_i64( iree_hal_device_t* base_device, iree_string_view_t category, iree_string_view_t key, int64_t* out_value) { + iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); *out_value = 0; + if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) { + *out_value = + iree_string_view_match_pattern(device->identifier, key) ? 1 : 0; + return iree_ok_status(); + } + if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) { *out_value = iree_string_view_equal(key, IREE_SV("rocm-hsaco-fb")) ? 1 : 0; return iree_ok_status(); diff --git a/experimental/rocm/rocm_device.c b/experimental/rocm/rocm_device.c index 24da6bd7da66..167415bb5084 100644 --- a/experimental/rocm/rocm_device.c +++ b/experimental/rocm/rocm_device.c @@ -231,9 +231,15 @@ static void iree_hal_rocm_replace_channel_provider( static iree_status_t iree_hal_rocm_device_query_i64( iree_hal_device_t* base_device, iree_string_view_t category, iree_string_view_t key, int64_t* out_value) { - // iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); + iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device); *out_value = 0; + if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) { + *out_value = + iree_string_view_match_pattern(device->identifier, key) ? 1 : 0; + return iree_ok_status(); + } + if (iree_string_view_equal(category, iree_make_cstring_view("hal.executable.format"))) { *out_value = diff --git a/experimental/webgpu/webgpu_device.c b/experimental/webgpu/webgpu_device.c index 066261a595ce..8af38c0287d7 100644 --- a/experimental/webgpu/webgpu_device.c +++ b/experimental/webgpu/webgpu_device.c @@ -212,10 +212,15 @@ static iree_status_t iree_hal_webgpu_device_trim( static iree_status_t iree_hal_webgpu_device_query_i64( iree_hal_device_t* base_device, iree_string_view_t category, iree_string_view_t key, int64_t* out_value) { - // iree_hal_webgpu_device_t* device = - // iree_hal_webgpu_device_cast(base_device); + iree_hal_webgpu_device_t* device = iree_hal_webgpu_device_cast(base_device); *out_value = 0; + if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) { + *out_value = + iree_string_view_match_pattern(device->identifier, key) ? 1 : 0; + return iree_ok_status(); + } + if (iree_string_view_equal(category, iree_make_cstring_view("hal.executable.format"))) { *out_value = diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c index 9189b4cfe55d..f7d550c7aac0 100644 --- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c +++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c @@ -408,6 +408,12 @@ static iree_status_t iree_hal_cuda_device_query_i64( iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device); *out_value = 0; + if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) { + *out_value = + iree_string_view_match_pattern(device->identifier, key) ? 1 : 0; + return iree_ok_status(); + } + if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) { *out_value = iree_string_view_equal(key, IREE_SV("cuda-nvptx-fb")) ? 1 : 0; return iree_ok_status(); diff --git a/runtime/src/iree/hal/drivers/local_sync/sync_device.c b/runtime/src/iree/hal/drivers/local_sync/sync_device.c index 7a85962ec682..711704c8a678 100644 --- a/runtime/src/iree/hal/drivers/local_sync/sync_device.c +++ b/runtime/src/iree/hal/drivers/local_sync/sync_device.c @@ -185,6 +185,12 @@ static iree_status_t iree_hal_sync_device_query_i64( iree_hal_sync_device_t* device = iree_hal_sync_device_cast(base_device); *out_value = 0; + if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) { + *out_value = + iree_string_view_match_pattern(device->identifier, key) ? 1 : 0; + return iree_ok_status(); + } + if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) { *out_value = iree_hal_query_any_executable_loader_support( @@ -192,7 +198,9 @@ static iree_status_t iree_hal_sync_device_query_i64( ? 1 : 0; return iree_ok_status(); - } else if (iree_string_view_equal(category, IREE_SV("hal.device"))) { + } + + if (iree_string_view_equal(category, IREE_SV("hal.device"))) { if (iree_string_view_equal(key, IREE_SV("concurrency"))) { *out_value = 1; return iree_ok_status(); diff --git a/runtime/src/iree/hal/drivers/local_task/task_device.c b/runtime/src/iree/hal/drivers/local_task/task_device.c index 19a56396321f..7601af55b9ac 100644 --- a/runtime/src/iree/hal/drivers/local_task/task_device.c +++ b/runtime/src/iree/hal/drivers/local_task/task_device.c @@ -226,6 +226,12 @@ static iree_status_t iree_hal_task_device_query_i64( iree_hal_task_device_t* device = iree_hal_task_device_cast(base_device); *out_value = 0; + if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) { + *out_value = + iree_string_view_match_pattern(device->identifier, key) ? 1 : 0; + return iree_ok_status(); + } + if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) { *out_value = iree_hal_query_any_executable_loader_support( @@ -233,7 +239,9 @@ static iree_status_t iree_hal_task_device_query_i64( ? 1 : 0; return iree_ok_status(); - } else if (iree_string_view_equal(category, IREE_SV("hal.device"))) { + } + + if (iree_string_view_equal(category, IREE_SV("hal.device"))) { if (iree_string_view_equal(key, IREE_SV("concurrency"))) { *out_value = (int64_t)device->queue_count; return iree_ok_status(); diff --git a/runtime/src/iree/hal/drivers/metal/metal_device.m b/runtime/src/iree/hal/drivers/metal/metal_device.m index f88747ca6d63..05878b08b1f8 100644 --- a/runtime/src/iree/hal/drivers/metal/metal_device.m +++ b/runtime/src/iree/hal/drivers/metal/metal_device.m @@ -216,8 +216,14 @@ static iree_status_t iree_hal_metal_device_trim(iree_hal_device_t* base_device) static iree_status_t iree_hal_metal_device_query_i64(iree_hal_device_t* base_device, iree_string_view_t category, iree_string_view_t key, int64_t* out_value) { + iree_hal_metal_device_t* device = iree_hal_metal_device_cast(base_device); *out_value = 0; + if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) { + *out_value = iree_string_view_match_pattern(device->identifier, key) ? 1 : 0; + return iree_ok_status(); + } + if (iree_string_view_equal(category, iree_make_cstring_view("hal.executable.format"))) { *out_value = iree_string_view_equal(key, iree_make_cstring_view("metal-msl-fb")) ? 1 : 0; return iree_ok_status(); diff --git a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc index aeebe69e035f..094118450159 100644 --- a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc +++ b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc @@ -1420,6 +1420,12 @@ static iree_status_t iree_hal_vulkan_device_query_i64( iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); *out_value = 0; + if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) { + *out_value = + iree_string_view_match_pattern(device->identifier, key) ? 1 : 0; + return iree_ok_status(); + } + if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) { if (iree_string_view_equal(key, IREE_SV("vulkan-spirv-fb"))) { // Base SPIR-V always supported. diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c index d00c39a1c10d..824766b9d64a 100644 --- a/runtime/src/iree/modules/hal/module.c +++ b/runtime/src/iree/modules/hal/module.c @@ -1155,13 +1155,11 @@ IREE_VM_ABI_EXPORT(iree_hal_module_devices_count, // IREE_VM_ABI_EXPORT(iree_hal_module_devices_get, // iree_hal_module_state_t, // i, r) { - if (args->i0 >= state->device_count) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "device index %d out of bounds (%" PRIhsz - " devices available)", - args->i0, state->device_count); + if (args->i0 < state->device_count) { + rets->r0 = iree_hal_device_retain_ref(state->devices[args->i0]); + } else { + rets->r0 = iree_vm_ref_null(); } - rets->r0 = iree_hal_device_retain_ref(state->devices[args->i0]); return iree_ok_status(); } diff --git a/samples/custom_dispatch/cuda/kernels/example.mlir b/samples/custom_dispatch/cuda/kernels/example.mlir index 22355a8d548d..1438b20fc697 100644 --- a/samples/custom_dispatch/cuda/kernels/example.mlir +++ b/samples/custom_dispatch/cuda/kernels/example.mlir @@ -28,9 +28,7 @@ executable_targets = [ #nvptx_sm_52_target, #nvptx_sm_80_target - ], - // HACK: CUDA target currently uses the legacy synchronous execution model. - legacy_sync + ] }> module @example attributes {hal.device.targets = [#cuda_target]} { diff --git a/samples/custom_dispatch/vulkan/shaders/example.mlir b/samples/custom_dispatch/vulkan/shaders/example.mlir index daf42ffc2a68..0aa2b31a1994 100644 --- a/samples/custom_dispatch/vulkan/shaders/example.mlir +++ b/samples/custom_dispatch/vulkan/shaders/example.mlir @@ -25,9 +25,7 @@ // It's possible, for example, to support targeting multiple devices in the same // compiled binary. #vulkan_target = #hal.device.target<"vulkan", { - executable_targets = [#spirv_target], - // HACK: Vulkan target currently uses the legacy synchronous execution model. - legacy_sync + executable_targets = [#spirv_target] }> module @example attributes {hal.device.targets = [#vulkan_target]} { diff --git a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir index 5979d7698191..ea816803f7f5 100644 --- a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir +++ b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir @@ -25,9 +25,7 @@ // It's possible, for example, to support targeting multiple devices in the same // compiled binary. #vulkan_target = #hal.device.target<"vulkan", { - executable_targets = [#spirv_target], - // HACK: Vulkan target currently uses the legacy synchronous execution model. - legacy_sync + executable_targets = [#spirv_target] }> module @example attributes {hal.device.targets = [#vulkan_target]} { diff --git a/samples/custom_dispatch/vulkan/shaders/example_transform.mlir b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir index 82662ac6b56e..d48c2f97aed3 100644 --- a/samples/custom_dispatch/vulkan/shaders/example_transform.mlir +++ b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir @@ -32,9 +32,7 @@ // kernel that supports multiple targets by specifying an object per-target, but // that requires authoring the kernel for multiple targets. #vulkan_target = #hal.device.target<"vulkan", { - executable_targets = [#spirv_target], - // HACK: Vulkan target currently uses the legacy synchronous execution model. - legacy_sync + executable_targets = [#spirv_target] }> #map = affine_map<(d0, d1) -> (d0, d1)> diff --git a/samples/transform_dialect/example_module.mlir b/samples/transform_dialect/example_module.mlir index 1e4ac4e8ed4d..7b4743bbb497 100644 --- a/samples/transform_dialect/example_module.mlir +++ b/samples/transform_dialect/example_module.mlir @@ -5,7 +5,7 @@ // !B_size = tensor<5x16xf32> // !C_size = tensor<16x16xf32> // !O_size = tensor<16xf32> -// +// // module { // func.func @example_module(%A : !A_size, %B : !B_size, %C : !C_size) -> !O_size { // %0 = linalg.add ins(%A, %A : !A_size, !A_size) @@ -16,10 +16,10 @@ // %2 = linalg.reduce // ins(%1 : !C_size) // outs(%empty : !O_size) -// dimensions = [1] +// dimensions = [1] // (%in: f32, %out: f32) { -// %3 = arith.addf %out, %in: f32 -// linalg.yield %3: f32 +// %3 = arith.addf %out, %in: f32 +// linalg.yield %3: f32 // } // return %2 : !O_size // } @@ -27,13 +27,13 @@ #target_env = #spirv.target_env<#spirv.vce, api=Vulkan, #spirv.resource_limits> -module attributes {hal.device.targets = [#hal.device.target<"vulkan", {executable_targets = [#hal.executable.target<"vulkan", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce, api=Vulkan, #spirv.resource_limits>}>], legacy_sync}>]} { +module attributes {hal.device.targets = [#hal.device.target<"vulkan", {executable_targets = [#hal.executable.target<"vulkan", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce, api=Vulkan, #spirv.resource_limits>}>]}>]} { hal.executable private @example_module_dispatch_0 { hal.executable.variant public @vulkan_spirv_fb target(<"vulkan", "vulkan-spirv-fb", {spirv.target_env = #target_env}>) { hal.executable.export public @example_module_dispatch_0_generic_80_f32 ordinal(0) layout( #hal.pipeline.layout, <1, storage_buffer>]>]>) { ^bb0(%arg0: !hal.device): - %x, %y, %z = flow.dispatch.workgroup_count_from_slice + %x, %y, %z = flow.dispatch.workgroup_count_from_slice hal.return %x, %y, %z : index, index, index } builtin.module { @@ -59,7 +59,7 @@ module attributes {hal.device.targets = [#hal.device.target<"vulkan", {executabl hal.executable.export public @example_module_dispatch_1_matmul_16x16x5_f32 ordinal(0) layout( #hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) { ^bb0(%arg0: !hal.device): - %x, %y, %z = flow.dispatch.workgroup_count_from_slice + %x, %y, %z = flow.dispatch.workgroup_count_from_slice hal.return %x, %y, %z : index, index, index } builtin.module { @@ -83,7 +83,7 @@ module attributes {hal.device.targets = [#hal.device.target<"vulkan", {executabl hal.executable.export public @example_module_dispatch_2_generic_16x16_f32 ordinal(0) layout( #hal.pipeline.layout, <1, storage_buffer>]>]>) { ^bb0(%arg0: !hal.device): - %x, %y, %z = flow.dispatch.workgroup_count_from_slice + %x, %y, %z = flow.dispatch.workgroup_count_from_slice hal.return %x, %y, %z : index, index, index } builtin.module { diff --git a/tests/e2e/stablehlo_ops/BUILD.bazel b/tests/e2e/stablehlo_ops/BUILD.bazel index 2e9bc1c0725e..9b7326373c53 100644 --- a/tests/e2e/stablehlo_ops/BUILD.bazel +++ b/tests/e2e/stablehlo_ops/BUILD.bazel @@ -414,7 +414,6 @@ iree_check_single_backend_test_suite( compiler_flags = [ # TODO(#13984): memset emulation required for graphs. "--iree-stream-emulate-memset", - "--iree-hal-cuda-enable-legacy-sync=false", ], driver = "cuda", input_type = "stablehlo", @@ -433,9 +432,6 @@ iree_check_single_backend_test_suite( iree_check_single_backend_test_suite( name = "check_cuda_stream", srcs = CUDA_SRCS, - compiler_flags = [ - "--iree-hal-cuda-enable-legacy-sync=false", - ], driver = "cuda", input_type = "stablehlo", runner_args = ["--cuda_use_streams=true"], diff --git a/tests/e2e/stablehlo_ops/CMakeLists.txt b/tests/e2e/stablehlo_ops/CMakeLists.txt index b322b8ebe01f..65cb2260da5a 100644 --- a/tests/e2e/stablehlo_ops/CMakeLists.txt +++ b/tests/e2e/stablehlo_ops/CMakeLists.txt @@ -373,7 +373,6 @@ iree_check_single_backend_test_suite( "cuda" COMPILER_FLAGS "--iree-stream-emulate-memset" - "--iree-hal-cuda-enable-legacy-sync=false" INPUT_TYPE "stablehlo" RUNNER_ARGS @@ -455,8 +454,6 @@ iree_check_single_backend_test_suite( "cuda" DRIVER "cuda" - COMPILER_FLAGS - "--iree-hal-cuda-enable-legacy-sync=false" INPUT_TYPE "stablehlo" RUNNER_ARGS diff --git a/tests/e2e/tosa_ops/BUILD.bazel b/tests/e2e/tosa_ops/BUILD.bazel index 121155f768ad..30546725a248 100644 --- a/tests/e2e/tosa_ops/BUILD.bazel +++ b/tests/e2e/tosa_ops/BUILD.bazel @@ -301,7 +301,6 @@ iree_check_single_backend_test_suite( compiler_flags = [ # TODO(#13984): memset emulation required for graphs. "--iree-stream-emulate-memset", - "--iree-hal-cuda-enable-legacy-sync=false", ], driver = "cuda", input_type = "tosa", @@ -320,9 +319,6 @@ iree_check_single_backend_test_suite( iree_check_single_backend_test_suite( name = "check_cuda_stream", srcs = CUDA_SRCS, - compiler_flags = [ - "--iree-hal-cuda-enable-legacy-sync=false", - ], driver = "cuda", input_type = "tosa", runner_args = ["--cuda_use_streams=true"], diff --git a/tests/e2e/tosa_ops/CMakeLists.txt b/tests/e2e/tosa_ops/CMakeLists.txt index 761555162290..f3578a534834 100644 --- a/tests/e2e/tosa_ops/CMakeLists.txt +++ b/tests/e2e/tosa_ops/CMakeLists.txt @@ -272,7 +272,6 @@ iree_check_single_backend_test_suite( "cuda" COMPILER_FLAGS "--iree-stream-emulate-memset" - "--iree-hal-cuda-enable-legacy-sync=false" INPUT_TYPE "tosa" RUNNER_ARGS @@ -333,8 +332,6 @@ iree_check_single_backend_test_suite( "cuda" DRIVER "cuda" - COMPILER_FLAGS - "--iree-hal-cuda-enable-legacy-sync=false" INPUT_TYPE "tosa" RUNNER_ARGS diff --git a/third_party/llvm-project b/third_party/llvm-project index 886294a2fe59..3c25a9de91e9 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit 886294a2fe5928ecf34299e02526e17be19910c6 +Subproject commit 3c25a9de91e9c76ffba82939b19eafe3e60d51f7 diff --git a/third_party/torch-mlir b/third_party/torch-mlir index e7a09440d380..135c81a4165f 160000 --- a/third_party/torch-mlir +++ b/third_party/torch-mlir @@ -1 +1 @@ -Subproject commit e7a09440d380827e90b94ef33bd82f32fda8874a +Subproject commit 135c81a4165f9e4c9070d72c485efece887d64f8