Skip to content

Commit

Permalink
Migrate StableHLO input conversion to a compiler plugin. (iree-org#15568
Browse files Browse the repository at this point in the history
)

Progress on iree-org#15468

StableHLO-specific changes:

* Added `input_stablehlo` plugin at
`compiler/plugins/input/StableHLO/stablehlo-iree/`
* Moved StableHLO input conversion code into the plugin
* Note file paths: `stablehlo-iree/Conversion` is used instead of
`stablehlo-iree/InputConversion` to save on lengths
* Torch and TOSA both have fewer characters and do not have a
`InputConversion/Preprocessing/` subfolder so they don't hit the same
limit

Cleanup now that all input dialect conversion is handled via plugins
consistently:

* Deleted `init_input_dialects.[h, cc]` and `init_input_passes.[h, cc]`
* Removed dialect-specific code paths from
`compiler/Pipelines/[Pipelines, Options].cpp`
* Generated more `CMakeList.txt` files from `BUILD.bazel` files (input
dialect dependencies and source files are organized into plugins with
their own top-level filtering)
* Added `auto_input_conversion.mlir` tests for each input plugin that
use `--compile-to=input` to show how the "auto" input type handles their
dialects
  • Loading branch information
ScottTodd authored Nov 14, 2023
1 parent 643b467 commit ebb5b7d
Show file tree
Hide file tree
Showing 109 changed files with 556 additions and 601 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
/compiler/src/iree/compiler/Dialect/Vulkan/ @antiagainst
/compiler/src/iree/compiler/GlobalOptimization/ @hanhanW
/compiler/src/iree/compiler/InputConversion/ @MaheshRavishankar @stellaraccident
/compiler/src/iree/compiler/InputConversion/StableHLO/ @hanhanW @MaheshRavishankar @rsuderman
/compiler/plugins/input/StableHLO/ @hanhanW @MaheshRavishankar @rsuderman
/compiler/plugins/input/TOSA/ @MaheshRavishankar @rsuderman

# Runtime
Expand Down
11 changes: 11 additions & 0 deletions compiler/plugins/input/StableHLO/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
licenses = ["notice"], # Apache 2.0
)
22 changes: 22 additions & 0 deletions compiler/plugins/input/StableHLO/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

set(IREE_PACKAGE_ROOT_DIR "${CMAKE_CURRENT_LIST_DIR}")
set(IREE_PACKAGE_ROOT_PREFIX "")
set(IREE_COMPILER_TABLEGEN_INCLUDE_DIRS "${CMAKE_CURRENT_SOURCE_DIR}")

add_library(stablehlo-iree_compiler_defs INTERFACE)
target_include_directories(stablehlo-iree_compiler_defs
INTERFACE
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_BINARY_DIR}
)

# Configures all iree_cc_* targets to take this implicit dep,
# which provides common includes and copts for the tree.
set(IREE_IMPLICIT_DEFS_CC_DEPS stablehlo-iree_compiler_defs)

add_subdirectory(stablehlo-iree)
39 changes: 39 additions & 0 deletions compiler/plugins/input/StableHLO/stablehlo-iree/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_compiler_register_plugin")

package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
licenses = ["notice"], # Apache 2.0
)

iree_compiler_register_plugin(
plugin_id = "input_stablehlo",
target = ":registration",
)

iree_compiler_cc_library(
name = "registration",
srcs = [
"PluginRegistration.cpp",
],
copts = [
"-Icompiler/plugins/input/StableHLO",
"-I$(GENDIR)/compiler/plugins/input/StableHLO",
],
deps = [
"//compiler/plugins/input/StableHLO/stablehlo-iree/Conversion",
"//compiler/src/iree/compiler/PluginAPI",
"@llvm-project//mlir:ConversionPasses",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
"@stablehlo//:chlo_ops",
"@stablehlo//:stablehlo_ops",
],
)
22 changes: 22 additions & 0 deletions compiler/plugins/input/StableHLO/stablehlo-iree/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
iree_add_all_subdirs()

iree_compiler_register_plugin(
PLUGIN_ID
input_stablehlo
TARGET
::registration
)

