Skip to content

Commit

Permalink
[NFC] Move LLVMCPULowerToUKernels pass to Common/CPU (iree-org#15590)
Browse files Browse the repository at this point in the history
This removes the dep between Codegen/LLVMCPU and Codegen/VMVX. The next
step to untangle LLVMCPU and VMVX better is having VMVX
SelectLoweringStrategy. We will be able to remove LLVMCPU dep from
`compiler/Dialect/VMVX/Transforms/`.

It also moves `getCastOpOfElementWiseCast` to be a static method in
`CPULowerToUKernels.cpp`.
  • Loading branch information
hanhanW authored Nov 14, 2023
1 parent ab0dee1 commit 5c95ec0
Show file tree
Hide file tree
Showing 20 changed files with 65 additions and 72 deletions.
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ iree_compiler_cc_library(
iree_compiler_cc_library(
name = "CommonCPUPasses",
srcs = [
"CPULowerToUKernels.cpp",
"CPUMaterializeEncodingPass.cpp",
"Passes.cpp",
],
Expand All @@ -56,12 +57,14 @@ iree_compiler_cc_library(
":PassesIncGen",
"//compiler/src/iree/compiler/Codegen/Common",
"//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Interfaces:UKernelOpInterface",
"//compiler/src/iree/compiler/Codegen/Transforms",
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
"//llvm-external-projects/iree-dialects:IREELinalgExtTransforms",
"//llvm-external-projects/iree-dialects:IREELinalgExtUtils",
"//runtime/src/iree/builtins/ukernel:exported_bits",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineTransforms",
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ iree_cc_library(
HDRS
"Passes.h"
SRCS
"CPULowerToUKernels.cpp"
"CPUMaterializeEncodingPass.cpp"
"Passes.cpp"
DEPS
Expand Down Expand Up @@ -75,8 +76,10 @@ iree_cc_library(
MLIRVectorDialect
MLIRVectorToSCF
MLIRVectorTransforms
iree::builtins::ukernel::exported_bits
iree::compiler::Codegen::Common
iree::compiler::Codegen::Dialect::IREECodegenDialect
iree::compiler::Codegen::Interfaces::UKernelOpInterface
iree::compiler::Codegen::Transforms
iree::compiler::Codegen::Utils
iree::compiler::Dialect::HAL::IR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <iree/compiler/Codegen/Utils/Utils.h>
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/builtins/ukernel/exported_bits.h"
#include "iree/compiler/Codegen/Common/CPU/PassDetail.h"
#include "iree/compiler/Codegen/Common/CPU/Passes.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/UKernelOps.h"
#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "iree/compiler/Codegen/LLVMCPU/Utils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
Expand All @@ -28,11 +28,36 @@
namespace mlir {
namespace iree_compiler {

// Returns the CastOpInterface op of the body, if
// - the `genericOp` is element-wise with identity maps, and
// - it has only a CastOpInterface op.
// Returns std::nullopt, otherwise.
static std::optional<CastOpInterface>
getCastOpOfElementWiseCast(linalg::GenericOp genericOp) {
if (!genericOp || genericOp.getNumDpsInputs() != 1 ||
genericOp.getNumDpsInits() != 1 ||
genericOp.getBody()->getOperations().size() != 2 ||
!isElementwise(genericOp)) {
return std::nullopt;
}
auto yieldOp = cast<linalg::YieldOp>(genericOp.getBody()->getTerminator());
auto castOp = yieldOp->getOperand(0).getDefiningOp<CastOpInterface>();
if (!castOp) {
return std::nullopt;
}
Value castIn = castOp->getOperand(0);
if (castIn.isa<BlockArgument>() &&
castIn.cast<BlockArgument>().getArgNumber() != 0) {
return std::nullopt;
}
return castOp;
}

namespace {
class LLVMCPULowerToUKernelsPass
: public LLVMCPULowerToUKernelsBase<LLVMCPULowerToUKernelsPass> {
class CPULowerToUKernelsPass
: public CPULowerToUKernelsBase<CPULowerToUKernelsPass> {
public:
LLVMCPULowerToUKernelsPass(bool skipIntermediateRoundings)
CPULowerToUKernelsPass(bool skipIntermediateRoundings)
: skipIntermediateRoundings(skipIntermediateRoundings) {}

void getDependentDialects(DialectRegistry &registry) const override {
Expand Down Expand Up @@ -545,7 +570,7 @@ struct LowerToUKernelPattern : OpRewritePattern<OpType> {

} // namespace

void LLVMCPULowerToUKernelsPass::runOnOperation() {
void CPULowerToUKernelsPass::runOnOperation() {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
// Enabling a lowering of an op to a microkernel is a trade-off between the
Expand Down Expand Up @@ -585,9 +610,8 @@ void LLVMCPULowerToUKernelsPass::runOnOperation() {
}

std::unique_ptr<OperationPass<>>
createLLVMCPULowerToUKernelsPass(bool skipIntermediateRoundings) {
return std::make_unique<LLVMCPULowerToUKernelsPass>(
skipIntermediateRoundings);
createCPULowerToUKernelsPass(bool skipIntermediateRoundings) {
return std::make_unique<CPULowerToUKernelsPass>(skipIntermediateRoundings);
}

} // namespace iree_compiler
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/CPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ createCPUMaterializeUpperBoundTileSizePass(
/// Adds CPU bufferization passes to the pipeline.
void addCPUBufferizePasses(OpPassManager &passManager);

/// Pass to lower a sequence of operations to a iree_codegen.ukernel.*
/// operation.
std::unique_ptr<OperationPass<>>
createCPULowerToUKernelsPass(bool skipIntermediateRoundings = true);

void registerCodegenCommonCPUPasses();

} // namespace iree_compiler
Expand Down
13 changes: 13 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/CPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,17 @@ def CPUMaterializeUpperBoundTileSize :
let constructor = "mlir::iree_compiler::createCPUMaterializeUpperBoundTileSizePass()";
}

def CPULowerToUKernels :
Pass<"iree-codegen-cpu-lower-to-ukernels", ""> {
let summary =
"Separate out parts of the IR that lower to a micro-kernel";
let constructor =
"mlir::iree_compiler::createCPULowerToUKernelsPass()";
let options = [
Option<"optionSkipIntermediateRoundings", "skip-intermediate-roundings",
"bool", /*default=*/"true",
"Allow skipping intermediate roundings, e.g. in f16 ukernels internally doing f32 arithmetic.">,
];
}

#endif // IREE_CODEGEN_COMMON_CPU_PASSES
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ iree_lit_test_suite(
# keep sorted
[
"llvmcpu_materialize_encoding.mlir",
"lower_to_ukernel_ops.mlir",
"vmvx_materialize_encoding.mlir",
],
include = ["*.mlir"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ iree_lit_test_suite(
lit
SRCS
"llvmcpu_materialize_encoding.mlir"
"lower_to_ukernel_ops.mlir"
"vmvx_materialize_encoding.mlir"
TOOLS
FileCheck
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-lower-to-ukernels{skip-intermediate-roundings=true},cse,canonicalize))" %s | FileCheck %s
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-lower-to-ukernels{skip-intermediate-roundings=false},cse,canonicalize))" %s | FileCheck %s --check-prefix=NOSKIPROUND
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-cpu-lower-to-ukernels{skip-intermediate-roundings=true},cse,canonicalize))" %s | FileCheck %s
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-cpu-lower-to-ukernels{skip-intermediate-roundings=false},cse,canonicalize))" %s | FileCheck %s --check-prefix=NOSKIPROUND

func.func @mmt4d_f32f32f32(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> attributes {
Expand Down
3 changes: 0 additions & 3 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ iree_compiler_cc_library(
"LLVMCPUEmitVectorizationRemarks.cpp",
"LLVMCPULinkExecutables.cpp",
"LLVMCPULowerExecutableTarget.cpp",
"LLVMCPULowerToUKernels.cpp",
"LLVMCPUMmt4dVectorLowering.cpp",
"LLVMCPUPeel.cpp",
"LLVMCPUSelectLoweringStrategy.cpp",
Expand Down Expand Up @@ -89,7 +88,6 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Common/CPU:CommonCPUPasses",
"//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Interfaces:PartitionableLoopsInterface",
"//compiler/src/iree/compiler/Codegen/Interfaces:UKernelOpInterface",
"//compiler/src/iree/compiler/Codegen/TransformStrategies/CPU",
"//compiler/src/iree/compiler/Codegen/Transforms",
"//compiler/src/iree/compiler/Codegen/Utils",
Expand All @@ -104,7 +102,6 @@ iree_compiler_cc_library(
"//llvm-external-projects/iree-dialects:IREELinalgExtUtils",
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialectPasses",
"//runtime/src/iree/builtins/ukernel:exported_bits",
"//runtime/src/iree/schemas:cpu_data",
"//runtime/src/iree/schemas/instruments",
"@llvm-project//llvm:BinaryFormat",
Expand Down
3 changes: 0 additions & 3 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ iree_cc_library(
"LLVMCPUEmitVectorizationRemarks.cpp"
"LLVMCPULinkExecutables.cpp"
"LLVMCPULowerExecutableTarget.cpp"
"LLVMCPULowerToUKernels.cpp"
"LLVMCPUMmt4dVectorLowering.cpp"
"LLVMCPUPeel.cpp"
"LLVMCPUSelectLoweringStrategy.cpp"
Expand Down Expand Up @@ -133,13 +132,11 @@ iree_cc_library(
MLIRVectorToLLVM
MLIRVectorToSCF
MLIRVectorTransforms
iree::builtins::ukernel::exported_bits
iree::compiler::Codegen::Common
iree::compiler::Codegen::Common::CPU::CommonCPUPasses
iree::compiler::Codegen::Common::TransformDialectInterpreterPass
iree::compiler::Codegen::Dialect::IREECodegenDialect
iree::compiler::Codegen::Interfaces::PartitionableLoopsInterface
iree::compiler::Codegen::Interfaces::UKernelOpInterface
iree::compiler::Codegen::TransformStrategies::CPU
iree::compiler::Codegen::Transforms
iree::compiler::Codegen::Utils
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager,
nestedModulePM.addNestedPass<func::FuncOp>(
createDecomposeBatchMmt4DOpsPass());
nestedModulePM.addPass(
createLLVMCPULowerToUKernelsPass(clSkipIntermediateRoundings));
createCPULowerToUKernelsPass(clSkipIntermediateRoundings));
} else {
nestedModulePM.addNestedPass<func::FuncOp>(createLLVMCPUTileAndFusePass(
static_cast<int64_t>(tilingConfig.getVectorCommonParallelLevel())));
Expand Down
5 changes: 0 additions & 5 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,6 @@ createLLVMCPULowerExecutableTargetPass();
/// Can handel more operations if required in future.
std::unique_ptr<Pass> createExpandF16OpToF32Pass();

/// Pass to lower a sequence of operations to a iree_codegen.ukernel.*
/// operation.
std::unique_ptr<OperationPass<>>
createLLVMCPULowerToUKernelsPass(bool skipIntermediateRoundings = true);

std::unique_ptr<OperationPass<func::FuncOp>>
createLLVMCPUMmt4dVectorLoweringPass();

Expand Down
13 changes: 0 additions & 13 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,6 @@ def LLVMCPULowerExecutableTarget :
"mlir::iree_compiler::createLLVMCPULowerExecutableTargetPass()";
}

def LLVMCPULowerToUKernels :
Pass<"iree-llvmcpu-lower-to-ukernels", ""> {
let summary =
"Separate out parts of the IR that lower to a micro-kernel";
let constructor =
"mlir::iree_compiler::createLLVMCPULowerToUKernelsPass()";
let options = [
Option<"optionSkipIntermediateRoundings", "skip-intermediate-roundings",
"bool", /*default=*/"true",
"Allow skipping intermediate roundings, e.g. in f16 ukernels internally doing f32 arithmetic.">,
];
}

def LLVMCPUMmt4dVectorLowering
: Pass<"iree-llvmcpu-mmt4d-vector-lowering", "func::FuncOp"> {
let summary = "Apply vector lowering logic to vector ops";
Expand Down
21 changes: 0 additions & 21 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,26 +131,5 @@ void setSCFTileSizes(scf::SCFTilingOptions &options, TilingInterface consumerOp,
}
}

std::optional<CastOpInterface>
getCastOpOfElementWiseCast(linalg::GenericOp genericOp) {
if (!genericOp || genericOp.getNumDpsInputs() != 1 ||
genericOp.getNumDpsInits() != 1 ||
genericOp.getBody()->getOperations().size() != 2 ||
!isElementwise(genericOp)) {
return std::nullopt;
}
auto yieldOp = cast<linalg::YieldOp>(genericOp.getBody()->getTerminator());
auto castOp = yieldOp->getOperand(0).getDefiningOp<CastOpInterface>();
if (!castOp) {
return std::nullopt;
}
Value castIn = castOp->getOperand(0);
if (castIn.isa<BlockArgument>() &&
castIn.cast<BlockArgument>().getArgNumber() != 0) {
return std::nullopt;
}
return castOp;
}

} // namespace iree_compiler
} // namespace mlir
6 changes: 0 additions & 6 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,6 @@ void setSCFTileSizes(scf::SCFTilingOptions &options, TilingInterface consumerOp,
SmallVector<int64_t> tileSizes,
SmallVector<bool> tileScalableFlags);

// If the `genericOp` is element-wise with identity maps, and has only a
// CastOpInterface op, return the CastOpInterface op of the body. Otherwise,
// return std::nullopt.
std::optional<CastOpInterface>
getCastOpOfElementWiseCast(linalg::GenericOp genericOp);

} // namespace iree_compiler
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ iree_lit_test_suite(
"hal_interface_constants.mlir",
"hal_interface_workgroup_info.mlir",
"illegal_configuration.mlir",
"lower_to_ukernel_ops.mlir",
"materialize_aarch64_launch_configuration.mlir",
"materialize_configuration_without_distribution.mlir",
"materialize_riscv_launch_configuration.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ iree_lit_test_suite(
"hal_interface_constants.mlir"
"hal_interface_workgroup_info.mlir"
"illegal_configuration.mlir"
"lower_to_ukernel_ops.mlir"
"materialize_aarch64_launch_configuration.mlir"
"materialize_configuration_without_distribution.mlir"
"materialize_riscv_launch_configuration.mlir"
Expand Down
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/VMVX/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Common",
"//compiler/src/iree/compiler/Codegen/Common/CPU:CommonCPUPasses",
"//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/LLVMCPU",
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
Expand Down
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ iree_cc_library(
iree::compiler::Codegen::Common
iree::compiler::Codegen::Common::CPU::CommonCPUPasses
iree::compiler::Codegen::Dialect::IREECodegenDialect
iree::compiler::Codegen::LLVMCPU
iree::compiler::Codegen::Utils
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
Expand Down
4 changes: 1 addition & 3 deletions compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "iree/compiler/Codegen/Common/CPU/Passes.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "iree/compiler/Codegen/VMVX/PassDetail.h"
#include "iree/compiler/Codegen/VMVX/Passes.h"
#include "mlir/Pass/PassManager.h"
Expand Down Expand Up @@ -60,9 +59,8 @@ void addVMVXDefaultPassPipeline(OpPassManager &passManager,
addTileAndDistributePasses(passManager);

if (enableUKernels) {
// TODO(hanchung): Move the pass to Codegen/Common/CPU/.
passManager.nest<ModuleOp>().addPass(
createLLVMCPULowerToUKernelsPass(clSkipIntermediateRoundings));
createCPULowerToUKernelsPass(clSkipIntermediateRoundings));
}

// Tensor-level micro-kernel optimizations.
Expand Down

0 comments on commit 5c95ec0

Please sign in to comment.