Skip to content

Commit

Permalink
Add pass for lowering to accel ukernels.
Browse files Browse the repository at this point in the history
Add lit test and fix build.

Fixes to LowerToAccelUKernelPass

Tweaks to LowerToAccelUKernelsPass.

Add AccelMatmulExpert pass pipeline

Apply clang-format to new C++ files. (#3)

- Apply clangformat.

use 'accel' identifier

Use parameter struct calling convention

Tweaks to KernelDispatch and AccelMatmulExpert pipeline

Co-authored-by: Sungsoon Cho <[email protected]>
  • Loading branch information
monorimet and godot73 committed Nov 16, 2023
1 parent 9c5af19 commit 326100b
Show file tree
Hide file tree
Showing 13 changed files with 245 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def SPIRV_WinogradVectorize

def VMVX_Default : I32EnumAttrCase<"VMVXDefault", 300>;

def CPU_AccelMatmulExpert
: I32EnumAttrCase<"AccelMatmulExpert", 25>;

def Linalg_TransformDialectCodegen
: I32EnumAttrCase<"TransformDialectCodegen", 1000>;
Expand All @@ -79,7 +81,7 @@ def DispatchLoweringPassPipelineEnum : I32EnumAttr<
CPU_Default, CPU_DoubleTilingExpert, CPU_DoubleTilingPadExpert,
CPU_DoubleTilingPeelingExpert, CPU_ConvTileAndDecomposeExpert,
CPU_Mmt4dTilingExpert, CPU_BufferOpsTileAndVectorize,
CPU_DataTiling,
CPU_DataTiling, CPU_AccelMatmulExpert,

// LLVMGPU CodeGen pipelines
LLVMGPU_Default, LLVMGPU_SimpleDistribute, LLVMGPU_Vectorize,
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ iree_compiler_cc_library(
"LLVMCPUFoldVectorContractUnitDims.cpp",
"LLVMCPULinkExecutables.cpp",
"LLVMCPULowerExecutableTarget.cpp",
"LLVMCPULowerToAccelUKernels.cpp",
"LLVMCPUMmt4dVectorLowering.cpp",
"LLVMCPUPeel.cpp",
"LLVMCPUSelectLoweringStrategy.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ iree_cc_library(
"LLVMCPUFoldVectorContractUnitDims.cpp"
"LLVMCPULinkExecutables.cpp"
"LLVMCPULowerExecutableTarget.cpp"
"LLVMCPULowerToAccelUKernels.cpp"
"LLVMCPUMmt4dVectorLowering.cpp"
"LLVMCPUPeel.cpp"
"LLVMCPUSelectLoweringStrategy.cpp"
Expand Down
14 changes: 13 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,18 @@ static LogicalResult setRootConfig(func::FuncOp entryPointFn,
DispatchLoweringPassPipeline::Mmt4dTilingExpert);
}

/// Sets the lowering configuration for dispatch region for linalg.matmul root
/// op
static LogicalResult setRootConfig(func::FuncOp entryPointFn,
linalg::MatmulOp matmulOp) {
assert(!getLoweringConfig(matmulOp) && "expected lowering_config is not set");
SmallVector<int64_t> tileSizes;
tileSizes.push_back(1);
return setOpConfigAndEntryPointFnTranslation(
entryPointFn, matmulOp, tileSizes,
DispatchLoweringPassPipeline::AccelMatmulExpert);
}

/// Sets the lowering configuration for dispatch region for linalg.batch_mmt4d
/// root op
static LogicalResult setRootConfig(func::FuncOp entryPointFn,
Expand Down Expand Up @@ -2103,7 +2115,7 @@ setRootConfigImpl(func::FuncOp entryPointFn, Operation *op,
targetMLTransInfo);
})
.Case<IREE::LinalgExt::FftOp, tensor::PackOp, tensor::PadOp,
linalg::Mmt4DOp, linalg::BatchMmt4DOp>(
linalg::Mmt4DOp, linalg::MatmulOp, linalg::BatchMmt4DOp>(
[&](auto op) { return setRootConfig(entryPointFn, op); })
.Case<linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNchwFchwOp,
linalg::PoolingNhwcSumOp, linalg::PoolingNhwcMaxOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,13 @@ void LLVMCPULowerExecutableTargetPass::runOnOperation() {
default:
moduleOp.emitOpError("Unsupported pipeline on CPU target.");
return signalPassFailure();

case IREE::Codegen::DispatchLoweringPassPipeline::AccelMatmulExpert: {
TilingConfig tilingConfig = getTilingConfigForPipeline(moduleOp);
addAccelMatmulExpertPassPipeline(executableLoweringPipeline,
tilingConfig,
enableAccelMicrokernels);
break;
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/builtins/ukernel/exported_bits.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/UKernelOps.h"
#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace iree_compiler {

namespace {

class LLVMCPULowerToAccelUKernelsPass
: public LLVMCPULowerToAccelUKernelsBase<LLVMCPULowerToAccelUKernelsPass> {
public:
LLVMCPULowerToAccelUKernelsPass() = default;

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Codegen::IREECodegenDialect>();
}

void runOnOperation() override;

LogicalResult initializeOptions(StringRef options) override {
if (failed(Pass::initializeOptions(options))) {
return failure();
}
return success();
}
};

/// Holds a function name and attributes.
struct FnNameAndDefAttrs {
std::string name;
SmallVector<NamedAttribute> defAttrs;
};

/// Returns the function name and attributes to use for a ukernel with given
/// `ukernelName` on the target described by `targetAttr`.
static FnNameAndDefAttrs
getFnNameAndDefAttrs(const char *ukernelName, RewriterBase &rewriter,
IREE::HAL::ExecutableTargetAttr targetAttr) {
FnNameAndDefAttrs result;
result.name = ukernelName;
result.defAttrs.emplace_back(
rewriter.getStringAttr("hal.import.fields"),
rewriter.getArrayAttr({rewriter.getStringAttr("processor_data"),
rewriter.getStringAttr("processor_id")}));
result.defAttrs.emplace_back(
rewriter.getStringAttr("hal.import.cconv"),
IREE::HAL::CallingConventionAttr::get(
rewriter.getContext(),
IREE::HAL::CallingConvention::ParameterStruct));
return result;
}

/// Matches an (linalg.fill -> )? linalg.matmul operation sequence and converts
/// it into a iree_codegen.ukernel.generic "accel_matmul_f32" operation, that is later lowered
/// into a call to the microkernel.
static FailureOr<IREE::Codegen::UKernelOpInterface>
matchDAGForUKernel(RewriterBase &rewriter, linalg::MatmulOp op) {
Value lhs = op.getDpsInputOperand(0)->get();
Value rhs = op.getDpsInputOperand(1)->get();
Value out = op.getDpsInitOperand(0)->get();
auto outType = llvm::cast<ShapedType>(out.getType());

Location loc = op.getLoc();
Value m = rewriter.create<tensor::DimOp>(loc, lhs, 0);
Value n = rewriter.create<tensor::DimOp>(loc, rhs, 0);
Value k = rewriter.create<tensor::DimOp>(loc, rhs, 1);

auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op);
auto fn = getFnNameAndDefAttrs("accel_matmul_f32", rewriter, targetAttr);
auto genericMicroKernelOp = rewriter.create<IREE::Codegen::UKernelGenericOp>(
loc, outType, fn.name, ValueRange{lhs, rhs}, out, ValueRange{m, n, k},
/*fn_def_attrs=*/rewriter.getDictionaryAttr(fn.defAttrs),
/*strided_outer_dims=*/rewriter.getIndexAttr(0));
return cast<IREE::Codegen::UKernelOpInterface>(
genericMicroKernelOp.getOperation());
}

template <typename OpType>
struct LowerToAccelUKernelPattern : OpRewritePattern<OpType> {
LowerToAccelUKernelPattern(MLIRContext *context)
: OpRewritePattern<OpType>(context) {}

LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rewriter) const override {
FailureOr<IREE::Codegen::UKernelOpInterface> ukernelOp =
matchDAGForUKernel(rewriter, op);
if (failed(ukernelOp)) {
return rewriter.notifyMatchFailure(
op, "failed to find microkernel op to replace with");
}
rewriter.replaceOp(op, ukernelOp.value()->getResults());
return success();
}
};

void LLVMCPULowerToAccelUKernelsPass::runOnOperation() {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
// Enabling a lowering of an op to a microkernel is a trade-off between the
// potential performance advantage of a microkernel over pure code generation
// for that op, and the potential benefits of fusions. Indeed, once an op
// lowered into a microkernel, it will never be fused at any MLIR level.
// Since microkernels are linked as bitcode, they will still undergo LTO-like
// optimization in their calling contexts, but we shouldn't expect this to
// achieve similar results as fusing structured ops.
patterns.insert<LowerToAccelUKernelPattern<linalg::MatmulOp>>(context);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}

} // namespace

