diff --git a/iree/integrations/pjrt/common/BUILD b/iree/integrations/pjrt/common/BUILD index e36f7bd01304..32fb46d7299b 100644 --- a/iree/integrations/pjrt/common/BUILD +++ b/iree/integrations/pjrt/common/BUILD @@ -44,6 +44,7 @@ iree_pjrt_cc_library( hdrs = [ "api_impl.h", "dylib_entry_point.cc.inc", + "iree_helpers.h", "platform.h", ], deps = [ diff --git a/iree/integrations/pjrt/common/api_impl.cc b/iree/integrations/pjrt/common/api_impl.cc index 5a3621bbee0b..0a8b47f2e28a 100644 --- a/iree/integrations/pjrt/common/api_impl.cc +++ b/iree/integrations/pjrt/common/api_impl.cc @@ -11,6 +11,7 @@ #include "iree/base/tracing.h" #include "iree/hal/api.h" +#include "iree/integrations/pjrt/common/iree_helpers.h" #include "iree/integrations/pjrt/common/tensor_utils.h" using iree::vm::retain_ref; @@ -461,7 +462,7 @@ iree_status_t BufferInstance::GetHostSizeInBytes(iree_host_size_t* host_size) { iree_status_t BufferInstance::AsyncDeallocate() { IREE_TRACE_SCOPE(); - return iree_hal_device_queue_dealloca( + return IreeApi::hal_device_queue_dealloca( device().device(), IREE_HAL_QUEUE_AFFINITY_ANY, /*wait_semaphore_list=*/iree_hal_fence_semaphore_list(done_fence()), /*signal_semaphore_list=*/iree_hal_semaphore_list_empty(), @@ -535,7 +536,7 @@ iree_status_t BufferInstance::CopyToHost(void* dst, iree_host_size_t dst_size, copy_data->event->ExternalSignalReady(iree_ok_status()); delete copy_data; }; - IREE_RETURN_IF_ERROR(iree_hal_allocator_import_buffer( + IREE_RETURN_IF_ERROR(IreeApi::hal_allocator_import_buffer( device_.device_allocator(), dst_buffer_params, &dst_external_buffer, /*release_callback=*/{release_callback, copy_to_host_data}, &dst_buffer)); @@ -556,7 +557,7 @@ iree_status_t BufferInstance::CopyToHost(void* dst, iree_host_size_t dst_size, /*transfer_count=*/1, &transfer_command, &transfer_cb)); dst_buffer.reset(); - IREE_RETURN_IF_ERROR(iree_hal_device_queue_execute( + IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_execute( device_.device(), IREE_HAL_QUEUE_AFFINITY_ANY, /*wait_semaphore_list=*/iree_hal_fence_semaphore_list(ready_fence_.get()), /*signal_semaphore_list=*/iree_hal_semaphore_list_empty(), @@ -567,12 +568,12 @@ iree_status_t BufferInstance::CopyToHost(void* dst, iree_host_size_t dst_size, iree_status_t BufferInstance::AdvanceReadyFence(iree_hal_semaphore_t* semaphore, uint64_t timepoint) { - return iree_hal_fence_insert(ready_fence_.get(), semaphore, timepoint); + return IreeApi::hal_fence_insert(ready_fence_.get(), semaphore, timepoint); } iree_status_t BufferInstance::AdvanceDoneFence(iree_hal_semaphore_t* semaphore, uint64_t timepoint) { - return iree_hal_fence_insert(done_fence_.get(), semaphore, timepoint); + return IreeApi::hal_fence_insert(done_fence_.get(), semaphore, timepoint); } //===----------------------------------------------------------------------===// @@ -634,8 +635,8 @@ void DeviceInstance::BindApi(PJRT_Api* api) { } iree_status_t DeviceInstance::CreateFence(iree_hal_fence_t** out_fence) { - return iree_hal_fence_create(/*capacity=*/2, client_.host_allocator(), - out_fence); + return IreeApi::hal_fence_create(/*capacity=*/2, client_.host_allocator(), + out_fence); } iree_status_t DeviceInstance::OpenDevice() { @@ -767,7 +768,7 @@ iree_status_t DeviceInstance::HostBufferToDevice( params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL; params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT | IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET; - IREE_RETURN_IF_ERROR(iree_hal_device_queue_alloca( + IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_alloca( device(), IREE_HAL_QUEUE_AFFINITY_ANY, /*wait_semaphore_list=*/ {1, &transfer_timeline_, &wait_transfer_start}, @@ -803,14 +804,14 @@ iree_status_t DeviceInstance::HostBufferToDevice( } if (has_zero_length) { - IREE_RETURN_IF_ERROR(iree_hal_device_queue_barrier( + IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_barrier( device(), IREE_HAL_QUEUE_AFFINITY_ANY, /*wait_semaphore_list=*/ {1, &transfer_timeline_, &signal_alloca_complete}, /*signal_semaphore_list=*/ {1, &transfer_timeline_, &signal_copy_complete})); } else { - IREE_RETURN_IF_ERROR(iree_hal_device_queue_execute( + IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_execute( device(), IREE_HAL_QUEUE_AFFINITY_ANY, /*wait_semaphore_list=*/ {1, &transfer_timeline_, &signal_alloca_complete}, @@ -853,7 +854,7 @@ iree_status_t DeviceInstance::AcquireHostStagingBuffer( memset(¶ms, 0, sizeof(params)); params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE; params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER; - IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer( + IREE_RETURN_IF_ERROR(IreeApi::hal_allocator_allocate_buffer( device_allocator(), params, initial_contents.data_length, initial_contents, out_buffer)); // We did a synchronous snapshot (memcpy). @@ -1494,13 +1495,13 @@ iree_status_t LoadedExecutableInstance::BatchExecute( // semaphores. IREE_RETURN_IF_ERROR( inv.res_exe->device_instance->CreateFence(&inv.wait_fence)); - IREE_RETURN_IF_ERROR(iree_hal_fence_insert( + IREE_RETURN_IF_ERROR(IreeApi::hal_fence_insert( inv.wait_fence.get(), inv.res_exe->device_instance->main_timeline(), wait_timepoint)); // Signal fence. This signals the next tick on the main execution // timeline. - IREE_RETURN_IF_ERROR(iree_hal_fence_create_at( + IREE_RETURN_IF_ERROR(IreeApi::hal_fence_create_at( inv.res_exe->device_instance->main_timeline(), signal_timepoint, client_.host_allocator(), &inv.signal_fence)); @@ -1520,8 +1521,8 @@ iree_status_t LoadedExecutableInstance::BatchExecute( iree_vm_list_push_ref_move(inv.inputs.get(), &bv_ref)); // Extend the execute wait to include the input's ready signal. - IREE_RETURN_IF_ERROR( - iree_hal_fence_extend(inv.wait_fence.get(), buffer->ready_fence())); + IREE_RETURN_IF_ERROR(IreeApi::hal_fence_extend(inv.wait_fence.get(), + buffer->ready_fence())); // And extend the buffer's done fence to close over this execution. buffer->AdvanceDoneFence(inv.res_exe->device_instance->main_timeline(), @@ -1541,17 +1542,29 @@ iree_status_t LoadedExecutableInstance::BatchExecute( iree_status_t status = iree_ok_status(); for (size_t dev_index = 0; dev_index < args->num_devices; ++dev_index) { auto& inv = invs[dev_index]; - auto new_status = iree_vm_invoke( - inv.res_exe->vm_context.get(), inv.res_exe->main_function, - IREE_VM_INVOCATION_FLAG_NONE, - /*policy=*/nullptr, inv.inputs.get(), inv.outputs.get(), allocator); + if (IreeApi::LOGGING_ENABLED) { + IreeApi::LogInvoke( + "vm_invoke[async]", + "context=%p, f=%d, wait_fence=%p {%s}, signal_fence=%p {%s}", + inv.res_exe->vm_context.get(), + (int)inv.res_exe->main_function.ordinal, inv.wait_fence.get(), + IreeApi::FenceToString(inv.wait_fence.get()).c_str(), + inv.signal_fence.get(), + IreeApi::FenceToString(inv.signal_fence.get()).c_str()); + } + auto new_status = IreeApi::HandleStatus( + "vm_invoke[async]", + iree_vm_invoke(inv.res_exe->vm_context.get(), + inv.res_exe->main_function, IREE_VM_INVOCATION_FLAG_NONE, + /*policy=*/nullptr, inv.inputs.get(), inv.outputs.get(), + allocator)); // Any invocation that fails needs a barrier so that signal fence is // incremented otherwise future waits will fail. We do this instead of // incrementing as only a subset of devices may fail. if (!iree_status_is_ok(new_status)) { status = new_status; // We can ignore the error as we are already erroring out earlier. - IREE_IGNORE_ERROR(iree_hal_device_queue_barrier( + IREE_IGNORE_ERROR(IreeApi::hal_device_queue_barrier( inv.res_exe->device_instance->device(), IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_fence_semaphore_list(inv.wait_fence.get()), iree_hal_fence_semaphore_list(inv.signal_fence.get()))); diff --git a/iree/integrations/pjrt/common/iree_helpers.h b/iree/integrations/pjrt/common/iree_helpers.h new file mode 100644 index 000000000000..b0daea05824c --- /dev/null +++ b/iree/integrations/pjrt/common/iree_helpers.h @@ -0,0 +1,199 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/hal/api.h" + +namespace iree::pjrt { + +// Anonymous namespace containing helpers and wrappers for IREE API +// functions which can perform verbose logging when enabled. These all +// match an IREE api but will have the |iree_| prefix elided, so they are +// used as IreeApi::hal_allocator_allocate_buffer(...), which should be a +// drop-in for iree_hal_allocator_allocate_buffer(...). +namespace IreeApi { +namespace { + +// Controls whether logging is printed to stderr. We may want to make this +// more configurable in the future. +const bool LOGGING_ENABLED = false; + +IREE_PRINTF_ATTRIBUTE(2, 3) +void LogInvoke(const char* func, const char* fmt, ...) { + if (LOGGING_ENABLED) { + fprintf(stderr, ":: IREE INVOKE (%s): ", func); + va_list args; + va_start(args, fmt); + vfprintf(stderr, fmt, args); + va_end(args); + fflush(stderr); + } +} +iree_status_t HandleStatus(const char* func, iree_status_t status) { + if (LOGGING_ENABLED) { + if (!iree_status_is_ok(status)) { + fprintf(stderr, " ("); + iree_status_fprint(stderr, status); + fprintf(stderr, ")\n"); + } else { + fprintf(stderr, " (OK)\n"); + } + } + return status; +} +std::string SemaphoreListToString(const iree_hal_semaphore_list_t sl) { + std::string result; + char fmtBuffer[64]; + for (iree_host_size_t i = 0; i < sl.count; ++i) { + snprintf(fmtBuffer, sizeof(fmtBuffer), "%p:%" PRIu64, sl.semaphores[i], + sl.payload_values[i]); + if (i > 0) { + result.append(", "); + } + result.append(fmtBuffer); + } + return result; +} +std::string FenceToString(iree_hal_fence_t* fence) { + return SemaphoreListToString(iree_hal_fence_semaphore_list(fence)); +} + +iree_status_t hal_allocator_allocate_buffer( + iree_hal_allocator_t* IREE_RESTRICT allocator, + iree_hal_buffer_params_t params, iree_device_size_t allocation_size, + iree_const_byte_span_t initial_data, iree_hal_buffer_t** out_buffer) { + auto status = iree_hal_allocator_allocate_buffer( + allocator, params, allocation_size, initial_data, out_buffer); + if (LOGGING_ENABLED) { + LogInvoke(__func__, "allocator=%p, size=%zu, buffer=%p", allocator, + (size_t)allocation_size, *out_buffer); + } + return HandleStatus(__func__, status); +} + +iree_status_t hal_allocator_import_buffer( + iree_hal_allocator_t* IREE_RESTRICT allocator, + iree_hal_buffer_params_t params, + iree_hal_external_buffer_t* IREE_RESTRICT external_buffer, + iree_hal_buffer_release_callback_t release_callback, + iree_hal_buffer_t** out_buffer) { + if (LOGGING_ENABLED) { + LogInvoke(__func__, "external_buffer=%p", external_buffer); + } + return HandleStatus(__func__, iree_hal_allocator_import_buffer( + allocator, params, external_buffer, + release_callback, out_buffer)); +} + +iree_status_t hal_device_queue_alloca( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_allocator_pool_t pool, iree_hal_buffer_params_t params, + iree_device_size_t allocation_size, + iree_hal_buffer_t** IREE_RESTRICT out_buffer) { + if (LOGGING_ENABLED) { + LogInvoke(__func__, "device=%p, size=%zd, wait={%s}, signal={%s}", device, + (size_t)allocation_size, + SemaphoreListToString(wait_semaphore_list).c_str(), + SemaphoreListToString(signal_semaphore_list).c_str()); + } + return HandleStatus(__func__, iree_hal_device_queue_alloca( + device, queue_affinity, wait_semaphore_list, + signal_semaphore_list, pool, params, + allocation_size, out_buffer)); +} + +iree_status_t hal_device_queue_dealloca( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* buffer) { + if (LOGGING_ENABLED) { + LogInvoke(__func__, "device=%p, buffer=%p, wait={%s}, signal={%s}", device, + buffer, SemaphoreListToString(wait_semaphore_list).c_str(), + SemaphoreListToString(signal_semaphore_list).c_str()); + } + return HandleStatus(__func__, iree_hal_device_queue_dealloca( + device, queue_affinity, wait_semaphore_list, + signal_semaphore_list, buffer)); +} + +iree_status_t hal_device_queue_barrier( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list) { + if (LOGGING_ENABLED) { + LogInvoke(__func__, "device=%p, wait={%s}, signal={%s}", device, + SemaphoreListToString(wait_semaphore_list).c_str(), + SemaphoreListToString(signal_semaphore_list).c_str()); + } + return HandleStatus(__func__, iree_hal_device_queue_barrier( + device, queue_affinity, wait_semaphore_list, + signal_semaphore_list)); +} + +iree_status_t hal_device_queue_execute( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_host_size_t command_buffer_count, + iree_hal_command_buffer_t* const* command_buffers) { + if (LOGGING_ENABLED) { + LogInvoke(__func__, "device=%p, wait={%s}, signal={%s}", device, + SemaphoreListToString(wait_semaphore_list).c_str(), + SemaphoreListToString(signal_semaphore_list).c_str()); + } + return HandleStatus(__func__, iree_hal_device_queue_execute( + device, queue_affinity, wait_semaphore_list, + signal_semaphore_list, command_buffer_count, + command_buffers)); +} + +iree_status_t hal_fence_create(iree_host_size_t capacity, + iree_allocator_t host_allocator, + iree_hal_fence_t** out_fence) { + auto status = iree_hal_fence_create(capacity, host_allocator, out_fence); + if (LOGGING_ENABLED) { + LogInvoke(__func__, "capacity=%zu, fence=%p", (size_t)capacity, *out_fence); + } + return HandleStatus(__func__, status); +} + +iree_status_t hal_fence_create_at(iree_hal_semaphore_t* semaphore, + uint64_t value, + iree_allocator_t host_allocator, + iree_hal_fence_t** out_fence) { + auto status = + iree_hal_fence_create_at(semaphore, value, host_allocator, out_fence); + if (LOGGING_ENABLED) { + LogInvoke(__func__, "semaphore=%p, value=%" PRIu64 ", fence=%p", semaphore, + value, *out_fence); + } + return HandleStatus(__func__, status); +} + +iree_status_t hal_fence_extend(iree_hal_fence_t* into_fence, + iree_hal_fence_t* from_fence) { + if (LOGGING_ENABLED) { + LogInvoke(__func__, "into_fence=%p, from_fence=%p", into_fence, from_fence); + } + return HandleStatus(__func__, iree_hal_fence_extend(into_fence, from_fence)); +} + +iree_status_t hal_fence_insert(iree_hal_fence_t* fence, + iree_hal_semaphore_t* semaphore, + uint64_t value) { + if (LOGGING_ENABLED) { + LogInvoke(__func__, "fence=%p, semaphore=%p, value=%" PRIu64, fence, + semaphore, value); + } + return HandleStatus(__func__, iree_hal_fence_insert(fence, semaphore, value)); +} + +} // namespace +} // namespace IreeApi + +} // namespace iree::pjrt diff --git a/iree/integrations/pjrt/cpu/BUILD b/iree/integrations/pjrt/cpu/BUILD index 3f449403b9f7..e5510c1ad2ae 100644 --- a/iree/integrations/pjrt/cpu/BUILD +++ b/iree/integrations/pjrt/cpu/BUILD @@ -16,6 +16,7 @@ iree_pjrt_cc_library( ], deps = [ "//iree/integrations/pjrt/common:impl", + "@iree_core//runtime/src/iree/hal/drivers/local_sync:sync_driver", "@iree_core//runtime/src/iree/hal/drivers/local_task:task_driver", "@iree_core//runtime/src/iree/hal/local", "@iree_core//runtime/src/iree/hal/local:executable_loader", diff --git a/iree/integrations/pjrt/cpu/client.cc b/iree/integrations/pjrt/cpu/client.cc index 6d831d6178c0..94f09dadbbf5 100644 --- a/iree/integrations/pjrt/cpu/client.cc +++ b/iree/integrations/pjrt/cpu/client.cc @@ -6,6 +6,7 @@ #include "iree/integrations/pjrt/cpu/client.h" +#include "iree/hal/drivers/local_sync/sync_driver.h" #include "iree/hal/drivers/local_task/task_driver.h" #include "iree/hal/local/plugins/registration/init.h" #include "iree/task/api.h" @@ -19,30 +20,21 @@ CPUClientInstance::CPUClientInstance(std::unique_ptr platform) // TODO: Get this when constructing the client so it is guaranteed to // match. cached_platform_name_ = "iree_cpu"; - iree_hal_task_device_params_initialize(&device_params_); - iree_task_executor_options_initialize(&task_executor_options_); iree_task_topology_initialize(&task_topology_options_); } CPUClientInstance::~CPUClientInstance() { iree_hal_allocator_release(device_allocator_); - iree_task_executor_release(executor_); + if (executor_) iree_task_executor_release(executor_); for (iree_host_size_t i = 0; i < loader_count_; ++i) { iree_hal_executable_loader_release(loaders_[i]); } - iree_hal_executable_plugin_manager_release(plugin_manager_); + if (plugin_manager_) + iree_hal_executable_plugin_manager_release(plugin_manager_); iree_task_topology_deinitialize(&task_topology_options_); } iree_status_t CPUClientInstance::InitializeDeps() { - // executor options and topology options. Getting these from flags is not - // great for this use since there is no way to set the flags :/ - IREE_RETURN_IF_ERROR(iree_task_executor_options_initialize_from_flags( - &task_executor_options_)); - // TODO: Do something smarter than pinning to NUMA node 0. - IREE_RETURN_IF_ERROR(iree_task_topology_initialize_from_flags( - /*node_id=*/0, &task_topology_options_)); - // plugin_manager_ IREE_RETURN_IF_ERROR(iree_hal_executable_plugin_manager_create( /*capacity=*/0, host_allocator_, &plugin_manager_)); @@ -56,12 +48,6 @@ iree_status_t CPUClientInstance::InitializeDeps() { IREE_RETURN_IF_ERROR(iree_hal_allocator_create_heap( iree_make_cstring_view("local"), host_allocator_, host_allocator_, &device_allocator_)); - - // executor_ - IREE_RETURN_IF_ERROR(iree_task_executor_create(task_executor_options_, - &task_topology_options_, - host_allocator_, &executor_)); - return iree_ok_status(); } @@ -73,9 +59,36 @@ iree_status_t CPUClientInstance::CreateDriver(iree_hal_driver_t** out_driver) { IREE_RETURN_IF_ERROR(InitializeDeps()); // driver - IREE_RETURN_IF_ERROR(iree_hal_task_driver_create( - IREE_SV("local-task"), &device_params_, /*queue_count=*/1, &executor_, - loader_count_, loaders_, device_allocator_, host_allocator_, out_driver)); + if (single_threaded_debug_) { + logger().debug("Creating single threaded CPU driver (debugging)"); + iree_hal_sync_device_params_t sync_params; + iree_hal_sync_device_params_initialize(&sync_params); + IREE_RETURN_IF_ERROR(iree_hal_sync_driver_create( + IREE_SV("local-sync"), &sync_params, loader_count_, loaders_, + device_allocator_, host_allocator_, out_driver)); + } else { + iree_task_executor_options_t task_executor_options; + iree_hal_task_device_params_t task_params; + iree_task_executor_options_initialize(&task_executor_options); + iree_hal_task_device_params_initialize(&task_params); + + // executor options and topology options. Getting these from flags is not + // great for this use since there is no way to set the flags :/ + IREE_RETURN_IF_ERROR(iree_task_executor_options_initialize_from_flags( + &task_executor_options)); + // TODO: Do something smarter than pinning to NUMA node 0. + IREE_RETURN_IF_ERROR(iree_task_topology_initialize_from_flags( + /*node_id=*/0, &task_topology_options_)); + + IREE_RETURN_IF_ERROR(iree_task_executor_create( + task_executor_options, &task_topology_options_, host_allocator_, + &executor_)); + + IREE_RETURN_IF_ERROR(iree_hal_task_driver_create( + IREE_SV("local-task"), &task_params, /*queue_count=*/1, &executor_, + loader_count_, loaders_, device_allocator_, host_allocator_, + out_driver)); + } logger().debug("CPU driver created"); return iree_ok_status(); diff --git a/iree/integrations/pjrt/cpu/client.h b/iree/integrations/pjrt/cpu/client.h index 787e70a88cbb..b1f73931d328 100644 --- a/iree/integrations/pjrt/cpu/client.h +++ b/iree/integrations/pjrt/cpu/client.h @@ -23,9 +23,8 @@ class CPUClientInstance final : public ClientInstance { private: iree_status_t InitializeDeps(); - // Options. - iree_hal_task_device_params_t device_params_; - iree_task_executor_options_t task_executor_options_; + // Instance scoped options. + bool single_threaded_debug_ = false; iree_task_topology_t task_topology_options_; // Deps.