Skip to content

Commit

Permalink
[xla:cpu] Optimize KernelThunk by passing SE_HOST_KernelArg directly …
Browse files Browse the repository at this point in the history
…to the kernel

PiperOrigin-RevId: 646311089
  • Loading branch information
ezhulenev authored and copybara-github committed Jun 25, 2024
1 parent 223bb06 commit 8a91f59
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ static void BM_SelectAndScatterF32(benchmark::State& state) {

BENCHMARK(BM_SelectAndScatterF32)
->MeasureProcessCPUTime()
->Arg(64)
->Arg(128)
->Arg(256)
->Arg(512);
->Arg(512)
->Arg(1024);

} // namespace xla::cpu
53 changes: 29 additions & 24 deletions xla/service/cpu/runtime/kernel_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/numeric/bits.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive
#include "xla/runtime/buffer_use.h"
Expand Down Expand Up @@ -86,67 +87,71 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
kernel_name_, arguments_buffers_.size(), results_buffers_.size(),
thread_dim_.ToString());

absl::InlinedVector<se::DeviceMemoryBase, 8> buffers_data;
buffers_data.reserve(arguments_buffers_.size() + results_buffers_.size());
absl::InlinedVector<SE_HOST_KernelArg, 8> kernel_args;
kernel_args.reserve(arguments_buffers_.size() + results_buffers_.size());

int64_t arg_num = 0;
for (BufferAllocation::Slice& buffer : arguments_buffers_) {
TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(),
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase arg_data,
params.buffer_allocations->GetDeviceAddress(buffer));
kernel_args.push_back(
SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()});
VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", arg_num++,
buffer.ToString(),
buffers_data.back().opaque());
buffer.ToString(), kernel_args.back().data);
}

int64_t res_num = 0;
for (BufferAllocation::Slice& buffer : results_buffers_) {
TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(),
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result_data,
params.buffer_allocations->GetDeviceAddress(buffer));
kernel_args.push_back(
SE_HOST_KernelArg{result_data.opaque(), result_data.size()});
VLOG(3) << absl::StreamFormat(" res #%d: %s (%p)", res_num++,
buffer.ToString(),
buffers_data.back().opaque());
buffer.ToString(), kernel_args.back().data);
}

// Check that all buffers are aligned to the minimum alignment. We codegen
// with the assumption that all buffers are aligned, and if they are not, we
// will crash with a segmentation fault, or worse, produce incorrect results.
if (min_alignment_.has_value()) {
for (int64_t i = 0; i < buffers_data.size(); ++i) {
auto ptr = reinterpret_cast<uintptr_t>(buffers_data[i].opaque());
for (int64_t i = 0; i < kernel_args.size(); ++i) {
auto ptr = reinterpret_cast<uintptr_t>(kernel_args[i].data);
if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) {
return Internal(
"Host kernel %s buffer argument #%d (%p) is not aligned to a "
"required minimum alignment of %d bytes",
info().op_name, i, buffers_data[i].opaque(), *min_alignment_);
info().op_name, i, kernel_args[i].data, *min_alignment_);
}
}
}

// TODO(ezhulenev): Kernel ptr should be loaded as a part of Thunk
// initialization stage.
SE_HOST_Kernel* kernel_ptr = kernel_ptr_.load();
se::host::HostKernel* kernel = kernel_ptr_.load();

// Because thunks are owned by a parent CpuExecutable, we can safely assume
// that kernel pointer will not change after we find it the first time.
if (kernel_ptr == nullptr) {
TF_ASSIGN_OR_RETURN(kernel_ptr, params.host_kernels->Find(kernel_name_));
kernel_ptr_.store(kernel_ptr);
}
if (ABSL_PREDICT_FALSE(kernel == nullptr)) {
TF_ASSIGN_OR_RETURN(SE_HOST_Kernel * kernel_fn,
params.host_kernels->Find(kernel_name_));

se::host::HostKernel kernel(buffers_data.size(), kernel_ptr, nullptr);
absl::MutexLock lock(&mutex_);
kernel_.emplace(kernel_args.size(), kernel_fn, nullptr);
kernel_ptr_.store(kernel = &kernel_.value());
}

