Skip to content

Commit

Permalink
Convert MetalSPIRV compiler target into a plugin. (iree-org#15635)
Browse files Browse the repository at this point in the history
Progress on iree-org#15468
  • Loading branch information
ScottTodd authored Nov 17, 2023
1 parent 5b2cb64 commit f5792fd
Show file tree
Hide file tree
Showing 17 changed files with 88 additions and 68 deletions.
4 changes: 4 additions & 0 deletions compiler/plugins/iree_compiler_plugin.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ if(IREE_TARGET_BACKEND_CUDA)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/target/CUDA target/CUDA)
endif()

if(IREE_TARGET_BACKEND_METAL_SPIRV)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/target/MetalSPIRV target/MetalSPIRV)
endif()

if(IREE_TARGET_BACKEND_WEBGPU)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/target/WebGPU target/WebGPU)
endif()
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,17 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

load("//build_tools/bazel:build_defs.oss.bzl", "iree_cmake_extra_content", "iree_compiler_cc_library")
load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_compiler_register_plugin")

package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
licenses = ["notice"], # Apache 2.0
)

iree_cmake_extra_content(
content = """
if(NOT IREE_TARGET_BACKEND_METAL_SPIRV)
return()
endif()
""",
iree_compiler_register_plugin(
plugin_id = "hal_target_metal_spirv",
target = ":MetalSPIRV",
)

iree_compiler_cc_library(
Expand All @@ -26,12 +23,14 @@ iree_compiler_cc_library(
hdrs = ["MetalSPIRVTarget.h"],
deps = [
":MSLToMetalLib",
":MetalTargetPlatform",
":SPIRVToMSL",
"//compiler/src/iree/compiler/Codegen/Common",
"//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/SPIRV",
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/HAL/Target",
"//compiler/src/iree/compiler/PluginAPI",
"//compiler/src/iree/compiler/Utils",
"//runtime/src/iree/schemas:metal_executable_def_c_fbs",
"@llvm-project//llvm:Support",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
################################################################################
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
# compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/BUILD.bazel #
# compiler/plugins/target/MetalSPIRV/BUILD.bazel #
# #
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
# CMake-only content. #
# #
# To disable autogeneration for this file entirely, delete this header. #
################################################################################

if(NOT IREE_TARGET_BACKEND_METAL_SPIRV)
return()
endif()

iree_add_all_subdirs()

iree_compiler_register_plugin(
PLUGIN_ID
hal_target_metal_spirv
TARGET
::MetalSPIRV
)

iree_cc_library(
NAME
MetalSPIRV
Expand All @@ -23,6 +26,7 @@ iree_cc_library(
"MetalSPIRVTarget.cpp"
DEPS
::MSLToMetalLib
::MetalTargetPlatform
::SPIRVToMSL
LLVMSupport
LLVMTargetParser
Expand All @@ -37,6 +41,7 @@ iree_cc_library(
iree::compiler::Codegen::SPIRV
iree::compiler::Codegen::Utils
iree::compiler::Dialect::HAL::Target
iree::compiler::PluginAPI
iree::compiler::Utils
iree::schemas::metal_executable_def_c_fbs
PUBLIC
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/MSLToMetalLib.h"
#include "./MSLToMetalLib.h"

#include <stdlib.h>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_METALSPIRV_MSLTOMETALLIB_H_
#define IREE_COMPILER_DIALECT_HAL_TARGET_METALSPIRV_MSLTOMETALLIB_H_

#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalTargetPlatform.h"
#include "./MetalTargetPlatform.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/MemoryBuffer.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.h"
#include "./MetalSPIRVTarget.h"

#include "./MSLToMetalLib.h"
#include "./MetalTargetPlatform.h"
#include "./SPIRVToMSL.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/MSLToMetalLib.h"
#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/PluginAPI/Client.h"
#include "iree/compiler/Utils/FlatbufferUtils.h"
#include "iree/schemas/metal_executable_def_builder.h"
#include "llvm/Support/MemoryBuffer.h"
Expand All @@ -31,21 +33,30 @@ namespace iree_compiler {
namespace IREE {
namespace HAL {

llvm::cl::opt<MetalTargetPlatform> clTargetPlatform(
"iree-metal-target-platform", llvm::cl::desc("Apple platform to target"),
llvm::cl::values(
clEnumValN(MetalTargetPlatform::macOS, "macos", "macOS platform"),
clEnumValN(MetalTargetPlatform::iOS, "ios", "iOS platform"),
clEnumValN(MetalTargetPlatform::iOSSimulator, "ios-simulator",
"iOS simulator platform")),
llvm::cl::init(MetalTargetPlatform::macOS));

static llvm::cl::opt<bool> clCompileToMetalLib(
"iree-metal-compile-to-metallib",
llvm::cl::desc(
"Compile to .metallib and embed in IREE deployable flatbuffer if true; "
"otherwise stop at and embed MSL source code"),
llvm::cl::init(true));
struct MetalSPIRVOptions {
MetalTargetPlatform clTargetPlatform;
bool clCompileToMetalLib;

void bindOptions(OptionsBinder &binder) {
static llvm::cl::OptionCategory category("MetalSPIRV HAL Target");
binder.opt<MetalTargetPlatform>(
"iree-metal-target-platform", clTargetPlatform, llvm::cl::cat(category),
llvm::cl::desc("Apple platform to target"),
llvm::cl::values(
clEnumValN(MetalTargetPlatform::macOS, "macos", "macOS platform"),
clEnumValN(MetalTargetPlatform::iOS, "ios", "iOS platform"),
clEnumValN(MetalTargetPlatform::iOSSimulator, "ios-simulator",
"iOS simulator platform")),
llvm::cl::init(MetalTargetPlatform::macOS));
binder.opt<bool>(
"iree-metal-compile-to-metallib", clCompileToMetalLib,
llvm::cl::cat(category),
llvm::cl::desc("Compile to .metallib and embed in IREE deployable "
"flatbuffer if true; "
"otherwise stop at and embed MSL source code"),
llvm::cl::init(true));
}
};

static spirv::TargetEnvAttr getMetalTargetEnv(MLIRContext *context) {
using spirv::Capability;
Expand Down Expand Up @@ -103,7 +114,8 @@ static spirv::TargetEnvAttr getMetalTargetEnv(MLIRContext *context) {

class MetalSPIRVTargetBackend : public TargetBackend {
public:
MetalSPIRVTargetBackend() = default;
MetalSPIRVTargetBackend(MetalSPIRVOptions options)
: options_(std::move(options)) {}

// NOTE: we could vary this based on the options such as 'metal-v2'.
std::string name() const override { return "metal"; }
Expand Down Expand Up @@ -188,7 +200,8 @@ class MetalSPIRVTargetBackend : public TargetBackend {
// We can use ArrayRef here given spvBinary reserves 0 bytes on stack.
ArrayRef spvData(spvBinary.data(), spvBinary.size());
std::optional<std::pair<MetalShader, std::string>> msl =
crossCompileSPIRVToMSL(clTargetPlatform, spvData, entryPoint);
crossCompileSPIRVToMSL(options_.clTargetPlatform, spvData,
entryPoint);
if (!msl) {
return variantOp.emitError()
<< "failed to cross compile SPIR-V to Metal shader";
Expand All @@ -208,16 +221,16 @@ class MetalSPIRVTargetBackend : public TargetBackend {

// 3. Compile MSL to MTLLibrary.
SmallVector<std::unique_ptr<llvm::MemoryBuffer>> metalLibs;
if (clCompileToMetalLib) {
if (options_.clCompileToMetalLib) {
// We need to use offline Metal shader compilers.
// TODO(#14048): The toolchain can also exist on other platforms. Probe
// the PATH instead.
auto hostTriple = llvm::Triple(llvm::sys::getProcessTriple());
if (hostTriple.isMacOSX()) {
for (auto [shader, entryPoint] :
llvm::zip(mslShaders, mslEntryPointNames)) {
std::unique_ptr<llvm::MemoryBuffer> lib =
compileMSLToMetalLib(clTargetPlatform, shader.source, entryPoint);
std::unique_ptr<llvm::MemoryBuffer> lib = compileMSLToMetalLib(
options_.clTargetPlatform, shader.source, entryPoint);
if (!lib) {
return variantOp.emitError()
<< "failed to compile to MTLLibrary from MSL:\n\n"
Expand Down Expand Up @@ -300,19 +313,35 @@ class MetalSPIRVTargetBackend : public TargetBackend {
context, b.getStringAttr("metal"), b.getStringAttr("metal-msl-fb"),
configAttr);
}

MetalSPIRVOptions options_;
};

void registerMetalSPIRVTargetBackends() {
auto backendFactory = [=]() {
return std::make_shared<MetalSPIRVTargetBackend>();
};
// #hal.device.target<"metal", ...
static TargetBackendRegistration registration0("metal", backendFactory);
// #hal.executable.target<"metal-spirv", ...
static TargetBackendRegistration registration1("metal-spirv", backendFactory);
}
struct MetalSPIRVSession
: public PluginSession<MetalSPIRVSession, MetalSPIRVOptions,
PluginActivationPolicy::DefaultActivated> {
void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) {
auto backendFactory = [=]() {
return std::make_shared<MetalSPIRVTargetBackend>(options);
};
// #hal.device.target<"metal", ...
targets.add("metal", backendFactory);
// #hal.executable.target<"metal-spirv", ...
targets.add("metal-spirv", backendFactory);
}
};

} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir

extern "C" bool iree_register_compiler_plugin_hal_target_metal_spirv(
mlir::iree_compiler::PluginRegistrar *registrar) {
registrar->registerPlugin<mlir::iree_compiler::IREE::HAL::MetalSPIRVSession>(
"hal_target_metal_spirv");
return true;
}

IREE_DEFINE_COMPILER_OPTION_FLAGS(
mlir::iree_compiler::IREE::HAL::MetalSPIRVOptions);
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.h"
#include "./SPIRVToMSL.h"

#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include <optional>
#include <string>

#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalTargetPlatform.h"
#include "./MetalTargetPlatform.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/Support/LLVM.h"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
# compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/test/BUILD.bazel #
# compiler/plugins/target/MetalSPIRV/test/BUILD.bazel #
# #
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
# CMake-only content. #
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/PluginAPI/Config/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ genrule(
cmd = (
"echo '" +
"HANDLE_PLUGIN_ID(hal_target_cuda)\n" +
"HANDLE_PLUGIN_ID(hal_target_metal_spirv)\n" +
"HANDLE_PLUGIN_ID(input_tosa)\n" +
"HANDLE_PLUGIN_ID(input_stablehlo)\n" +
# Samples
Expand All @@ -42,6 +43,7 @@ iree_compiler_cc_library(
# generates its deps from the environment.
# For now, we just hard include all in-tree plugins.
"//compiler/plugins/target/CUDA",
"//compiler/plugins/target/MetalSPIRV",
"//compiler/plugins/input/StableHLO/stablehlo-iree:registration",
"//compiler/plugins/input/TOSA/tosa-iree:registration",
"//samples/compiler_plugins/example:registration",
Expand Down
2 changes: 0 additions & 2 deletions compiler/src/iree/compiler/Tools/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,12 @@ iree_compiler_cc_library(
hdrs = ["init_targets.h"],
local_defines = [
"IREE_HAVE_LLVM_CPU_TARGET",
"IREE_HAVE_METALSPIRV_TARGET",
"IREE_HAVE_ROCM_TARGET",
"IREE_HAVE_VMVX_TARGET",
"IREE_HAVE_VULKANSPIRV_TARGET",
],
deps = [
"//compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU",
"//compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV",
"//compiler/src/iree/compiler/Dialect/HAL/Target/ROCM",
"//compiler/src/iree/compiler/Dialect/HAL/Target/VMVX",
"//compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV",
Expand Down
4 changes: 0 additions & 4 deletions compiler/src/iree/compiler/Tools/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@ if(IREE_TARGET_BACKEND_LLVM_CPU)
list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::LLVMCPU)
list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_LLVM_CPU_TARGET")
endif()
if(IREE_TARGET_BACKEND_METAL_SPIRV)
list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::MetalSPIRV)
list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_METALSPIRV_TARGET")
endif()
if(IREE_TARGET_BACKEND_VMVX)
list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::VMVX)
list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_VMVX_TARGET")
Expand Down
13 changes: 0 additions & 13 deletions compiler/src/iree/compiler/Tools/init_targets.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
#ifdef IREE_HAVE_LLVM_CPU_TARGET
#include "iree/compiler/Dialect/HAL/Target/LLVMCPU/LLVMCPUTarget.h"
#endif // IREE_HAVE_LLVM_CPU_TARGET
#ifdef IREE_HAVE_METALSPIRV_TARGET
#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.h"
#endif // IREE_HAVE_METALSPIRV_TARGET
#ifdef IREE_HAVE_ROCM_TARGET
#include "iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.h"
#endif // IREE_HAVE_ROCM_TARGET
Expand All @@ -23,9 +20,6 @@
#ifdef IREE_HAVE_VULKANSPIRV_TARGET
#include "iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h"
#endif // IREE_HAVE_VULKANSPIRV_TARGET
#ifdef IREE_HAVE_WEBGPU_TARGET
#include "iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.h"
#endif // IREE_HAVE_WEBGPU_TARGET

namespace mlir {
namespace iree_compiler {
Expand All @@ -41,9 +35,6 @@ void registerHALTargetBackends() {
IREE::HAL::registerLLVMCPUTargetBackends(
[]() { return IREE::HAL::LLVMTargetOptions::getFromFlags(); });
#endif // IREE_HAVE_LLVM_CPU_TARGET
#ifdef IREE_HAVE_METALSPIRV_TARGET
IREE::HAL::registerMetalSPIRVTargetBackends();
#endif // IREE_HAVE_METALSPIRV_TARGET
#ifdef IREE_HAVE_ROCM_TARGET
IREE::HAL::registerROCMTargetBackends();
#endif // IREE_HAVE_ROCM_TARGET
Expand All @@ -54,10 +45,6 @@ void registerHALTargetBackends() {
IREE::HAL::registerVulkanSPIRVTargetBackends(
[]() { return IREE::HAL::getVulkanSPIRVTargetOptionsFromFlags(); });
#endif // IREE_HAVE_VULKANSPIRV_TARGET
#ifdef IREE_HAVE_WEBGPU_TARGET
IREE::HAL::registerWebGPUTargetBackends(
[]() { return IREE::HAL::getWebGPUTargetOptionsFromFlags(); });
#endif // IREE_HAVE_WEBGPU_TARGET
return true;
}();
(void)init_once;
Expand Down

0 comments on commit f5792fd

Please sign in to comment.