iree_cc_library(
NAME
registration
SRCS
"PluginRegistration.cpp"
DEPS
MLIRIR
MLIRPass
MLIRTransforms
iree::compiler::PluginAPI
stablehlo-iree::Conversion::Conversion
PUBLIC
)
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ iree_compiler_cc_library(
"Passes.h.inc",
"Rewriters.h",
],
copts = [
"-Icompiler/plugins/input/StableHLO",
"-I$(GENDIR)/compiler/plugins/input/StableHLO",
],
deps = [
":PassesIncGen",
"@llvm-project//mlir:Pass",
Expand Down Expand Up @@ -81,13 +85,17 @@ iree_compiler_cc_library(
"TypeConversion.h",
"VerifyCompilerInputLegality.cpp",
],
copts = [
"-Icompiler/plugins/input/StableHLO",
"-I$(GENDIR)/compiler/plugins/input/StableHLO",
],
deps = [
":CHLODecompositionPatterns",
":PassHeaders",
"//compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Dialect/Util/Transforms",
"//compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing",
"//compiler/src/iree/compiler/Utils",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
"@llvm-project//llvm:Support",
Expand Down Expand Up @@ -122,24 +130,25 @@ iree_compiler_cc_library(
)

iree_compiler_cc_library(
name = "StableHLO",
name = "Conversion",
srcs = [
"Passes.cpp",
],
hdrs = [
"Passes.h",
],
defines = [
"IREE_HAVE_STABLEHLO_INPUT",
copts = [
"-Icompiler/plugins/input/StableHLO",
"-I$(GENDIR)/compiler/plugins/input/StableHLO",
],
deps = [
":PassHeaders",
":StableHLOLegalization",
"//compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Dialect/Util/Transforms",
"//compiler/src/iree/compiler/InputConversion/Common",
"//compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:MLProgramDialect",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
# TODO(scotttodd): generate this file
# Need a mapping for stablehlo-iree::Conversion::Preprocessing

# Add this tablegen include to support CHLO rewrites with DRR.
list(APPEND IREE_COMPILER_TABLEGEN_INCLUDE_DIRS "${IREE_SOURCE_DIR}/third_party/stablehlo")

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_ABOVE_THIS_LINE ###
################################################################################
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
# compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel #
# #
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
# CMake-only content. #
# #
# To disable autogeneration for this file entirely, delete this header. #
################################################################################

iree_add_all_subdirs()

iree_tablegen_library(
Expand All @@ -26,6 +18,9 @@ iree_tablegen_library(
iree_cc_library(
NAME
PassHeaders
COPTS
"-Icompiler/plugins/input/StableHLO"
"-I$(GENDIR)/compiler/plugins/input/StableHLO"
HDRS
"PassDetail.h"
"Passes.h"
Expand All @@ -50,6 +45,9 @@ iree_tablegen_library(
iree_cc_library(
NAME
StableHLOLegalization
COPTS
"-Icompiler/plugins/input/StableHLO"
"-I$(GENDIR)/compiler/plugins/input/StableHLO"
SRCS
"ConvertCollectives.cpp"
"LegalizeCHLO.cpp"
Expand Down Expand Up @@ -104,14 +102,17 @@ iree_cc_library(
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
iree::compiler::InputConversion::StableHLO::Preprocessing
iree::compiler::Utils
stablehlo-iree::Conversion::Preprocessing
PUBLIC
)

iree_cc_library(
NAME
StableHLO
Conversion
COPTS
"-Icompiler/plugins/input/StableHLO"
"-I$(GENDIR)/compiler/plugins/input/StableHLO"
HDRS
"Passes.h"
SRCS
Expand All @@ -133,10 +134,6 @@ iree_cc_library(
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
iree::compiler::InputConversion::Common
iree::compiler::InputConversion::StableHLO::Preprocessing
DEFINES
"IREE_HAVE_STABLEHLO_INPUT"
stablehlo-iree::Conversion::Preprocessing
PUBLIC
)

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
#include "iree/compiler/InputConversion/StableHLO/Rewriters.h"
#include "iree/compiler/Utils/IndexSet.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo-iree/Conversion/Rewriters.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace mlir::iree_compiler::stablehlo {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
// Implements logic for lowering CHLO ops to StableHLO and Shape dialect ops,
// taking care of CHLO's broadcasting semantics

#include "iree/compiler/InputConversion/StableHLO/Passes.h"
#include "iree/compiler/InputConversion/StableHLO/Preprocessing/Rewriters.h"
#include "iree/compiler/InputConversion/StableHLO/Rewriters.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand All @@ -21,14 +18,17 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo-iree/Conversion/Passes.h"
#include "stablehlo-iree/Conversion/Preprocessing/Rewriters.h"
#include "stablehlo-iree/Conversion/Rewriters.h"
#include "stablehlo/dialect/BroadcastUtils.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace mlir::iree_compiler::stablehlo {

#define GEN_PASS_DEF_LEGALIZECHLO
#include "iree/compiler/InputConversion/StableHLO/Passes.h.inc"
#include "stablehlo-iree/Conversion/Passes.h.inc"

namespace {

Expand Down Expand Up @@ -2225,7 +2225,7 @@ struct LegalizeChlo final : impl::LegalizeChloBase<LegalizeChlo> {
} // namespace

namespace {
#include "iree/compiler/InputConversion/StableHLO/CHLODecompositionPatterns.h.inc"
#include "stablehlo-iree/Conversion/CHLODecompositionPatterns.h.inc"
} // end anonymous namespace

namespace {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,20 @@

// Implements logic for lowering StableHLO dialect ops to the SCF dialect.

#include "iree/compiler/InputConversion/StableHLO/Passes.h"
#include "iree/compiler/InputConversion/StableHLO/Rewriters.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo-iree/Conversion/Passes.h"
#include "stablehlo-iree/Conversion/Rewriters.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace mlir::iree_compiler::stablehlo {

#define GEN_PASS_DEF_LEGALIZECONTROLFLOW
#include "iree/compiler/InputConversion/StableHLO/Passes.h.inc"
#include "stablehlo-iree/Conversion/Passes.h.inc"

namespace {
// All transformations in this file take stablehlo blocks which end with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@

// Implements logic for lowering StableHLO dialect to scalar shape operations.

#include "iree/compiler/InputConversion/StableHLO/MapStableHLOToScalarOp.h"
#include "iree/compiler/InputConversion/StableHLO/Passes.h"
#include "iree/compiler/InputConversion/StableHLO/Rewriters.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo-iree/Conversion/MapStableHLOToScalarOp.h"
#include "stablehlo-iree/Conversion/Passes.h"
#include "stablehlo-iree/Conversion/Rewriters.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace mlir::iree_compiler::stablehlo {

#define GEN_PASS_DEF_LEGALIZESHAPECOMPUTATIONS
#include "iree/compiler/InputConversion/StableHLO/Passes.h.inc"
#include "stablehlo-iree/Conversion/Passes.h.inc"

namespace {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

// Implements utilities for lowering StableHLO dialect to Linalg dialect.

#include "iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h"
#include "stablehlo-iree/Conversion/LegalizeToLinalgUtils.h"

#include <algorithm>
#include <numeric>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@

// Utils for lowering of the StableHLO dialect to the Linalg dialect.

#ifndef IREE_COMPILER_INPUTCONVERSION_STABLEHLO_LEGALIZE_TO_LINALG_UTILS_H_
#define IREE_COMPILER_INPUTCONVERSION_STABLEHLO_LEGALIZE_TO_LINALG_UTILS_H_
#ifndef STABLEHLO_IREE_CONVERSION_LEGALIZE_TO_LINALG_UTILS_H_
#define STABLEHLO_IREE_CONVERSION_LEGALIZE_TO_LINALG_UTILS_H_

#include <algorithm>
#include <numeric>
#include <optional>
#include <string>
#include <utility>

#include "iree/compiler/InputConversion/StableHLO/MapStableHLOToScalarOp.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
Expand All @@ -35,6 +34,7 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo-iree/Conversion/MapStableHLOToScalarOp.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace mlir::iree_compiler::stablehlo {
Expand Down Expand Up @@ -110,4 +110,4 @@ inline bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) {

} // namespace mlir::iree_compiler::stablehlo

#endif // IREE_COMPILER_INPUTCONVERSION_STABLEHLO_LEGALIZE_TO_LINALG_UTILS_H_
#endif // STABLEHLO_IREE_CONVERSION_LEGALIZE_TO_LINALG_UTILS_H_
Loading

0 comments on commit ebb5b7d

Please sign in to comment.