Skip to content

Commit

Permalink
custom module and dynamic module
Browse files Browse the repository at this point in the history
  • Loading branch information
montaglue committed May 22, 2024
1 parent 9fe159d commit 46f3678
Show file tree
Hide file tree
Showing 14 changed files with 1,614 additions and 0 deletions.
41 changes: 41 additions & 0 deletions samples/custom_module/systolic-dynamic/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

if(NOT IREE_HAL_DRIVER_LOCAL_SYNC)
return()
endif()

if(IREE_ENABLE_RUNTIME_TRACING)
message(WARNING "IREE_ENABLE_RUNTIME_TRACING enabled but it currently has issues with dynamic libraries")
endif()

set(_NAME "iree_samples_custom_module_systolic_dynamic_module")
add_library(${_NAME} SHARED module.cc)
target_link_libraries(${_NAME}
iree_base_base
iree_hal_hal
iree_modules_hal_types
iree_vm_vm
iree_vm_dynamic_api
)

# NOTE: this is only required because we want this sample to run on all
# platforms without needing to change the library name (libfoo.so/foo.dll).
set_target_properties(${_NAME}
PROPERTIES
WINDOWS_EXPORT_ALL_SYMBOLS ON
PREFIX ""
OUTPUT_NAME "module"
)

# TODO(benvanik): make iree_status_annotate_f always available as a function
# instead of defining it empty? otherwise optimized builds of the runtime won't
# export it but external libraries may pull it in.
target_compile_options(${_NAME} PRIVATE ${IREE_DEFAULT_COPTS})

add_dependencies(iree-sample-deps ${_NAME})

add_subdirectory(test)
303 changes: 303 additions & 0 deletions samples/custom_module/systolic-dynamic/module.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <cstdio>

#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/types.h"
#include "iree/vm/api.h"
#include "iree/vm/dynamic/api.h"
#include "iree/vm/native_module_cc.h"

// NOTE: this module is written in C++ using the native module wrapper and uses
// template magic to handle marshaling arguments. For a lot of uses this is a
// much friendlier way of exposing modules to the IREE VM and if performance and
// code size are not a concern is a fine route to take. Here we do it for
// brevity but all of the internal IREE modules are implemented in C.

//===----------------------------------------------------------------------===//
// !custom.string type
//===----------------------------------------------------------------------===//

// The "string" type we use to store and retain string data.
// This could be arbitrarily complex or simply wrap another user-defined type.
// The descriptor that is registered at startup defines how to manage the
// lifetime of the type (such as which destruction function is called, if any).
// See ref.h for more information and additional utilities.

//===----------------------------------------------------------------------===//
// VM module interface implementation
//===----------------------------------------------------------------------===//