std::unique_ptr<OperationPass<>> createLLVMCPULowerToAccelUKernelsPass() {
return std::make_unique<LLVMCPULowerToAccelUKernelsPass>();
}

} // namespace iree_compiler
} // namespace mlir
31 changes: 31 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ static llvm::cl::opt<bool> clEnablePadConsumerFusion(
llvm::cl::desc("Flag to enable the fusion for pad + consumer"),
llvm::cl::init(false));

static llvm::cl::opt<bool> clEnableAccelMicrokernels(
"iree-llvmcpu-enable-accel-ukernels",
llvm::cl::desc("Flag to enable lowering to accelUkernels"),
llvm::cl::init(false));

static llvm::cl::opt<bool> clEnableReassociateFpReductions(
"iree-llvmcpu-reassociate-fp-reductions",
llvm::cl::desc("Enables reassociation for FP reductions"),
Expand Down Expand Up @@ -562,6 +567,32 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager,
}
}

void addAccelMatmulExpertPassPipeline(OpPassManager &passManager,
TilingConfig &tilingConfig,
bool enableAccelMicrokernels) {
addTileAndDistributePasses(passManager);

OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();

if (enableAccelMicrokernels) {
nestedModulePM.addPass(createLLVMCPULowerToAccelUKernelsPass());
} else {
nestedModulePM.addNestedPass<func::FuncOp>(createLLVMCPUTileAndFusePass(
static_cast<int64_t>(tilingConfig.getVectorCommonParallelLevel())));
nestedModulePM.addNestedPass<func::FuncOp>(createLLVMCPUTilePass(
static_cast<int64_t>(tilingConfig.getVectorReductionLevel())));
nestedModulePM.addNestedPass<func::FuncOp>(
createGenericVectorizationPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createHoistRedundantVectorTransfersPass());
}

nestedModulePM.addNestedPass<func::FuncOp>(createCanonicalizerPass());
nestedModulePM.addNestedPass<func::FuncOp>(createCSEPass());

addBufferizePasses(nestedModulePM);
}

