Skip to content

Commit

Permalink
Convert WebGPU compiler target into a plugin. (iree-org#15612)
Browse files Browse the repository at this point in the history
Progress on iree-org#15468

I didn't do anything fancy with includes or namespaces for this, just
relative includes with `./` (same as target/CUDA).
* Unit test is at
`iree/compiler/plugins/target/WebGPU/test/smoketest.mlir.test`
* Unit tests for input plugins are at e.g.
`tosa-iree/InputConversion/test/auto_input_conversion.mlir.test`

I also moved the deps that are unique to the WebGPU target into the
plugin, keeping the root `CMakeLists.txt` and shared
`build_tools/third_party/` directories cleaner.
  • Loading branch information
ScottTodd authored Nov 16, 2023
1 parent 073bead commit ef8982c
Show file tree
Hide file tree
Showing 19 changed files with 122 additions and 140 deletions.
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
[submodule "third_party/vulkan_headers"]
path = third_party/vulkan_headers
url = https://github.com/KhronosGroup/Vulkan-Headers.git
[submodule "third_party/spirv_headers"]
path = third_party/spirv_headers
url = https://github.com/KhronosGroup/SPIRV-Headers.git
[submodule "third_party/pybind11"]
path = third_party/pybind11
url = https://github.com/pybind/pybind11.git
Expand Down
18 changes: 0 additions & 18 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -927,24 +927,6 @@ if(IREE_TARGET_BACKEND_METAL_SPIRV)
add_subdirectory(third_party/spirv_cross EXCLUDE_FROM_ALL)
endif()

if(IREE_TARGET_BACKEND_WEBGPU)
# Tint is needed to compile SPIR-V into WGSL source code.
# Tint also requires SPIRV-Tools, which requires SPIRV-Headers.

# NOTE: these can be synced by referencing one of these repositories:
# * https://dawn.googlesource.com/dawn/
# * https://dawn.googlesource.com/tint/
# * https://chromium.googlesource.com/vulkan-deps/+/refs/heads/main/DEPS
# or they can be updated independently
set(IREE_TINT_TAG "fdb8787e9c1b79770bd98a8faf37fbe48a3077a4") # 2023-03-06
set(IREE_SPIRV_TOOLS_TAG "43b8886490eb6af81fc61e0ff071c51a922af864") # 2023-08-11

iree_set_spirv_headers_cmake_options()
add_subdirectory(third_party/spirv_headers EXCLUDE_FROM_ALL)
add_subdirectory(build_tools/third_party/spirv-tools EXCLUDE_FROM_ALL)
add_subdirectory(build_tools/third_party/tint EXCLUDE_FROM_ALL)
endif()

#-------------------------------------------------------------------------------
# IREE top-level libraries
#-------------------------------------------------------------------------------
Expand Down
6 changes: 0 additions & 6 deletions build_tools/bazel/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,6 @@ def configure_iree_submodule_deps(iree_repo_alias = "@", iree_path = "./"):
path = paths.join(iree_path, "third_party/vulkan_headers"),
)

maybe(
native.local_repository,
name = "spirv_headers",
path = paths.join(iree_path, "third_party/spirv_headers"),
)

maybe(
native.local_repository,
name = "stablehlo",
Expand Down
5 changes: 0 additions & 5 deletions build_tools/cmake/iree_external_cmake_options.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@ macro(iree_set_googletest_cmake_options)
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
endmacro()

macro(iree_set_spirv_headers_cmake_options)
set(SPIRV_HEADERS_SKIP_EXAMPLES ON CACHE BOOL "" FORCE)
set(SPIRV_HEADERS_SKIP_INSTALL ON CACHE BOOL "" FORCE)
endmacro()

macro(iree_set_spirv_cross_cmake_options)
set(SPIRV_CROSS_ENABLE_MSL ON CACHE BOOL "" FORCE)
set(SPIRV_CROSS_ENABLE_GLSL ON CACHE BOOL "" FORCE) # Required to enable MSL
Expand Down
1 change: 0 additions & 1 deletion build_tools/scripts/git/runtime_submodules.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ third_party/googletest
third_party/libyaml
third_party/musl
third_party/spirv_cross
third_party/spirv_headers
third_party/tracy
third_party/vulkan_headers
third_party/webgpu-headers
4 changes: 4 additions & 0 deletions compiler/plugins/iree_compiler_plugin.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ endif()
if(IREE_TARGET_BACKEND_CUDA)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/target/CUDA target/CUDA)
endif()

