Skip to content

Commit

Permalink
[xla:cpu] NFC: Micro-optimizations for KernelThunk
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 646316826
  • Loading branch information
ezhulenev authored and copybara-github committed Jun 26, 2024
1 parent 6a009c8 commit cbab42f
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ BENCHMARK(BM_SelectAndScatterF32)
->Arg(64)
->Arg(128)
->Arg(256)
->Arg(512)
->Arg(1024);
->Arg(512);

} // namespace xla::cpu
30 changes: 18 additions & 12 deletions xla/service/cpu/runtime/kernel_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,40 +87,46 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
kernel_name_, arguments_buffers_.size(), results_buffers_.size(),
thread_dim_.ToString());

absl::InlinedVector<SE_HOST_KernelArg, 8> kernel_args;
kernel_args.reserve(arguments_buffers_.size() + results_buffers_.size());
int64_t num_args = arguments_buffers_.size() + results_buffers_.size();
absl::InlinedVector<SE_HOST_KernelArg, 8> kernel_args(num_args);

// We initialize `kernel_args` array using pointer to the first argument,
// because individual elements access adds up measurable overhead, and this
// code is on the critical path.
SE_HOST_KernelArg* kernel_args_ptr = kernel_args.data();
int64_t kernel_arg_idx = 0;

int64_t arg_num = 0;
for (BufferAllocation::Slice& buffer : arguments_buffers_) {
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase arg_data,
params.buffer_allocations->GetDeviceAddress(buffer));
kernel_args.push_back(
SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()});
VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", arg_num++,
buffer.ToString(), kernel_args.back().data);
buffer.ToString(), arg_data.opaque());
kernel_args_ptr[kernel_arg_idx++] =
SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()};
}

int64_t res_num = 0;
for (BufferAllocation::Slice& buffer : results_buffers_) {
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result_data,
params.buffer_allocations->GetDeviceAddress(buffer));
kernel_args.push_back(
SE_HOST_KernelArg{result_data.opaque(), result_data.size()});
VLOG(3) << absl::StreamFormat(" res #%d: %s (%p)", res_num++,
buffer.ToString(), kernel_args.back().data);
buffer.ToString(), result_data.opaque());
kernel_args_ptr[kernel_arg_idx++] =
SE_HOST_KernelArg{result_data.opaque(), result_data.size()};
}

// Check that all buffers are aligned to the minimum alignment. We codegen
// with the assumption that all buffers are aligned, and if they are not, we
// will crash with a segmentation fault, or worse, produce incorrect results.
if (min_alignment_.has_value()) {
for (int64_t i = 0; i < kernel_args.size(); ++i) {
auto ptr = reinterpret_cast<uintptr_t>(kernel_args[i].data);
for (int64_t i = 0; i < num_args; ++i) {
auto ptr = reinterpret_cast<uintptr_t>(kernel_args_ptr[i].data);
if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) {
return Internal(
"Host kernel %s buffer argument #%d (%p) is not aligned to a "
"required minimum alignment of %d bytes",
info().op_name, i, kernel_args[i].data, *min_alignment_);
info().op_name, i, kernel_args_ptr[i].data, *min_alignment_);
}
}
}
Expand All @@ -136,7 +142,7 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
params.host_kernels->Find(kernel_name_));

absl::MutexLock lock(&mutex_);
kernel_.emplace(kernel_args.size(), kernel_fn, nullptr);
kernel_.emplace(num_args, kernel_fn, nullptr);
kernel_ptr_.store(kernel = &kernel_.value());
}

Expand Down
16 changes: 7 additions & 9 deletions xla/stream_executor/host/host_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ class HostKernelExecuteState
: public tsl::ReferenceCounted<HostKernelExecuteState> {
public:
HostKernelExecuteState(HostKernel::TaskRunner task_runner,
HostKernel::KernelFunction* function,
ThreadDim thread_dims,
SE_HOST_Kernel* kernel, ThreadDim thread_dims,
absl::Span<const SE_HOST_KernelArg> args);

// Notify of a completion of a host kernel task.
Expand Down Expand Up @@ -112,6 +111,7 @@ HostKernel::HostKernel(std::shared_ptr<tsl::thread::ThreadPool> thread_pool)
HostKernel::HostKernel(unsigned arity, SE_HOST_Kernel* kernel,
std::shared_ptr<tsl::thread::ThreadPool> thread_pool)
: function_(std::make_unique<KernelFunctionPtr>(kernel)),
kernel_(function_->kernel()),
arity_(arity),
thread_pool_(thread_pool) {}

Expand All @@ -130,8 +130,6 @@ absl::Status HostKernel::Launch(
thread_dims.z,
};

SE_HOST_Kernel* kernel = function_->kernel();

for (uint64_t z = 0; z < thread_dims.z; ++z) {
for (uint64_t y = 0; y < thread_dims.y; ++y) {
for (uint64_t x = 0; x < thread_dims.x; ++x) {
Expand All @@ -140,7 +138,7 @@ absl::Status HostKernel::Launch(
SE_HOST_KernelCallFrame call_frame = {
&kernel_thread_dims, &kernel_thread, args.size(), args.data()};

SE_HOST_KernelError* error = (*kernel)(&call_frame);
SE_HOST_KernelError* error = (*kernel_)(&call_frame);

if (ABSL_PREDICT_FALSE(error != nullptr)) {
return absl::InternalError("Failed to call host kernel");
Expand Down Expand Up @@ -174,20 +172,20 @@ tsl::AsyncValueRef<LaunchEvent> HostKernel::Launch(
}

// Allocate a control structure that will orchestrate kernel execution.
auto state = tsl::MakeRef<HostKernelExecuteState>(
std::move(task_runner), function_.get(), thread_dims, args);
auto state = tsl::MakeRef<HostKernelExecuteState>(std::move(task_runner),
kernel_, thread_dims, args);

state->CallAsync(/*start_index=*/0, /*end_index=*/num_tasks);

return state->event();
}

HostKernelExecuteState::HostKernelExecuteState(
HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function,
HostKernel::TaskRunner task_runner, SE_HOST_Kernel kernel,
ThreadDim thread_dims, absl::Span<const SE_HOST_KernelArg> args)
: task_runner_(std::move(task_runner)),
num_tasks_(thread_dims.x * thread_dims.y * thread_dims.z),
kernel_(function->kernel()),
kernel_(kernel),
thread_dims_({thread_dims.x, thread_dims.y, thread_dims.z}),
args_(args.begin(), args.end()),
abort_(false),
Expand Down
2 changes: 2 additions & 0 deletions xla/stream_executor/host/host_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,12 @@ class HostKernel : public Kernel {
std::enable_if_t<std::is_base_of_v<KernelFunction, T>>* = nullptr>
void SetKernelFunction(std::unique_ptr<T> function) {
function_ = std::move(function);
kernel_ = function_->kernel();
}

private:
std::unique_ptr<KernelFunction> function_;
SE_HOST_Kernel* kernel_; // pointer to the kernel owned by `function_`

unsigned arity_;
std::shared_ptr<tsl::thread::ThreadPool> thread_pool_;
Expand Down

0 comments on commit cbab42f

Please sign in to comment.