namespace {

using namespace iree;

// Approximation of some external library call that populates a buffer.
// It's assumed that when this is called the |source_buffer| is available to
// read and the |target_buffer| is available to write (no other readers exist).
// This sample assumes that the buffers are mappable so we can do the work here
// but they will not always be. APIs like iree_hal_allocator_import_buffer and
// iree_hal_allocator_export_buffer can be used in some cases to avoid
// potentially expensive operations but real applications that care about
// performance would want to issue async transfer command buffers.
//
// Only use this as a reference for when synchronous behavior is absolutely
// required (old-style blocking file IO/etc).
static Status SyncSimulatedHostOpI32(iree_hal_buffer_t* source_buffer_a,
iree_hal_buffer_t* source_buffer_b,
iree_hal_buffer_t* target_buffer,
iree_device_size_t n, iree_device_size_t m,
iree_device_size_t k) {
Status status = OkStatus();

// Map the source and target buffers into host memory. Note that not all
// devices allow this but in this sample we assume they do.
iree_hal_buffer_mapping_t source_mapping_a = {{0}};
if (status.ok()) {
status = iree_hal_buffer_map_range(
source_buffer_a, IREE_HAL_MAPPING_MODE_SCOPED,
IREE_HAL_MEMORY_ACCESS_READ, 0, IREE_WHOLE_BUFFER, &source_mapping_a);
}
iree_hal_buffer_mapping_t source_mapping_b = {{0}};
if (status.ok()) {
status = iree_hal_buffer_map_range(
source_buffer_a, IREE_HAL_MAPPING_MODE_SCOPED,
IREE_HAL_MEMORY_ACCESS_READ, 0, IREE_WHOLE_BUFFER, &source_mapping_b);
}
iree_hal_buffer_mapping_t target_mapping = {{0}};
if (status.ok()) {
status =
iree_hal_buffer_map_range(target_buffer, IREE_HAL_MAPPING_MODE_SCOPED,
IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, 0,
IREE_WHOLE_BUFFER, &target_mapping);
}

// Sad slow host work. Whenever possible it's worth it to move these into the
// program so the IREE compiler can fuse and accelerate these operations.
if (status.ok()) {
const float* source_ptr_a =
reinterpret_cast<const float*>(source_mapping_a.contents.data);

const float* source_ptr_b =
reinterpret_cast<const float*>(source_mapping_b.contents.data);
float* target_ptr = reinterpret_cast<float*>(target_mapping.contents.data);

for (int i = 0; i < n; ++i) {
for (int j = 0; j < m; ++j) {
for (int l = 0; l < k; ++l) {
target_ptr[i * m + j] +=
source_ptr_a[i * k + l] * source_ptr_b[l * m + j];
}
}
}

// for (iree_host_size_t i = 0; i < count; ++i) {
// target_ptr[i] = source_ptr[i] * 2;
// }
}

// We must unmap the buffers before they will be usable.
// Note that it's possible for these to fail in cases where the buffer
// required emulated mapping but on basic host-local devices like CPU assumed
// in this sample that should never happen.
iree_status_ignore(iree_hal_buffer_unmap_range(&source_mapping_a));
iree_status_ignore(iree_hal_buffer_unmap_range(&source_mapping_b));

iree_status_ignore(iree_hal_buffer_unmap_range(&target_mapping));

return status;
}

// Per-context module state.
class CustomModuleState final {
public:
explicit CustomModuleState(vm::ref<iree_hal_device_t> device,
iree_allocator_t host_allocator)
: device_(std::move(device)), host_allocator_(host_allocator) {}
~CustomModuleState() = default;

StatusOr<vm::ref<iree_hal_buffer_view_t>> MatMul(
const vm::ref<iree_hal_buffer_view_t> a,
const vm::ref<iree_hal_buffer_view_t> b,
vm::ref<iree_hal_buffer_view_t> c) {
// We can directly access the buffer here but only for reading.
// In the future it'll be possible to pass in-place buffers.
auto* arg_buffer_a = iree_hal_buffer_view_buffer(a.get());
auto* arg_buffer_b = iree_hal_buffer_view_buffer(b.get());
auto* arg_buffer_c = iree_hal_buffer_view_buffer(c.get());

iree_host_size_t rank_cap_a = 2;
iree_hal_dim_t out_shape_a[2] = {0, 0};
iree_host_size_t* out_shape_rank_a;

IREE_RETURN_IF_ERROR(iree_hal_buffer_view_shape(
a.get(), rank_cap_a, out_shape_a, out_shape_rank_a));

iree_host_size_t rank_cap_b = 2;
iree_hal_dim_t out_shape_b[2] = {0, 0};
iree_host_size_t* out_shape_rank_b;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_shape(
b.get(), rank_cap_b, out_shape_b, out_shape_rank_b));

iree_host_size_t n = out_shape_a[0];
iree_host_size_t m = out_shape_b[1];
iree_host_size_t k = out_shape_a[1];

// Synchronously allocate the memory from the device allocator. We could
// use queue-ordered allocations but that's unsafe to use from arbitrary
// threads and we want to show how to safely do that using the thread-safe
// device allocator.
//
// NOTE: if cloning host memory the initial_data can be passed in to
// efficiently upload the memory to the device. If wrapping host memory then
// iree_hal_allocator_import_buffer can be used to import the memory without
// a copy (if supported). This simple example is showing an in-place style
// external call.
/*
fprintf(stdout, "getting the device pointer\n");
fflush(stdout);
iree_hal_device_t* device = device_.get();
int64_t device_int = (int64_t)device;
fprintf(stdout, "debugging device pointer %ld\n", device_int);
fflush(stdout);
fprintf(stdout, "getting device allocator\n");
fflush(stdout);
iree_hal_allocator_t* device_allocator = iree_hal_device_allocator(device);
fprintf(stdout, "allocating buffer");
fflush(stdout);
*/
// iree_hal_buffer_params_t buffer_params = {
// /*.usage=*/IREE_HAL_BUFFER_USAGE_DEFAULT |
// IREE_HAL_BUFFER_USAGE_MAPPING,
// /*.access=*/IREE_HAL_MEMORY_ACCESS_ALL,
// /*.type=*/IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE |
// IREE_HAL_MEMORY_TYPE_HOST_VISIBLE,
// /*.queue_affinity=*/IREE_HAL_QUEUE_AFFINITY_ANY,
// /*.min_alignment=*/64,
// };
/*
vm::ref<iree_hal_buffer_t> result_buffer;
IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
device_allocator, buffer_params, sizeof(float) * n * m,
&result_buffer));
fprintf(stdout, "HERE3");
fflush(stdout);
*/
// Hacky example accessing the source contents and producing the result
// contents. This emulates what an external library the user is calling that
// expects host void* buffers does.
IREE_RETURN_IF_ERROR(SyncSimulatedHostOpI32(arg_buffer_a, arg_buffer_b,
arg_buffer_c, n, m, k));

// Wrap the buffer in a buffer view that provides the metadata for
// runtime verification.
/*
vm::ref<iree_hal_buffer_view_t> result_view;
const iree_hal_dim_t shape[2] = {n, m};
IREE_RETURN_IF_ERROR(
iree_hal_buffer_view_create(result_buffer.get(), 2, shape,
iree_hal_buffer_view_element_type(a.get()),
iree_hal_buffer_view_encoding_type(a.get()),
host_allocator_, &result_view));
*/

// Note that the caller may immediately use the buffer contents without
// waiting as by being synchronous we've indicated that we waited ourselves
// (the thread join above).
return c;
}

private:
// HAL device used for scheduling work and allocations.
vm::ref<iree_hal_device_t> device_;

// Allocator that the caller requested we use for any allocations we need to
// perform during operation.
iree_allocator_t host_allocator_;
};