if(IREE_TARGET_BACKEND_WEBGPU)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/target/WebGPU target/WebGPU)
endif()
58 changes: 58 additions & 0 deletions compiler/plugins/target/WebGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2021 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# Tint is needed to compile SPIR-V into WGSL source code.
# Tint also requires SPIRV-Tools, which requires SPIRV-Headers.
#
# NOTE: these can be synced by referencing one of these repositories:
# * https://dawn.googlesource.com/dawn/
# * https://dawn.googlesource.com/tint/
# * https://chromium.googlesource.com/vulkan-deps/+/refs/heads/main/DEPS
# or they can be updated independently
set(IREE_TINT_TAG "fdb8787e9c1b79770bd98a8faf37fbe48a3077a4") # 2023-03-06
set(IREE_SPIRV_HEADERS_TAG "b730938c033ede3572b660ab019b438509ba24d9") # 2023-08-10
set(IREE_SPIRV_TOOLS_TAG "43b8886490eb6af81fc61e0ff071c51a922af864") # 2023-08-11
message(STATUS "Configuring WebGPU target deps (SPIRV-Headers, SPIRV-Tools, Tint)")
list(APPEND CMAKE_MESSAGE_INDENT " ")
add_subdirectory(spirv-headers EXCLUDE_FROM_ALL)
add_subdirectory(spirv-tools EXCLUDE_FROM_ALL)
add_subdirectory(tint EXCLUDE_FROM_ALL)
list(POP_BACK CMAKE_MESSAGE_INDENT)

add_subdirectory(test)

iree_compiler_register_plugin(
PLUGIN_ID
hal_target_webgpu
TARGET
::WebGPU
)