void addCPUDataTilingPipeline(OpPassManager &passManager,
TilingConfig &tilingConfig,
bool enableVectorMasking) {
Expand Down
13 changes: 13 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ createLLVMCPULowerExecutableTargetPass();
/// Can handel more operations if required in future.
std::unique_ptr<Pass> createExpandF16OpToF32Pass();

/// Pass to lower a sequence of operations to a iree_codegen.ukernel.*
/// operation.
std::unique_ptr<OperationPass<>>
createLLVMCPULowerToUKernelsPass(bool skipIntermediateRoundings = true);

/// Pass to lower a sequence of operations to a iree_codegen.ukernel.*
/// operation.
std::unique_ptr<OperationPass<>> createLLVMCPULowerToAccelUKernelsPass();

std::unique_ptr<OperationPass<func::FuncOp>>
createLLVMCPUMmt4dVectorLoweringPass();

Expand Down Expand Up @@ -165,6 +174,10 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager,
TilingConfig &tilingConfig,
bool enableMicrokernels);

void addAccelMatmulExpertPassPipeline(OpPassManager &passManager,
TilingConfig &tilingConfig,
bool enableAccelMicrokernels);

void addMultiTilingExpertPassPipeline(
OpPassManager &passManager, TilingConfig &tilingConfig, bool enablePeeling,
bool enableVectorMasking, bool lowerToAVX2, bool enableAArch64SSVE = false);
Expand Down
8 changes: 8 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ def LLVMCPULowerExecutableTarget :
"mlir::iree_compiler::createLLVMCPULowerExecutableTargetPass()";
}

def LLVMCPULowerToAccelUKernels :
Pass<"iree-llvmcpu-lower-to-accel-ukernels", ""> {
let summary =
"Separate out parts of the IR that lower to an accel-micro-kernel";
let constructor =
"mlir::iree_compiler::createLLVMCPULowerToAccelUKernelsPass()";
}

def LLVMCPUMmt4dVectorLowering
: Pass<"iree-llvmcpu-mmt4d-vector-lowering", "func::FuncOp"> {
let summary = "Apply vector lowering logic to vector ops";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ iree_lit_test_suite(
"hal_interface_constants.mlir",
"hal_interface_workgroup_info.mlir",
"illegal_configuration.mlir",
"lower_to_accel_ukernel_ops.mlir",
"materialize_aarch64_launch_configuration.mlir",
"materialize_configuration_without_distribution.mlir",
"materialize_riscv_launch_configuration.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ iree_lit_test_suite(
"hal_interface_constants.mlir"
"hal_interface_workgroup_info.mlir"
"illegal_configuration.mlir"
"lower_to_accel_ukernel_ops.mlir"
"materialize_aarch64_launch_configuration.mlir"
"materialize_configuration_without_distribution.mlir"
"materialize_riscv_launch_configuration.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-lower-to-accel-ukernels,cse,canonicalize))" %s | FileCheck %s

func.func @matmul_f32f32f32(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK: func @matmul_f32f32f32(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0
// CHECK-DAG: %[[C1:.+]] = arith.constant 1
// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C0]]
// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "aie_matmul_f32"
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK-SAME: outs(%[[ARG2]] :
// CHECK-SAME: (%[[M]], %[[N]], %[[K]] :
// CHECK-DAG: "processor_id"
// CHECK-DAG: "processor_data"
// CHECK: return %[[MICRO_KERNEL]]
1 change: 1 addition & 0 deletions samples/custom_dispatch/cpu/plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ target_include_directories(iree_samples_custom_dispatch_cpu_system_plugin
${IREE_SOURCE_DIR}/runtime/src/
)

iree_add_all_subdirs()
# NOTE: this is only required because we want this sample to run on all
# platforms without needing to change the library name (libfoo.so/foo.dll).
set_target_properties(iree_samples_custom_dispatch_cpu_system_plugin
Expand Down

0 comments on commit 326100b

Please sign in to comment.