Skip to content

Commit

Permalink
[xla:cpu] Optimize KernelThunk by passing SE_HOST_KernelArg directly …
Browse files Browse the repository at this point in the history
…to the kernel

PiperOrigin-RevId: 646311089
  • Loading branch information
ezhulenev authored and copybara-github committed Jun 25, 2024
1 parent f888ff7 commit 48b640c
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ static void BM_SelectAndScatterF32(benchmark::State& state) {

BENCHMARK(BM_SelectAndScatterF32)
->MeasureProcessCPUTime()
->Arg(64)
->Arg(128)
->Arg(256)
->Arg(512)
Expand Down
30 changes: 16 additions & 14 deletions xla/service/cpu/runtime/kernel_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,38 +87,40 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
kernel_name_, arguments_buffers_.size(), results_buffers_.size(),
thread_dim_.ToString());

absl::InlinedVector<se::DeviceMemoryBase, 8> buffers_data;
buffers_data.reserve(arguments_buffers_.size() + results_buffers_.size());
absl::InlinedVector<SE_HOST_KernelArg, 8> kernel_args;
kernel_args.reserve(arguments_buffers_.size() + results_buffers_.size());

int64_t arg_num = 0;
for (BufferAllocation::Slice& buffer : arguments_buffers_) {
TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(),
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase arg_data,
params.buffer_allocations->GetDeviceAddress(buffer));
kernel_args.push_back(
SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()});
VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", arg_num++,
buffer.ToString(),
buffers_data.back().opaque());
buffer.ToString(), kernel_args.back().data);
}

int64_t res_num = 0;
for (BufferAllocation::Slice& buffer : results_buffers_) {
TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(),
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result_data,
params.buffer_allocations->GetDeviceAddress(buffer));
kernel_args.push_back(
SE_HOST_KernelArg{result_data.opaque(), result_data.size()});
VLOG(3) << absl::StreamFormat(" res #%d: %s (%p)", res_num++,
buffer.ToString(),
buffers_data.back().opaque());
buffer.ToString(), kernel_args.back().data);
}

// Check that all buffers are aligned to the minimum alignment. We codegen
// with the assumption that all buffers are aligned, and if they are not, we
// will crash with a segmentation fault, or worse, produce incorrect results.
if (min_alignment_.has_value()) {
for (int64_t i = 0; i < buffers_data.size(); ++i) {
auto ptr = reinterpret_cast<uintptr_t>(buffers_data[i].opaque());
for (int64_t i = 0; i < kernel_args.size(); ++i) {
auto ptr = reinterpret_cast<uintptr_t>(kernel_args[i].data);
if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) {
return Internal(
"Host kernel %s buffer argument #%d (%p) is not aligned to a "
"required minimum alignment of %d bytes",
info().op_name, i, buffers_data[i].opaque(), *min_alignment_);
info().op_name, i, kernel_args[i].data, *min_alignment_);
}
}
}
Expand All @@ -134,22 +136,22 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
params.host_kernels->Find(kernel_name_));

absl::MutexLock lock(&mutex_);
kernel_.emplace(buffers_data.size(), kernel_fn, nullptr);
kernel_.emplace(kernel_args.size(), kernel_fn, nullptr);
kernel_ptr_.store(kernel = &kernel_.value());
}

// If intra-op thread pool is not nullptr, we launch HostKernel in async mode
// by scheduling tasks into it. HostKernel launch completion will
// automatically signal KernelThunk execute completion.
if (ABSL_PREDICT_FALSE(params.intra_op_threadpool && use_task_runner_)) {
return kernel->Launch(thread_dim_, buffers_data,
return kernel->Launch(thread_dim_, kernel_args,
[&params](se::host::HostKernel::Task task) {
params.intra_op_threadpool->getPool()->Schedule(
ToCopyableTask(std::move(task)));
});
}

TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, buffers_data));
TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, kernel_args));
return OkExecuteEvent();
}

Expand Down
33 changes: 24 additions & 9 deletions xla/stream_executor/host/host_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class HostKernelExecuteState
HostKernelExecuteState(HostKernel::TaskRunner task_runner,
HostKernel::KernelFunction* function,
ThreadDim thread_dims,
absl::Span<const DeviceMemoryBase> buffers);
absl::Span<const SE_HOST_KernelArg> args);

// Notify of a completion of a host kernel task.
void Notify(absl::Status status);
Expand Down Expand Up @@ -118,11 +118,19 @@ HostKernel::HostKernel(unsigned arity, SE_HOST_Kernel* kernel,
absl::Status HostKernel::Launch(
const ThreadDim& thread_dims,
absl::Span<const DeviceMemoryBase> buffers) const {
SE_HOST_KernelThreadDim kernel_thread_dims = {thread_dims.x, thread_dims.y,
thread_dims.z};
return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers));
}