// If intra-op thread pool is not nullptr, we launch HostKernel in async mode
// by scheduling tasks into it. HostKernel launch completion will
// automatically signal KernelThunk execute completion.
if (params.intra_op_threadpool && use_task_runner_) {
return kernel.Launch(thread_dim_, buffers_data,
[&params](se::host::HostKernel::Task task) {
params.intra_op_threadpool->getPool()->Schedule(
ToCopyableTask(std::move(task)));
});
if (ABSL_PREDICT_FALSE(params.intra_op_threadpool && use_task_runner_)) {
return kernel->Launch(thread_dim_, kernel_args,
[&params](se::host::HostKernel::Task task) {
params.intra_op_threadpool->getPool()->Schedule(
ToCopyableTask(std::move(task)));
});
}

TF_RETURN_IF_ERROR(kernel.Launch(thread_dim_, buffers_data));
TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, kernel_args));
return OkExecuteEvent();
}

Expand Down
14 changes: 7 additions & 7 deletions xla/service/cpu/runtime/kernel_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ limitations under the License.
#include <string>
#include <vector>

#include "absl/base/thread_annotations.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/cpu/runtime/thunk.h"
#include "xla/stream_executor/host/host_kernel_c_api.h"
#include "xla/stream_executor/host/host_kernel.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/tsl/concurrency/async_value_ref.h"

Expand Down Expand Up @@ -64,12 +66,10 @@ class KernelThunk final : public Thunk {
// launch the kernel directly in the caller thread.
bool use_task_runner_;

// Pointer to the host kernel corresponding to `kernel_name_`. Initialized
// lazily at run time by looking it up in the HostKernels passed via params.
//
// TODO(ezhulenev): This should be moved to initialization stage when we'll
// have it for CPU thunks.
std::atomic<SE_HOST_Kernel*> kernel_ptr_;
// Lazily loaded host kernel corresponding to `kernel_name_`.
absl::Mutex mutex_;
std::optional<se::host::HostKernel> kernel_ ABSL_GUARDED_BY(mutex_);
std::atomic<se::host::HostKernel*> kernel_ptr_; // pointer to `kernel_`
};

} // namespace xla::cpu
Expand Down
33 changes: 24 additions & 9 deletions xla/stream_executor/host/host_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class HostKernelExecuteState
HostKernelExecuteState(HostKernel::TaskRunner task_runner,
HostKernel::KernelFunction* function,
ThreadDim thread_dims,
absl::Span<const DeviceMemoryBase> buffers);
absl::Span<const SE_HOST_KernelArg> args);

// Notify of a completion of a host kernel task.
void Notify(absl::Status status);
Expand Down Expand Up @@ -118,11 +118,19 @@ HostKernel::HostKernel(unsigned arity, SE_HOST_Kernel* kernel,
absl::Status HostKernel::Launch(
const ThreadDim& thread_dims,
absl::Span<const DeviceMemoryBase> buffers) const {
SE_HOST_KernelThreadDim kernel_thread_dims = {thread_dims.x, thread_dims.y,
thread_dims.z};
return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers));
}

