Skip to content

Commit

Permalink
Adds IREE API logging and single threaded CPU debug options. (#63)
Browse files Browse the repository at this point in the history
I had built these out while working on root causing a regression.
Cleaned them up and mainlining them. They are controlled by compile time
variables for the moment and we can do something smarter later.
  • Loading branch information
Stella Laurenzo authored Apr 20, 2023
1 parent e9b15e0 commit dace0f0
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 44 deletions.
1 change: 1 addition & 0 deletions iree/integrations/pjrt/common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ iree_pjrt_cc_library(
hdrs = [
"api_impl.h",
"dylib_entry_point.cc.inc",
"iree_helpers.h",
"platform.h",
],
deps = [
Expand Down
53 changes: 33 additions & 20 deletions iree/integrations/pjrt/common/api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "iree/base/tracing.h"
#include "iree/hal/api.h"
#include "iree/integrations/pjrt/common/iree_helpers.h"
#include "iree/integrations/pjrt/common/tensor_utils.h"

using iree::vm::retain_ref;
Expand Down Expand Up @@ -461,7 +462,7 @@ iree_status_t BufferInstance::GetHostSizeInBytes(iree_host_size_t* host_size) {

iree_status_t BufferInstance::AsyncDeallocate() {
IREE_TRACE_SCOPE();
return iree_hal_device_queue_dealloca(
return IreeApi::hal_device_queue_dealloca(
device().device(), IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/iree_hal_fence_semaphore_list(done_fence()),
/*signal_semaphore_list=*/iree_hal_semaphore_list_empty(),
Expand Down Expand Up @@ -535,7 +536,7 @@ iree_status_t BufferInstance::CopyToHost(void* dst, iree_host_size_t dst_size,
copy_data->event->ExternalSignalReady(iree_ok_status());
delete copy_data;
};
IREE_RETURN_IF_ERROR(iree_hal_allocator_import_buffer(
IREE_RETURN_IF_ERROR(IreeApi::hal_allocator_import_buffer(
device_.device_allocator(), dst_buffer_params, &dst_external_buffer,
/*release_callback=*/{release_callback, copy_to_host_data}, &dst_buffer));

Expand All @@ -556,7 +557,7 @@ iree_status_t BufferInstance::CopyToHost(void* dst, iree_host_size_t dst_size,
/*transfer_count=*/1, &transfer_command, &transfer_cb));
dst_buffer.reset();

IREE_RETURN_IF_ERROR(iree_hal_device_queue_execute(
IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_execute(
device_.device(), IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/iree_hal_fence_semaphore_list(ready_fence_.get()),
/*signal_semaphore_list=*/iree_hal_semaphore_list_empty(),
Expand All @@ -567,12 +568,12 @@ iree_status_t BufferInstance::CopyToHost(void* dst, iree_host_size_t dst_size,

iree_status_t BufferInstance::AdvanceReadyFence(iree_hal_semaphore_t* semaphore,
uint64_t timepoint) {
return iree_hal_fence_insert(ready_fence_.get(), semaphore, timepoint);
return IreeApi::hal_fence_insert(ready_fence_.get(), semaphore, timepoint);
}

iree_status_t BufferInstance::AdvanceDoneFence(iree_hal_semaphore_t* semaphore,
uint64_t timepoint) {
return iree_hal_fence_insert(done_fence_.get(), semaphore, timepoint);
return IreeApi::hal_fence_insert(done_fence_.get(), semaphore, timepoint);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -634,8 +635,8 @@ void DeviceInstance::BindApi(PJRT_Api* api) {
}

iree_status_t DeviceInstance::CreateFence(iree_hal_fence_t** out_fence) {
return iree_hal_fence_create(/*capacity=*/2, client_.host_allocator(),
out_fence);
return IreeApi::hal_fence_create(/*capacity=*/2, client_.host_allocator(),
out_fence);
}

iree_status_t DeviceInstance::OpenDevice() {
Expand Down Expand Up @@ -767,7 +768,7 @@ iree_status_t DeviceInstance::HostBufferToDevice(
params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
params.usage =
IREE_HAL_BUFFER_USAGE_DEFAULT | IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET;
IREE_RETURN_IF_ERROR(iree_hal_device_queue_alloca(
IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_alloca(
device(), IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/
{1, &transfer_timeline_, &wait_transfer_start},
Expand Down Expand Up @@ -803,14 +804,14 @@ iree_status_t DeviceInstance::HostBufferToDevice(
}

if (has_zero_length) {
IREE_RETURN_IF_ERROR(iree_hal_device_queue_barrier(
IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_barrier(
device(), IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/
{1, &transfer_timeline_, &signal_alloca_complete},
/*signal_semaphore_list=*/
{1, &transfer_timeline_, &signal_copy_complete}));
} else {
IREE_RETURN_IF_ERROR(iree_hal_device_queue_execute(
IREE_RETURN_IF_ERROR(IreeApi::hal_device_queue_execute(
device(), IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/
{1, &transfer_timeline_, &signal_alloca_complete},
Expand Down Expand Up @@ -853,7 +854,7 @@ iree_status_t DeviceInstance::AcquireHostStagingBuffer(
memset(&params, 0, sizeof(params));
params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE;
params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER;
IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
IREE_RETURN_IF_ERROR(IreeApi::hal_allocator_allocate_buffer(
device_allocator(), params, initial_contents.data_length,
initial_contents, out_buffer));
// We did a synchronous snapshot (memcpy).
Expand Down Expand Up @@ -1494,13 +1495,13 @@ iree_status_t LoadedExecutableInstance::BatchExecute(
// semaphores.
IREE_RETURN_IF_ERROR(
inv.res_exe->device_instance->CreateFence(&inv.wait_fence));
IREE_RETURN_IF_ERROR(iree_hal_fence_insert(
IREE_RETURN_IF_ERROR(IreeApi::hal_fence_insert(
inv.wait_fence.get(), inv.res_exe->device_instance->main_timeline(),
wait_timepoint));

// Signal fence. This signals the next tick on the main execution
// timeline.
IREE_RETURN_IF_ERROR(iree_hal_fence_create_at(
IREE_RETURN_IF_ERROR(IreeApi::hal_fence_create_at(
inv.res_exe->device_instance->main_timeline(), signal_timepoint,
client_.host_allocator(), &inv.signal_fence));

Expand All @@ -1520,8 +1521,8 @@ iree_status_t LoadedExecutableInstance::BatchExecute(
iree_vm_list_push_ref_move(inv.inputs.get(), &bv_ref));

// Extend the execute wait to include the input's ready signal.
IREE_RETURN_IF_ERROR(
iree_hal_fence_extend(inv.wait_fence.get(), buffer->ready_fence()));
IREE_RETURN_IF_ERROR(IreeApi::hal_fence_extend(inv.wait_fence.get(),
buffer->ready_fence()));

// And extend the buffer's done fence to close over this execution.
buffer->AdvanceDoneFence(inv.res_exe->device_instance->main_timeline(),
Expand All @@ -1541,17 +1542,29 @@ iree_status_t LoadedExecutableInstance::BatchExecute(
iree_status_t status = iree_ok_status();
for (size_t dev_index = 0; dev_index < args->num_devices; ++dev_index) {
auto& inv = invs[dev_index];
auto new_status = iree_vm_invoke(
inv.res_exe->vm_context.get(), inv.res_exe->main_function,
IREE_VM_INVOCATION_FLAG_NONE,
/*policy=*/nullptr, inv.inputs.get(), inv.outputs.get(), allocator);
if (IreeApi::LOGGING_ENABLED) {
IreeApi::LogInvoke(
"vm_invoke[async]",
"context=%p, f=%d, wait_fence=%p {%s}, signal_fence=%p {%s}",
inv.res_exe->vm_context.get(),
(int)inv.res_exe->main_function.ordinal, inv.wait_fence.get(),
IreeApi::FenceToString(inv.wait_fence.get()).c_str(),
inv.signal_fence.get(),
IreeApi::FenceToString(inv.signal_fence.get()).c_str());
}
auto new_status = IreeApi::HandleStatus(
"vm_invoke[async]",
iree_vm_invoke(inv.res_exe->vm_context.get(),
inv.res_exe->main_function, IREE_VM_INVOCATION_FLAG_NONE,
/*policy=*/nullptr, inv.inputs.get(), inv.outputs.get(),
allocator));
// Any invocation that fails needs a barrier so that signal fence is
// incremented otherwise future waits will fail. We do this instead of
// incrementing as only a subset of devices may fail.
if (!iree_status_is_ok(new_status)) {
status = new_status;
// We can ignore the error as we are already erroring out earlier.
IREE_IGNORE_ERROR(iree_hal_device_queue_barrier(
IREE_IGNORE_ERROR(IreeApi::hal_device_queue_barrier(
inv.res_exe->device_instance->device(), IREE_HAL_QUEUE_AFFINITY_ANY,
iree_hal_fence_semaphore_list(inv.wait_fence.get()),
iree_hal_fence_semaphore_list(inv.signal_fence.get())));
Expand Down
199 changes: 199 additions & 0 deletions iree/integrations/pjrt/common/iree_helpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/hal/api.h"

namespace iree::pjrt {

// Anonymous namespace containing helpers and wrappers for IREE API
// functions which can perform verbose logging when enabled. These all
// match an IREE api but will have the |iree_| prefix elided, so they are
// used as IreeApi::hal_allocator_allocate_buffer(...), which should be a
// drop-in for iree_hal_allocator_allocate_buffer(...).
namespace IreeApi {
namespace {

// Controls whether logging is printed to stderr. We may want to make this
// more configurable in the future.
const bool LOGGING_ENABLED = false;

IREE_PRINTF_ATTRIBUTE(2, 3)
void LogInvoke(const char* func, const char* fmt, ...) {
if (LOGGING_ENABLED) {
fprintf(stderr, ":: IREE INVOKE (%s): ", func);
va_list args;
va_start(args, fmt);
vfprintf(stderr, fmt, args);
va_end(args);
fflush(stderr);
}
}
iree_status_t HandleStatus(const char* func, iree_status_t status) {
if (LOGGING_ENABLED) {
if (!iree_status_is_ok(status)) {
fprintf(stderr, " (");
iree_status_fprint(stderr, status);
fprintf(stderr, ")\n");
} else {
fprintf(stderr, " (OK)\n");
}
}
return status;
}
std::string SemaphoreListToString(const iree_hal_semaphore_list_t sl) {
std::string result;
char fmtBuffer[64];
for (iree_host_size_t i = 0; i < sl.count; ++i) {
snprintf(fmtBuffer, sizeof(fmtBuffer), "%p:%" PRIu64, sl.semaphores[i],
sl.payload_values[i]);
if (i > 0) {
result.append(", ");
}
result.append(fmtBuffer);
}
return result;
}
std::string FenceToString(iree_hal_fence_t* fence) {
return SemaphoreListToString(iree_hal_fence_semaphore_list(fence));
}

iree_status_t hal_allocator_allocate_buffer(
iree_hal_allocator_t* IREE_RESTRICT allocator,
iree_hal_buffer_params_t params, iree_device_size_t allocation_size,
iree_const_byte_span_t initial_data, iree_hal_buffer_t** out_buffer) {
auto status = iree_hal_allocator_allocate_buffer(
allocator, params, allocation_size, initial_data, out_buffer);
if (LOGGING_ENABLED) {
LogInvoke(__func__, "allocator=%p, size=%zu, buffer=%p", allocator,
(size_t)allocation_size, *out_buffer);
}
return HandleStatus(__func__, status);
}

iree_status_t hal_allocator_import_buffer(
iree_hal_allocator_t* IREE_RESTRICT allocator,
iree_hal_buffer_params_t params,
iree_hal_external_buffer_t* IREE_RESTRICT external_buffer,
iree_hal_buffer_release_callback_t release_callback,
iree_hal_buffer_t** out_buffer) {
if (LOGGING_ENABLED) {
LogInvoke(__func__, "external_buffer=%p", external_buffer);
}
return HandleStatus(__func__, iree_hal_allocator_import_buffer(
allocator, params, external_buffer,
release_callback, out_buffer));
}

iree_status_t hal_device_queue_alloca(
iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_hal_allocator_pool_t pool, iree_hal_buffer_params_t params,
iree_device_size_t allocation_size,
iree_hal_buffer_t** IREE_RESTRICT out_buffer) {
if (LOGGING_ENABLED) {
LogInvoke(__func__, "device=%p, size=%zd, wait={%s}, signal={%s}", device,
(size_t)allocation_size,
SemaphoreListToString(wait_semaphore_list).c_str(),
SemaphoreListToString(signal_semaphore_list).c_str());
}
return HandleStatus(__func__, iree_hal_device_queue_alloca(
device, queue_affinity, wait_semaphore_list,
signal_semaphore_list, pool, params,
allocation_size, out_buffer));
}

iree_status_t hal_device_queue_dealloca(
iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_hal_buffer_t* buffer) {
if (LOGGING_ENABLED) {
LogInvoke(__func__, "device=%p, buffer=%p, wait={%s}, signal={%s}", device,
buffer, SemaphoreListToString(wait_semaphore_list).c_str(),
SemaphoreListToString(signal_semaphore_list).c_str());
}
return HandleStatus(__func__, iree_hal_device_queue_dealloca(
device, queue_affinity, wait_semaphore_list,
signal_semaphore_list, buffer));
}

iree_status_t hal_device_queue_barrier(
iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list) {
if (LOGGING_ENABLED) {
LogInvoke(__func__, "device=%p, wait={%s}, signal={%s}", device,
SemaphoreListToString(wait_semaphore_list).c_str(),
SemaphoreListToString(signal_semaphore_list).c_str());
}
return HandleStatus(__func__, iree_hal_device_queue_barrier(
device, queue_affinity, wait_semaphore_list,
signal_semaphore_list));
}

iree_status_t hal_device_queue_execute(
iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_host_size_t command_buffer_count,
iree_hal_command_buffer_t* const* command_buffers) {
if (LOGGING_ENABLED) {
LogInvoke(__func__, "device=%p, wait={%s}, signal={%s}", device,
SemaphoreListToString(wait_semaphore_list).c_str(),
SemaphoreListToString(signal_semaphore_list).c_str());
}
return HandleStatus(__func__, iree_hal_device_queue_execute(
device, queue_affinity, wait_semaphore_list,
signal_semaphore_list, command_buffer_count,
command_buffers));
}

iree_status_t hal_fence_create(iree_host_size_t capacity,
iree_allocator_t host_allocator,
iree_hal_fence_t** out_fence) {
auto status = iree_hal_fence_create(capacity, host_allocator, out_fence);
if (LOGGING_ENABLED) {
LogInvoke(__func__, "capacity=%zu, fence=%p", (size_t)capacity, *out_fence);
}
return HandleStatus(__func__, status);
}

iree_status_t hal_fence_create_at(iree_hal_semaphore_t* semaphore,
uint64_t value,
iree_allocator_t host_allocator,
iree_hal_fence_t** out_fence) {
auto status =
iree_hal_fence_create_at(semaphore, value, host_allocator, out_fence);
if (LOGGING_ENABLED) {
LogInvoke(__func__, "semaphore=%p, value=%" PRIu64 ", fence=%p", semaphore,
value, *out_fence);
}
return HandleStatus(__func__, status);
}

iree_status_t hal_fence_extend(iree_hal_fence_t* into_fence,
iree_hal_fence_t* from_fence) {
if (LOGGING_ENABLED) {
LogInvoke(__func__, "into_fence=%p, from_fence=%p", into_fence, from_fence);
}
return HandleStatus(__func__, iree_hal_fence_extend(into_fence, from_fence));
}

iree_status_t hal_fence_insert(iree_hal_fence_t* fence,
iree_hal_semaphore_t* semaphore,
uint64_t value) {
if (LOGGING_ENABLED) {
LogInvoke(__func__, "fence=%p, semaphore=%p, value=%" PRIu64, fence,
semaphore, value);
}
return HandleStatus(__func__, iree_hal_fence_insert(fence, semaphore, value));
}

} // namespace
} // namespace IreeApi

} // namespace iree::pjrt
1 change: 1 addition & 0 deletions iree/integrations/pjrt/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ iree_pjrt_cc_library(
],
deps = [
"//iree/integrations/pjrt/common:impl",
"@iree_core//runtime/src/iree/hal/drivers/local_sync:sync_driver",
"@iree_core//runtime/src/iree/hal/drivers/local_task:task_driver",
"@iree_core//runtime/src/iree/hal/local",
"@iree_core//runtime/src/iree/hal/local:executable_loader",
Expand Down
Loading

0 comments on commit dace0f0

Please sign in to comment.