absl::Status HostKernel::Launch(
const ThreadDim& thread_dims,
absl::Span<const SE_HOST_KernelArg> args) const {
SE_HOST_KernelThreadDim kernel_thread_dims = {
thread_dims.x,
thread_dims.y,
thread_dims.z,
};

SE_HOST_Kernel* kernel = function_->kernel();
auto args = ConvertBuffersToKernelArgs(buffers);

for (uint64_t z = 0; z < thread_dims.z; ++z) {
for (uint64_t y = 0; y < thread_dims.y; ++y) {
Expand All @@ -134,7 +142,7 @@ absl::Status HostKernel::Launch(

SE_HOST_KernelError* error = (*kernel)(&call_frame);

if (error != nullptr) {
if (ABSL_PREDICT_FALSE(error != nullptr)) {
return absl::InternalError("Failed to call host kernel");
}
}
Expand All @@ -147,20 +155,27 @@ absl::Status HostKernel::Launch(
tsl::AsyncValueRef<LaunchEvent> HostKernel::Launch(
const ThreadDim& thread_dims, absl::Span<const DeviceMemoryBase> buffers,
TaskRunner task_runner) const {
return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers),
std::move(task_runner));
}

tsl::AsyncValueRef<LaunchEvent> HostKernel::Launch(
const ThreadDim& thread_dims, absl::Span<const SE_HOST_KernelArg> args,
TaskRunner task_runner) const {
size_t num_tasks = thread_dims.x * thread_dims.y * thread_dims.z;
CHECK_GT(num_tasks, 0) << "Number of tasks must be positive"; // Crash Ok

// Short-circuit launch with a single task and run it in the caller thread.
if (ABSL_PREDICT_TRUE(num_tasks == 1)) {
absl::Status launched = Launch(thread_dims, buffers);
absl::Status launched = Launch(thread_dims, args);
return ABSL_PREDICT_TRUE(launched.ok())
? OkLaunchEvent()
: tsl::MakeErrorAsyncValueRef(std::move(launched));
}

// Allocate a control structure that will orchestrate kernel execution.
auto state = tsl::MakeRef<HostKernelExecuteState>(
std::move(task_runner), function_.get(), thread_dims, buffers);
std::move(task_runner), function_.get(), thread_dims, args);

state->CallAsync(/*start_index=*/0, /*end_index=*/num_tasks);

Expand All @@ -169,12 +184,12 @@ tsl::AsyncValueRef<LaunchEvent> HostKernel::Launch(

HostKernelExecuteState::HostKernelExecuteState(
HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function,
ThreadDim thread_dims, absl::Span<const DeviceMemoryBase> buffers)
ThreadDim thread_dims, absl::Span<const SE_HOST_KernelArg> args)
: task_runner_(std::move(task_runner)),
num_tasks_(thread_dims.x * thread_dims.y * thread_dims.z),
kernel_(function->kernel()),
thread_dims_({thread_dims.x, thread_dims.y, thread_dims.z}),
args_(ConvertBuffersToKernelArgs(buffers)),
args_(args.begin(), args.end()),
abort_(false),
counter_(num_tasks_),
event_(tsl::MakeConstructedAsyncValueRef<LaunchEvent>()) {}
Expand Down
5 changes: 5 additions & 0 deletions xla/stream_executor/host/host_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class HostKernel : public Kernel {
// `thread_dims` and calling the kernel function.
absl::Status Launch(const ThreadDim& thread_dims,
absl::Span<const DeviceMemoryBase> buffers) const;
absl::Status Launch(const ThreadDim& thread_dims,
absl::Span<const SE_HOST_KernelArg> args) const;

// Launches the kernel by iterating over all threads in `thread_dims` and
// calling `task_runner` to run individual task (implementation might decide
Expand All @@ -93,6 +95,9 @@ class HostKernel : public Kernel {
tsl::AsyncValueRef<LaunchEvent> Launch(
const ThreadDim& thread_dims, absl::Span<const DeviceMemoryBase> buffers,
TaskRunner task_runner) const;
tsl::AsyncValueRef<LaunchEvent> Launch(
const ThreadDim& thread_dims, absl::Span<const SE_HOST_KernelArg> args,
TaskRunner task_runner) const;

// For host platform, we assume that a core is a thread, and we can run at
// most one instance of a kernel on a given thread.
Expand Down
2 changes: 1 addition & 1 deletion xla/stream_executor/host/host_kernel_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ typedef struct SE_HOST_KernelCallFrame {
SE_HOST_KernelThread* thread;

size_t num_args;
SE_HOST_KernelArg* args;
const SE_HOST_KernelArg* args;
} SE_HOST_KernelCallFrame;

// Error reporting for host kernels. NULL means success.
Expand Down

0 comments on commit 48b640c

Please sign in to comment.