iree_cc_library(
NAME
WebGPU
HDRS
"SPIRVToWGSL.h"
SRCS
"SPIRVToWGSL.cpp"
"WebGPUTarget.cpp"
DEPS
LLVMSupport
MLIRGPUDialect
MLIRIR
MLIRSPIRVDialect
MLIRSPIRVSerialization
MLIRSPIRVTransforms
SPIRV-Tools
iree::compiler::Codegen::Dialect::IREECodegenDialect
iree::compiler::Codegen::SPIRV
iree::compiler::Dialect::HAL::Target
iree::compiler::PluginAPI
iree::compiler::Utils
iree::schemas::wgsl_executable_def_c_fbs
libtint
PUBLIC
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/HAL/Target/WebGPU/SPIRVToWGSL.h"
#include "./SPIRVToWGSL.h"

#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.h"

#include "./SPIRVToWGSL.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Codegen/WGSL/Passes.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/HAL/Target/WebGPU/SPIRVToWGSL.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "iree/compiler/PluginAPI/Client.h"
#include "iree/compiler/Utils/FlatbufferUtils.h"
#include "iree/schemas/wgsl_executable_def_builder.h"
#include "llvm/Support/CommandLine.h"
Expand All @@ -31,18 +30,19 @@ namespace iree_compiler {
namespace IREE {
namespace HAL {

WebGPUTargetOptions getWebGPUTargetOptionsFromFlags() {
static llvm::cl::opt<bool> clDebugSymbols(
"iree-webgpu-debug-symbols",
llvm::cl::desc(
"Include debug information like variable names in outputs"),
llvm::cl::init(true));
namespace {

WebGPUTargetOptions targetOptions;
targetOptions.debugSymbols = clDebugSymbols;
struct WebGPUOptions {
bool debugSymbols = true;

return targetOptions;
}
void bindOptions(OptionsBinder &binder) {
static llvm::cl::OptionCategory category("WebGPU HAL Target");
binder.opt<bool>(
"iree-webgpu-debug-symbols", debugSymbols, llvm::cl::cat(category),
llvm::cl::desc(
"Include debug information like variable names in outputs."));
}
};

// TODO(scotttodd): provide a proper target environment for WebGPU.
static spirv::TargetEnvAttr getWebGPUTargetEnv(MLIRContext *context) {
Expand All @@ -58,8 +58,7 @@ static spirv::TargetEnvAttr getWebGPUTargetEnv(MLIRContext *context) {

class WebGPUTargetBackend : public TargetBackend {
public:
WebGPUTargetBackend(WebGPUTargetOptions options)
: options_(std::move(options)) {}
WebGPUTargetBackend(WebGPUOptions options) : options_(std::move(options)) {}

// NOTE: we could vary this based on the options such as 'webgpu-v2'.
std::string name() const override { return "webgpu"; }
Expand Down Expand Up @@ -284,22 +283,36 @@ class WebGPUTargetBackend : public TargetBackend {
configAttr);
}

WebGPUTargetOptions options_;
WebGPUOptions options_;
};

void registerWebGPUTargetBackends(
std::function<WebGPUTargetOptions()> queryOptions) {
getWebGPUTargetOptionsFromFlags();
auto backendFactory = [=]() {
return std::make_shared<WebGPUTargetBackend>(queryOptions());
};
// #hal.device.target<"webgpu", ...
static TargetBackendRegistration registration0("webgpu", backendFactory);
// #hal.executable.target<"webgpu-wgsl", ...
static TargetBackendRegistration registration1("webgpu-wgsl", backendFactory);
}
struct WebGPUSession
: public PluginSession<WebGPUSession, WebGPUOptions,
PluginActivationPolicy::DefaultActivated> {
void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) {
auto backendFactory = [=]() {
return std::make_shared<WebGPUTargetBackend>(options);
};
// #hal.device.target<"webgpu", ...
targets.add("webgpu", backendFactory);
// #hal.executable.target<"webgpu-wgsl", ...
targets.add("webgpu-wgsl", backendFactory);
}
};

} // namespace

} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir

IREE_DEFINE_COMPILER_OPTION_FLAGS(
mlir::iree_compiler::IREE::HAL::WebGPUOptions);

extern "C" bool iree_register_compiler_plugin_hal_target_webgpu(
mlir::iree_compiler::PluginRegistrar *registrar) {
registrar->registerPlugin<mlir::iree_compiler::IREE::HAL::WebGPUSession>(
"hal_target_webgpu");
return true;
}
19 changes: 19 additions & 0 deletions compiler/plugins/target/WebGPU/spirv-headers/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

include(FetchContent)

FetchContent_Declare(
spirv-headers
GIT_REPOSITORY https://github.com/KhronosGroup/SPIRV-Headers.git
GIT_TAG ${IREE_SPIRV_HEADERS_TAG}
)

set(SPIRV_HEADERS_SKIP_EXAMPLES ON CACHE BOOL "" FORCE)
set(SPIRV_HEADERS_SKIP_INSTALL ON CACHE BOOL "" FORCE)

FetchContent_MakeAvailable(spirv-headers)
FetchContent_GetProperties(spirv-headers SOURCE_DIR SPIRV_HEADERS_SOURCE)
File renamed without changes.

This file was deleted.

This file was deleted.

4 changes: 0 additions & 4 deletions compiler/src/iree/compiler/Tools/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ if(IREE_TARGET_BACKEND_VULKAN_SPIRV)
list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::VulkanSPIRV)
list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_VULKANSPIRV_TARGET")
endif()
if(IREE_TARGET_BACKEND_WEBGPU)
list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::WebGPU)
list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_WEBGPU_TARGET")
endif()
if(IREE_TARGET_BACKEND_ROCM)
list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::ROCM)
list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_ROCM_TARGET")
Expand Down
1 change: 0 additions & 1 deletion third_party/spirv_headers
Submodule spirv_headers deleted from b73093

0 comments on commit ef8982c

Please sign in to comment.