From 48b640c0abf1cebde8bb0106151101179bb327ba Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 24 Jun 2024 20:20:46 -0700 Subject: [PATCH] [xla:cpu] Optimize KernelThunk by passing SE_HOST_KernelArg directly to the kernel PiperOrigin-RevId: 646311089 --- .../select_and_scatter_benchmark_test.cc | 1 + xla/service/cpu/runtime/kernel_thunk.cc | 30 +++++++++-------- xla/stream_executor/host/host_kernel.cc | 33 ++++++++++++++----- xla/stream_executor/host/host_kernel.h | 5 +++ xla/stream_executor/host/host_kernel_c_api.h | 2 +- 5 files changed, 47 insertions(+), 24 deletions(-) diff --git a/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc b/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc index 7521dcda5b0f86..bbc32250444b0f 100644 --- a/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc +++ b/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc @@ -73,6 +73,7 @@ static void BM_SelectAndScatterF32(benchmark::State& state) { BENCHMARK(BM_SelectAndScatterF32) ->MeasureProcessCPUTime() + ->Arg(64) ->Arg(128) ->Arg(256) ->Arg(512) diff --git a/xla/service/cpu/runtime/kernel_thunk.cc b/xla/service/cpu/runtime/kernel_thunk.cc index 247c995649a525..a8d793d1076071 100644 --- a/xla/service/cpu/runtime/kernel_thunk.cc +++ b/xla/service/cpu/runtime/kernel_thunk.cc @@ -87,38 +87,40 @@ tsl::AsyncValueRef KernelThunk::Execute( kernel_name_, arguments_buffers_.size(), results_buffers_.size(), thread_dim_.ToString()); - absl::InlinedVector buffers_data; - buffers_data.reserve(arguments_buffers_.size() + results_buffers_.size()); + absl::InlinedVector kernel_args; + kernel_args.reserve(arguments_buffers_.size() + results_buffers_.size()); int64_t arg_num = 0; for (BufferAllocation::Slice& buffer : arguments_buffers_) { - TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(), + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase arg_data, params.buffer_allocations->GetDeviceAddress(buffer)); + kernel_args.push_back( + SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()}); VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", arg_num++, - buffer.ToString(), - buffers_data.back().opaque()); + buffer.ToString(), kernel_args.back().data); } int64_t res_num = 0; for (BufferAllocation::Slice& buffer : results_buffers_) { - TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(), + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result_data, params.buffer_allocations->GetDeviceAddress(buffer)); + kernel_args.push_back( + SE_HOST_KernelArg{result_data.opaque(), result_data.size()}); VLOG(3) << absl::StreamFormat(" res #%d: %s (%p)", res_num++, - buffer.ToString(), - buffers_data.back().opaque()); + buffer.ToString(), kernel_args.back().data); } // Check that all buffers are aligned to the minimum alignment. We codegen // with the assumption that all buffers are aligned, and if they are not, we // will crash with a segmentation fault, or worse, produce incorrect results. if (min_alignment_.has_value()) { - for (int64_t i = 0; i < buffers_data.size(); ++i) { - auto ptr = reinterpret_cast(buffers_data[i].opaque()); + for (int64_t i = 0; i < kernel_args.size(); ++i) { + auto ptr = reinterpret_cast(kernel_args[i].data); if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) { return Internal( "Host kernel %s buffer argument #%d (%p) is not aligned to a " "required minimum alignment of %d bytes", - info().op_name, i, buffers_data[i].opaque(), *min_alignment_); + info().op_name, i, kernel_args[i].data, *min_alignment_); } } } @@ -134,7 +136,7 @@ tsl::AsyncValueRef KernelThunk::Execute( params.host_kernels->Find(kernel_name_)); absl::MutexLock lock(&mutex_); - kernel_.emplace(buffers_data.size(), kernel_fn, nullptr); + kernel_.emplace(kernel_args.size(), kernel_fn, nullptr); kernel_ptr_.store(kernel = &kernel_.value()); } @@ -142,14 +144,14 @@ tsl::AsyncValueRef KernelThunk::Execute( // by scheduling tasks into it. HostKernel launch completion will // automatically signal KernelThunk execute completion. if (ABSL_PREDICT_FALSE(params.intra_op_threadpool && use_task_runner_)) { - return kernel->Launch(thread_dim_, buffers_data, + return kernel->Launch(thread_dim_, kernel_args, [¶ms](se::host::HostKernel::Task task) { params.intra_op_threadpool->getPool()->Schedule( ToCopyableTask(std::move(task))); }); } - TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, buffers_data)); + TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, kernel_args)); return OkExecuteEvent(); } diff --git a/xla/stream_executor/host/host_kernel.cc b/xla/stream_executor/host/host_kernel.cc index ceb148cdaaf918..04586b5272432b 100644 --- a/xla/stream_executor/host/host_kernel.cc +++ b/xla/stream_executor/host/host_kernel.cc @@ -69,7 +69,7 @@ class HostKernelExecuteState HostKernelExecuteState(HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function, ThreadDim thread_dims, - absl::Span buffers); + absl::Span args); // Notify of a completion of a host kernel task. void Notify(absl::Status status); @@ -118,11 +118,19 @@ HostKernel::HostKernel(unsigned arity, SE_HOST_Kernel* kernel, absl::Status HostKernel::Launch( const ThreadDim& thread_dims, absl::Span buffers) const { - SE_HOST_KernelThreadDim kernel_thread_dims = {thread_dims.x, thread_dims.y, - thread_dims.z}; + return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers)); +} + +absl::Status HostKernel::Launch( + const ThreadDim& thread_dims, + absl::Span args) const { + SE_HOST_KernelThreadDim kernel_thread_dims = { + thread_dims.x, + thread_dims.y, + thread_dims.z, + }; SE_HOST_Kernel* kernel = function_->kernel(); - auto args = ConvertBuffersToKernelArgs(buffers); for (uint64_t z = 0; z < thread_dims.z; ++z) { for (uint64_t y = 0; y < thread_dims.y; ++y) { @@ -134,7 +142,7 @@ absl::Status HostKernel::Launch( SE_HOST_KernelError* error = (*kernel)(&call_frame); - if (error != nullptr) { + if (ABSL_PREDICT_FALSE(error != nullptr)) { return absl::InternalError("Failed to call host kernel"); } } @@ -147,12 +155,19 @@ absl::Status HostKernel::Launch( tsl::AsyncValueRef HostKernel::Launch( const ThreadDim& thread_dims, absl::Span buffers, TaskRunner task_runner) const { + return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers), + std::move(task_runner)); +} + +tsl::AsyncValueRef HostKernel::Launch( + const ThreadDim& thread_dims, absl::Span args, + TaskRunner task_runner) const { size_t num_tasks = thread_dims.x * thread_dims.y * thread_dims.z; CHECK_GT(num_tasks, 0) << "Number of tasks must be positive"; // Crash Ok // Short-circuit launch with a single task and run it in the caller thread. if (ABSL_PREDICT_TRUE(num_tasks == 1)) { - absl::Status launched = Launch(thread_dims, buffers); + absl::Status launched = Launch(thread_dims, args); return ABSL_PREDICT_TRUE(launched.ok()) ? OkLaunchEvent() : tsl::MakeErrorAsyncValueRef(std::move(launched)); @@ -160,7 +175,7 @@ tsl::AsyncValueRef HostKernel::Launch( // Allocate a control structure that will orchestrate kernel execution. auto state = tsl::MakeRef( - std::move(task_runner), function_.get(), thread_dims, buffers); + std::move(task_runner), function_.get(), thread_dims, args); state->CallAsync(/*start_index=*/0, /*end_index=*/num_tasks); @@ -169,12 +184,12 @@ tsl::AsyncValueRef HostKernel::Launch( HostKernelExecuteState::HostKernelExecuteState( HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function, - ThreadDim thread_dims, absl::Span buffers) + ThreadDim thread_dims, absl::Span args) : task_runner_(std::move(task_runner)), num_tasks_(thread_dims.x * thread_dims.y * thread_dims.z), kernel_(function->kernel()), thread_dims_({thread_dims.x, thread_dims.y, thread_dims.z}), - args_(ConvertBuffersToKernelArgs(buffers)), + args_(args.begin(), args.end()), abort_(false), counter_(num_tasks_), event_(tsl::MakeConstructedAsyncValueRef()) {} diff --git a/xla/stream_executor/host/host_kernel.h b/xla/stream_executor/host/host_kernel.h index e8e040a2d86173..9d278b2b79c357 100644 --- a/xla/stream_executor/host/host_kernel.h +++ b/xla/stream_executor/host/host_kernel.h @@ -80,6 +80,8 @@ class HostKernel : public Kernel { // `thread_dims` and calling the kernel function. absl::Status Launch(const ThreadDim& thread_dims, absl::Span buffers) const; + absl::Status Launch(const ThreadDim& thread_dims, + absl::Span args) const; // Launches the kernel by iterating over all threads in `thread_dims` and // calling `task_runner` to run individual task (implementation might decide @@ -93,6 +95,9 @@ class HostKernel : public Kernel { tsl::AsyncValueRef Launch( const ThreadDim& thread_dims, absl::Span buffers, TaskRunner task_runner) const; + tsl::AsyncValueRef Launch( + const ThreadDim& thread_dims, absl::Span args, + TaskRunner task_runner) const; // For host platform, we assume that a core is a thread, and we can run at // most one instance of a kernel on a given thread. diff --git a/xla/stream_executor/host/host_kernel_c_api.h b/xla/stream_executor/host/host_kernel_c_api.h index 6768706abc2800..30f710cb44b264 100644 --- a/xla/stream_executor/host/host_kernel_c_api.h +++ b/xla/stream_executor/host/host_kernel_c_api.h @@ -71,7 +71,7 @@ typedef struct SE_HOST_KernelCallFrame { SE_HOST_KernelThread* thread; size_t num_args; - SE_HOST_KernelArg* args; + const SE_HOST_KernelArg* args; } SE_HOST_KernelCallFrame; // Error reporting for host kernels. NULL means success.