forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
1,614 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Copyright 2023 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
if(NOT IREE_HAL_DRIVER_LOCAL_SYNC) | ||
return() | ||
endif() | ||
|
||
if(IREE_ENABLE_RUNTIME_TRACING) | ||
message(WARNING "IREE_ENABLE_RUNTIME_TRACING enabled but it currently has issues with dynamic libraries") | ||
endif() | ||
|
||
set(_NAME "iree_samples_custom_module_systolic_dynamic_module") | ||
add_library(${_NAME} SHARED module.cc) | ||
target_link_libraries(${_NAME} | ||
iree_base_base | ||
iree_hal_hal | ||
iree_modules_hal_types | ||
iree_vm_vm | ||
iree_vm_dynamic_api | ||
) | ||
|
||
# NOTE: this is only required because we want this sample to run on all | ||
# platforms without needing to change the library name (libfoo.so/foo.dll). | ||
set_target_properties(${_NAME} | ||
PROPERTIES | ||
WINDOWS_EXPORT_ALL_SYMBOLS ON | ||
PREFIX "" | ||
OUTPUT_NAME "module" | ||
) | ||
|
||
# TODO(benvanik): make iree_status_annotate_f always available as a function | ||
# instead of defining it empty? otherwise optimized builds of the runtime won't | ||
# export it but external libraries may pull it in. | ||
target_compile_options(${_NAME} PRIVATE ${IREE_DEFAULT_COPTS}) | ||
|
||
add_dependencies(iree-sample-deps ${_NAME}) | ||
|
||
add_subdirectory(test) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,303 @@ | ||
// Copyright 2023 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include <cstdio> | ||
|
||
#include "iree/base/api.h" | ||
#include "iree/hal/api.h" | ||
#include "iree/modules/hal/types.h" | ||
#include "iree/vm/api.h" | ||
#include "iree/vm/dynamic/api.h" | ||
#include "iree/vm/native_module_cc.h" | ||
|
||
// NOTE: this module is written in C++ using the native module wrapper and uses | ||
// template magic to handle marshaling arguments. For a lot of uses this is a | ||
// much friendlier way of exposing modules to the IREE VM and if performance and | ||
// code size are not a concern is a fine route to take. Here we do it for | ||
// brevity but all of the internal IREE modules are implemented in C. | ||
|
||
//===----------------------------------------------------------------------===// | ||
// !custom.string type | ||
//===----------------------------------------------------------------------===// | ||
|
||
// The "string" type we use to store and retain string data. | ||
// This could be arbitrarily complex or simply wrap another user-defined type. | ||
// The descriptor that is registered at startup defines how to manage the | ||
// lifetime of the type (such as which destruction function is called, if any). | ||
// See ref.h for more information and additional utilities. | ||
|
||
//===----------------------------------------------------------------------===// | ||
// VM module interface implementation | ||
//===----------------------------------------------------------------------===// | ||
|
||
namespace { | ||
|
||
using namespace iree; | ||
|
||
// Approximation of some external library call that populates a buffer. | ||
// It's assumed that when this is called the |source_buffer| is available to | ||
// read and the |target_buffer| is available to write (no other readers exist). | ||
// This sample assumes that the buffers are mappable so we can do the work here | ||
// but they will not always be. APIs like iree_hal_allocator_import_buffer and | ||
// iree_hal_allocator_export_buffer can be used in some cases to avoid | ||
// potentially expensive operations but real applications that care about | ||
// performance would want to issue async transfer command buffers. | ||
// | ||
// Only use this as a reference for when synchronous behavior is absolutely | ||
// required (old-style blocking file IO/etc). | ||
static Status SyncSimulatedHostOpI32(iree_hal_buffer_t* source_buffer_a, | ||
iree_hal_buffer_t* source_buffer_b, | ||
iree_hal_buffer_t* target_buffer, | ||
iree_device_size_t n, iree_device_size_t m, | ||
iree_device_size_t k) { | ||
Status status = OkStatus(); | ||
|
||
// Map the source and target buffers into host memory. Note that not all | ||
// devices allow this but in this sample we assume they do. | ||
iree_hal_buffer_mapping_t source_mapping_a = {{0}}; | ||
if (status.ok()) { | ||
status = iree_hal_buffer_map_range( | ||
source_buffer_a, IREE_HAL_MAPPING_MODE_SCOPED, | ||
IREE_HAL_MEMORY_ACCESS_READ, 0, IREE_WHOLE_BUFFER, &source_mapping_a); | ||
} | ||
iree_hal_buffer_mapping_t source_mapping_b = {{0}}; | ||
if (status.ok()) { | ||
status = iree_hal_buffer_map_range( | ||
source_buffer_a, IREE_HAL_MAPPING_MODE_SCOPED, | ||
IREE_HAL_MEMORY_ACCESS_READ, 0, IREE_WHOLE_BUFFER, &source_mapping_b); | ||
} | ||
iree_hal_buffer_mapping_t target_mapping = {{0}}; | ||
if (status.ok()) { | ||
status = | ||
iree_hal_buffer_map_range(target_buffer, IREE_HAL_MAPPING_MODE_SCOPED, | ||
IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, 0, | ||
IREE_WHOLE_BUFFER, &target_mapping); | ||
} | ||
|
||
// Sad slow host work. Whenever possible it's worth it to move these into the | ||
// program so the IREE compiler can fuse and accelerate these operations. | ||
if (status.ok()) { | ||
const float* source_ptr_a = | ||
reinterpret_cast<const float*>(source_mapping_a.contents.data); | ||
|
||
const float* source_ptr_b = | ||
reinterpret_cast<const float*>(source_mapping_b.contents.data); | ||
float* target_ptr = reinterpret_cast<float*>(target_mapping.contents.data); | ||
|
||
for (int i = 0; i < n; ++i) { | ||
for (int j = 0; j < m; ++j) { | ||
for (int l = 0; l < k; ++l) { | ||
target_ptr[i * m + j] += | ||
source_ptr_a[i * k + l] * source_ptr_b[l * m + j]; | ||
} | ||
} | ||
} | ||
|
||
// for (iree_host_size_t i = 0; i < count; ++i) { | ||
// target_ptr[i] = source_ptr[i] * 2; | ||
// } | ||
} | ||
|
||
// We must unmap the buffers before they will be usable. | ||
// Note that it's possible for these to fail in cases where the buffer | ||
// required emulated mapping but on basic host-local devices like CPU assumed | ||
// in this sample that should never happen. | ||
iree_status_ignore(iree_hal_buffer_unmap_range(&source_mapping_a)); | ||
iree_status_ignore(iree_hal_buffer_unmap_range(&source_mapping_b)); | ||
|
||
iree_status_ignore(iree_hal_buffer_unmap_range(&target_mapping)); | ||
|
||
return status; | ||
} | ||
|
||
// Per-context module state. | ||
class CustomModuleState final { | ||
public: | ||
explicit CustomModuleState(vm::ref<iree_hal_device_t> device, | ||
iree_allocator_t host_allocator) | ||
: device_(std::move(device)), host_allocator_(host_allocator) {} | ||
~CustomModuleState() = default; | ||
|
||
StatusOr<vm::ref<iree_hal_buffer_view_t>> MatMul( | ||
const vm::ref<iree_hal_buffer_view_t> a, | ||
const vm::ref<iree_hal_buffer_view_t> b, | ||
vm::ref<iree_hal_buffer_view_t> c) { | ||
// We can directly access the buffer here but only for reading. | ||
// In the future it'll be possible to pass in-place buffers. | ||
auto* arg_buffer_a = iree_hal_buffer_view_buffer(a.get()); | ||
auto* arg_buffer_b = iree_hal_buffer_view_buffer(b.get()); | ||
auto* arg_buffer_c = iree_hal_buffer_view_buffer(c.get()); | ||
|
||
iree_host_size_t rank_cap_a = 2; | ||
iree_hal_dim_t out_shape_a[2] = {0, 0}; | ||
iree_host_size_t* out_shape_rank_a; | ||
|
||
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_shape( | ||
a.get(), rank_cap_a, out_shape_a, out_shape_rank_a)); | ||
|
||
iree_host_size_t rank_cap_b = 2; | ||
iree_hal_dim_t out_shape_b[2] = {0, 0}; | ||
iree_host_size_t* out_shape_rank_b; | ||
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_shape( | ||
b.get(), rank_cap_b, out_shape_b, out_shape_rank_b)); | ||
|
||
iree_host_size_t n = out_shape_a[0]; | ||
iree_host_size_t m = out_shape_b[1]; | ||
iree_host_size_t k = out_shape_a[1]; | ||
|
||
// Synchronously allocate the memory from the device allocator. We could | ||
// use queue-ordered allocations but that's unsafe to use from arbitrary | ||
// threads and we want to show how to safely do that using the thread-safe | ||
// device allocator. | ||
// | ||
// NOTE: if cloning host memory the initial_data can be passed in to | ||
// efficiently upload the memory to the device. If wrapping host memory then | ||
// iree_hal_allocator_import_buffer can be used to import the memory without | ||
// a copy (if supported). This simple example is showing an in-place style | ||
// external call. | ||
/* | ||
fprintf(stdout, "getting the device pointer\n"); | ||
fflush(stdout); | ||
iree_hal_device_t* device = device_.get(); | ||
int64_t device_int = (int64_t)device; | ||
fprintf(stdout, "debugging device pointer %ld\n", device_int); | ||
fflush(stdout); | ||
fprintf(stdout, "getting device allocator\n"); | ||
fflush(stdout); | ||
iree_hal_allocator_t* device_allocator = iree_hal_device_allocator(device); | ||
fprintf(stdout, "allocating buffer"); | ||
fflush(stdout); | ||
*/ | ||
// iree_hal_buffer_params_t buffer_params = { | ||
// /*.usage=*/IREE_HAL_BUFFER_USAGE_DEFAULT | | ||
// IREE_HAL_BUFFER_USAGE_MAPPING, | ||
// /*.access=*/IREE_HAL_MEMORY_ACCESS_ALL, | ||
// /*.type=*/IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE | | ||
// IREE_HAL_MEMORY_TYPE_HOST_VISIBLE, | ||
// /*.queue_affinity=*/IREE_HAL_QUEUE_AFFINITY_ANY, | ||
// /*.min_alignment=*/64, | ||
// }; | ||
/* | ||
vm::ref<iree_hal_buffer_t> result_buffer; | ||
IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer( | ||
device_allocator, buffer_params, sizeof(float) * n * m, | ||
&result_buffer)); | ||
fprintf(stdout, "HERE3"); | ||
fflush(stdout); | ||
*/ | ||
// Hacky example accessing the source contents and producing the result | ||
// contents. This emulates what an external library the user is calling that | ||
// expects host void* buffers does. | ||
IREE_RETURN_IF_ERROR(SyncSimulatedHostOpI32(arg_buffer_a, arg_buffer_b, | ||
arg_buffer_c, n, m, k)); | ||
|
||
// Wrap the buffer in a buffer view that provides the metadata for | ||
// runtime verification. | ||
/* | ||
vm::ref<iree_hal_buffer_view_t> result_view; | ||
const iree_hal_dim_t shape[2] = {n, m}; | ||
IREE_RETURN_IF_ERROR( | ||
iree_hal_buffer_view_create(result_buffer.get(), 2, shape, | ||
iree_hal_buffer_view_element_type(a.get()), | ||
iree_hal_buffer_view_encoding_type(a.get()), | ||
host_allocator_, &result_view)); | ||
*/ | ||
|
||
// Note that the caller may immediately use the buffer contents without | ||
// waiting as by being synchronous we've indicated that we waited ourselves | ||
// (the thread join above). | ||
return c; | ||
} | ||
|
||
private: | ||
// HAL device used for scheduling work and allocations. | ||
vm::ref<iree_hal_device_t> device_; | ||
|
||
// Allocator that the caller requested we use for any allocations we need to | ||
// perform during operation. | ||
iree_allocator_t host_allocator_; | ||
}; | ||
|
||
// Function table mapping imported function names to their implementation. | ||
static const vm::NativeFunction<CustomModuleState> kCustomModuleFunctions[] = { | ||
vm::MakeNativeFunction("matmul", &CustomModuleState::MatMul), | ||
}; | ||
|
||
// The module instance that will be allocated and reused across contexts. | ||
class CustomModule final : public vm::NativeModule<CustomModuleState> { | ||
public: | ||
using vm::NativeModule<CustomModuleState>::NativeModule; | ||
|
||
void SetDevice(vm::ref<iree_hal_device_t> device) { | ||
device_ = std::move(device); | ||
} | ||
|
||
// Creates per-context state when the module is added to a new context. | ||
// May be called from any thread. | ||
StatusOr<std::unique_ptr<CustomModuleState>> CreateState( | ||
iree_allocator_t host_allocator) override { | ||
auto state = std::make_unique<CustomModuleState>(vm::retain_ref(device_), | ||
host_allocator); | ||
return state; | ||
} | ||
|
||
private: | ||
vm::ref<iree_hal_device_t> device_; | ||
}; | ||
|
||
} // namespace | ||
|
||
// Creates a native custom module that can be reused in multiple contexts. | ||
// The module itself may hold state that can be shared by all instantiated | ||
// copies but it will require the module to provide synchronization; usually | ||
// it's safer to just treat the module as immutable and keep state within the | ||
// instantiated module states instead. | ||
// | ||
// Note that while we are using C++ bindings internally we still expose the | ||
// module as a C instance. This hides the details of our implementation and | ||
// is required for working across the dynamic library boundary. | ||
extern "C" IREE_VM_DYNAMIC_MODULE_EXPORT iree_status_t create_custom_module( | ||
iree_vm_dynamic_module_version_t max_version, iree_vm_instance_t* instance, | ||
iree_host_size_t param_count, const iree_string_pair_t* params, | ||
iree_allocator_t host_allocator, iree_vm_module_t** out_module) { | ||
// Ensure the version matches; the version will change if the VM module | ||
// interface changes and existing libraries are incompatible. | ||
if (max_version != IREE_VM_DYNAMIC_MODULE_VERSION_LATEST) { | ||
return iree_make_status( | ||
IREE_STATUS_UNIMPLEMENTED, | ||
"unsupported runtime version %u, module compiled with version %u", | ||
max_version, IREE_VM_DYNAMIC_MODULE_VERSION_LATEST); | ||
} | ||
|
||
#if IREE_TRACING_FEATURES | ||
// Today Tracy cannot be used with custom dynamic modules as it'll try to | ||
// create a new tracing context distinct from the hosting application. Custom | ||
// module libraries should be built with tracing disabled. | ||
fprintf(stderr, | ||
"Tracy is not currently supported in custom dynamic modules\n"); | ||
#endif // IREE_TRACING_FEATURES | ||
|
||
// Ensure HAL types are available. We need to do this as we're being | ||
// dynamically loaded and can't automatically access the hosting process | ||
// variables. | ||
IREE_RETURN_IF_ERROR(iree_hal_module_resolve_all_types(instance)); | ||
|
||
// Register custom types used by the module against the instance. | ||
// Note that this function must be safe to call multiple times as the module | ||
// may be loaded multiple times. | ||
// Create the custom module and return it to the runtime. | ||
// NOTE: this isn't using the allocator here and that's bad as it leaves | ||
// untracked allocations and pulls in the system allocator that may differ | ||
// from the one requested by the user. | ||
// TODO(benvanik): std::allocator wrapper around iree_allocator_t so this can | ||
// use that instead. | ||
auto module = std::make_unique<CustomModule>( | ||
"systolic", /*version=*/0, instance, host_allocator, | ||
iree::span<const vm::NativeFunction<CustomModuleState>>( | ||
kCustomModuleFunctions)); | ||
*out_module = module.release()->interface(); | ||
return iree_ok_status(); | ||
} |
20 changes: 20 additions & 0 deletions
20
samples/custom_module/systolic-dynamic/test/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright 2023 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
iree_lit_test_suite( | ||
NAME | ||
lit | ||
SRCS | ||
"example.mlir" | ||
TOOLS | ||
FileCheck | ||
iree-compile | ||
iree-run-module | ||
iree_samples_custom_module_systolic_dynamic_module | ||
LABELS | ||
"driver=local-sync" | ||
"hostonly" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
module @example { | ||
func.func private @systolic.matmul(tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> | ||
func.func @forward(%a: tensor<?x?xf32>, %b: tensor<?x?xf32>) -> tensor<?x?xf32> { | ||
%result = tensor.empty() : tensor<2x2xf32> | ||
%cast = tensor.cast %result : tensor<2x2xf32> to tensor<?x?xf32> | ||
%4 = call @systolic.matmul(%a, %b, %cast) : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> | ||
return %4 : tensor<?x?xf32> | ||
} | ||
} |
Oops, something went wrong.