absl::Status HostKernel::Launch(
const ThreadDim& thread_dims,
absl::Span<const SE_HOST_KernelArg> args) const {
SE_HOST_KernelThreadDim kernel_thread_dims = {
thread_dims.x,
thread_dims.y,
thread_dims.z,
};

SE_HOST_Kernel* kernel = function_->kernel();
auto args = ConvertBuffersToKernelArgs(buffers);

for (uint64_t z = 0; z < thread_dims.z; ++z) {
for (uint64_t y = 0; y < thread_dims.y; ++y) {
Expand All @@ -134,7 +142,7 @@ absl::Status HostKernel::Launch(

SE_HOST_KernelError* error = (*kernel)(&call_frame);

if (error != nullptr) {
if (ABSL_PREDICT_FALSE(error != nullptr)) {
return absl::InternalError("Failed to call host kernel");
}
}
Expand All @@ -147,20 +155,27 @@ absl::Status HostKernel::Launch(
tsl::AsyncValueRef<LaunchEvent> HostKernel::Launch(
const ThreadDim& thread_dims, absl::Span<const DeviceMemoryBase> buffers,
TaskRunner task_runner) const {
return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers),
std::move(task_runner));
}

tsl::AsyncValueRef<LaunchEvent> HostKernel::Launch(
const ThreadDim& thread_dims, absl::Span<const SE_HOST_KernelArg> args,
TaskRunner task_runner) const {
size_t num_tasks = thread_dims.x * thread_dims.y * thread_dims.z;
CHECK_GT(num_tasks, 0) << "Number of tasks must be positive"; // Crash Ok

// Short-circuit launch with a single task and run it in the caller thread.
if (ABSL_PREDICT_TRUE(num_tasks == 1)) {
absl::Status launched = Launch(thread_dims, buffers);
absl::Status launched = Launch(thread_dims, args);
return ABSL_PREDICT_TRUE(launched.ok())
? OkLaunchEvent()
: tsl::MakeErrorAsyncValueRef(std::move(launched));
}

// Allocate a control structure that will orchestrate kernel execution.
auto state = tsl::MakeRef<HostKernelExecuteState>(
std::move(task_runner), function_.get(), thread_dims, buffers);
std::move(task_runner), function_.get(), thread_dims, args);

state->CallAsync(/*start_index=*/0, /*end_index=*/num_tasks);

Expand All @@ -169,12 +184,12 @@ tsl::AsyncValueRef<LaunchEvent> HostKernel::Launch(

HostKernelExecuteState::HostKernelExecuteState(
HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function,
ThreadDim thread_dims, absl::Span<const DeviceMemoryBase> buffers)
ThreadDim thread_dims, absl::Span<const SE_HOST_KernelArg> args)
: task_runner_(std::move(task_runner)),
num_tasks_(thread_dims.x * thread_dims.y * thread_dims.z),
kernel_(function->kernel()),
thread_dims_({thread_dims.x, thread_dims.y, thread_dims.z}),
args_(ConvertBuffersToKernelArgs(buffers)),
args_(args.begin(), args.end()),
abort_(false),
counter_(num_tasks_),
event_(tsl::MakeConstructedAsyncValueRef<LaunchEvent>()) {}
Expand Down
5 changes: 5 additions & 0 deletions xla/stream_executor/host/host_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class HostKernel : public Kernel {
// `thread_dims` and calling the kernel function.
absl::Status Launch(const ThreadDim& thread_dims,
absl::Span<const DeviceMemoryBase> buffers) const;
absl::Status Launch(const ThreadDim& thread_dims,
absl::Span<const SE_HOST_KernelArg> args) const;

// Launches the kernel by iterating over all threads in `thread_dims` and
// calling `task_runner` to run individual task (implementation might decide
Expand All @@ -93,6 +95,9 @@ class HostKernel : public Kernel {
tsl::AsyncValueRef<LaunchEvent> Launch(
const ThreadDim& thread_dims, absl::Span<const DeviceMemoryBase> buffers,
TaskRunner task_runner) const;
tsl::AsyncValueRef<LaunchEvent> Launch(
const ThreadDim& thread_dims, absl::Span<const SE_HOST_KernelArg> args,
TaskRunner task_runner) const;

// For host platform, we assume that a core is a thread, and we can run at
// most one instance of a kernel on a given thread.
Expand Down
2 changes: 1 addition & 1 deletion xla/stream_executor/host/host_kernel_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ typedef struct SE_HOST_KernelCallFrame {
SE_HOST_KernelThread* thread;

size_t num_args;
SE_HOST_KernelArg* args;
const SE_HOST_KernelArg* args;
} SE_HOST_KernelCallFrame;

// Error reporting for host kernels. NULL means success.
Expand Down

0 comments on commit 8a91f59

Please sign in to comment.