// Function table mapping imported function names to their implementation.
static const vm::NativeFunction<CustomModuleState> kCustomModuleFunctions[] = {
vm::MakeNativeFunction("matmul", &CustomModuleState::MatMul),
};

// The module instance that will be allocated and reused across contexts.
class CustomModule final : public vm::NativeModule<CustomModuleState> {
public:
using vm::NativeModule<CustomModuleState>::NativeModule;

void SetDevice(vm::ref<iree_hal_device_t> device) {
device_ = std::move(device);
}

// Creates per-context state when the module is added to a new context.
// May be called from any thread.
StatusOr<std::unique_ptr<CustomModuleState>> CreateState(
iree_allocator_t host_allocator) override {
auto state = std::make_unique<CustomModuleState>(vm::retain_ref(device_),
host_allocator);
return state;
}

private:
vm::ref<iree_hal_device_t> device_;
};

} // namespace

// Creates a native custom module that can be reused in multiple contexts.
// The module itself may hold state that can be shared by all instantiated
// copies but it will require the module to provide synchronization; usually
// it's safer to just treat the module as immutable and keep state within the
// instantiated module states instead.
//
// Note that while we are using C++ bindings internally we still expose the
// module as a C instance. This hides the details of our implementation and
// is required for working across the dynamic library boundary.
extern "C" IREE_VM_DYNAMIC_MODULE_EXPORT iree_status_t create_custom_module(
iree_vm_dynamic_module_version_t max_version, iree_vm_instance_t* instance,
iree_host_size_t param_count, const iree_string_pair_t* params,
iree_allocator_t host_allocator, iree_vm_module_t** out_module) {
// Ensure the version matches; the version will change if the VM module
// interface changes and existing libraries are incompatible.
if (max_version != IREE_VM_DYNAMIC_MODULE_VERSION_LATEST) {
return iree_make_status(
IREE_STATUS_UNIMPLEMENTED,
"unsupported runtime version %u, module compiled with version %u",
max_version, IREE_VM_DYNAMIC_MODULE_VERSION_LATEST);
}

#if IREE_TRACING_FEATURES
// Today Tracy cannot be used with custom dynamic modules as it'll try to
// create a new tracing context distinct from the hosting application. Custom
// module libraries should be built with tracing disabled.
fprintf(stderr,
"Tracy is not currently supported in custom dynamic modules\n");
#endif // IREE_TRACING_FEATURES

// Ensure HAL types are available. We need to do this as we're being
// dynamically loaded and can't automatically access the hosting process
// variables.
IREE_RETURN_IF_ERROR(iree_hal_module_resolve_all_types(instance));

// Register custom types used by the module against the instance.
// Note that this function must be safe to call multiple times as the module
// may be loaded multiple times.
// Create the custom module and return it to the runtime.
// NOTE: this isn't using the allocator here and that's bad as it leaves
// untracked allocations and pulls in the system allocator that may differ
// from the one requested by the user.
// TODO(benvanik): std::allocator wrapper around iree_allocator_t so this can
// use that instead.
auto module = std::make_unique<CustomModule>(
"systolic", /*version=*/0, instance, host_allocator,
iree::span<const vm::NativeFunction<CustomModuleState>>(
kCustomModuleFunctions));
*out_module = module.release()->interface();
return iree_ok_status();
}
20 changes: 20 additions & 0 deletions samples/custom_module/systolic-dynamic/test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

iree_lit_test_suite(
NAME
lit
SRCS
"example.mlir"
TOOLS
FileCheck
iree-compile
iree-run-module
iree_samples_custom_module_systolic_dynamic_module
LABELS
"driver=local-sync"
"hostonly"
)
9 changes: 9 additions & 0 deletions samples/custom_module/systolic-dynamic/test/example.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module @example {
func.func private @systolic.matmul(tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
func.func @forward(%a: tensor<?x?xf32>, %b: tensor<?x?xf32>) -> tensor<?x?xf32> {
%result = tensor.empty() : tensor<2x2xf32>
%cast = tensor.cast %result : tensor<2x2xf32> to tensor<?x?xf32>
%4 = call @systolic.matmul(%a, %b, %cast) : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %4 : tensor<?x?xf32>
}
}
Loading

0 comments on commit 46f3678

Please sign in to comment.