diff --git a/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc b/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc index fdfdc7b00b1882..bbc32250444b0f 100644 --- a/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc +++ b/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc @@ -73,8 +73,10 @@ static void BM_SelectAndScatterF32(benchmark::State& state) { BENCHMARK(BM_SelectAndScatterF32) ->MeasureProcessCPUTime() + ->Arg(64) ->Arg(128) ->Arg(256) - ->Arg(512); + ->Arg(512) + ->Arg(1024); } // namespace xla::cpu diff --git a/xla/service/cpu/runtime/kernel_thunk.cc b/xla/service/cpu/runtime/kernel_thunk.cc index 85d65c0ec323fd..a8d793d1076071 100644 --- a/xla/service/cpu/runtime/kernel_thunk.cc +++ b/xla/service/cpu/runtime/kernel_thunk.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/numeric/bits.h" #include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "xla/runtime/buffer_use.h" @@ -86,67 +87,71 @@ tsl::AsyncValueRef KernelThunk::Execute( kernel_name_, arguments_buffers_.size(), results_buffers_.size(), thread_dim_.ToString()); - absl::InlinedVector buffers_data; - buffers_data.reserve(arguments_buffers_.size() + results_buffers_.size()); + absl::InlinedVector kernel_args; + kernel_args.reserve(arguments_buffers_.size() + results_buffers_.size()); int64_t arg_num = 0; for (BufferAllocation::Slice& buffer : arguments_buffers_) { - TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(), + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase arg_data, params.buffer_allocations->GetDeviceAddress(buffer)); + kernel_args.push_back( + SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()}); VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", arg_num++, - buffer.ToString(), - buffers_data.back().opaque()); + buffer.ToString(), kernel_args.back().data); } int64_t res_num = 0; for (BufferAllocation::Slice& buffer : results_buffers_) { - TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(), + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result_data, params.buffer_allocations->GetDeviceAddress(buffer)); + kernel_args.push_back( + SE_HOST_KernelArg{result_data.opaque(), result_data.size()}); VLOG(3) << absl::StreamFormat(" res #%d: %s (%p)", res_num++, - buffer.ToString(), - buffers_data.back().opaque()); + buffer.ToString(), kernel_args.back().data); } // Check that all buffers are aligned to the minimum alignment. We codegen // with the assumption that all buffers are aligned, and if they are not, we // will crash with a segmentation fault, or worse, produce incorrect results. if (min_alignment_.has_value()) { - for (int64_t i = 0; i < buffers_data.size(); ++i) { - auto ptr = reinterpret_cast(buffers_data[i].opaque()); + for (int64_t i = 0; i < kernel_args.size(); ++i) { + auto ptr = reinterpret_cast(kernel_args[i].data); if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) { return Internal( "Host kernel %s buffer argument #%d (%p) is not aligned to a " "required minimum alignment of %d bytes", - info().op_name, i, buffers_data[i].opaque(), *min_alignment_); + info().op_name, i, kernel_args[i].data, *min_alignment_); } } } // TODO(ezhulenev): Kernel ptr should be loaded as a part of Thunk // initialization stage. - SE_HOST_Kernel* kernel_ptr = kernel_ptr_.load(); + se::host::HostKernel* kernel = kernel_ptr_.load(); // Because thunks are owned by a parent CpuExecutable, we can safely assume // that kernel pointer will not change after we find it the first time. - if (kernel_ptr == nullptr) { - TF_ASSIGN_OR_RETURN(kernel_ptr, params.host_kernels->Find(kernel_name_)); - kernel_ptr_.store(kernel_ptr); - } + if (ABSL_PREDICT_FALSE(kernel == nullptr)) { + TF_ASSIGN_OR_RETURN(SE_HOST_Kernel * kernel_fn, + params.host_kernels->Find(kernel_name_)); - se::host::HostKernel kernel(buffers_data.size(), kernel_ptr, nullptr); + absl::MutexLock lock(&mutex_); + kernel_.emplace(kernel_args.size(), kernel_fn, nullptr); + kernel_ptr_.store(kernel = &kernel_.value()); + } // If intra-op thread pool is not nullptr, we launch HostKernel in async mode // by scheduling tasks into it. HostKernel launch completion will // automatically signal KernelThunk execute completion. - if (params.intra_op_threadpool && use_task_runner_) { - return kernel.Launch(thread_dim_, buffers_data, - [¶ms](se::host::HostKernel::Task task) { - params.intra_op_threadpool->getPool()->Schedule( - ToCopyableTask(std::move(task))); - }); + if (ABSL_PREDICT_FALSE(params.intra_op_threadpool && use_task_runner_)) { + return kernel->Launch(thread_dim_, kernel_args, + [¶ms](se::host::HostKernel::Task task) { + params.intra_op_threadpool->getPool()->Schedule( + ToCopyableTask(std::move(task))); + }); } - TF_RETURN_IF_ERROR(kernel.Launch(thread_dim_, buffers_data)); + TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, kernel_args)); return OkExecuteEvent(); } diff --git a/xla/service/cpu/runtime/kernel_thunk.h b/xla/service/cpu/runtime/kernel_thunk.h index 72cd1be097ac25..708f918d342c96 100644 --- a/xla/service/cpu/runtime/kernel_thunk.h +++ b/xla/service/cpu/runtime/kernel_thunk.h @@ -23,11 +23,13 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/runtime/thunk.h" -#include "xla/stream_executor/host/host_kernel_c_api.h" +#include "xla/stream_executor/host/host_kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -64,12 +66,10 @@ class KernelThunk final : public Thunk { // launch the kernel directly in the caller thread. bool use_task_runner_; - // Pointer to the host kernel corresponding to `kernel_name_`. Initialized - // lazily at run time by looking it up in the HostKernels passed via params. - // - // TODO(ezhulenev): This should be moved to initialization stage when we'll - // have it for CPU thunks. - std::atomic kernel_ptr_; + // Lazily loaded host kernel corresponding to `kernel_name_`. + absl::Mutex mutex_; + std::optional kernel_ ABSL_GUARDED_BY(mutex_); + std::atomic kernel_ptr_; // pointer to `kernel_` }; } // namespace xla::cpu diff --git a/xla/stream_executor/host/host_kernel.cc b/xla/stream_executor/host/host_kernel.cc index ceb148cdaaf918..04586b5272432b 100644 --- a/xla/stream_executor/host/host_kernel.cc +++ b/xla/stream_executor/host/host_kernel.cc @@ -69,7 +69,7 @@ class HostKernelExecuteState HostKernelExecuteState(HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function, ThreadDim thread_dims, - absl::Span buffers); + absl::Span args); // Notify of a completion of a host kernel task. void Notify(absl::Status status); @@ -118,11 +118,19 @@ HostKernel::HostKernel(unsigned arity, SE_HOST_Kernel* kernel, absl::Status HostKernel::Launch( const ThreadDim& thread_dims, absl::Span buffers) const { - SE_HOST_KernelThreadDim kernel_thread_dims = {thread_dims.x, thread_dims.y, - thread_dims.z}; + return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers)); +} + +absl::Status HostKernel::Launch( + const ThreadDim& thread_dims, + absl::Span args) const { + SE_HOST_KernelThreadDim kernel_thread_dims = { + thread_dims.x, + thread_dims.y, + thread_dims.z, + }; SE_HOST_Kernel* kernel = function_->kernel(); - auto args = ConvertBuffersToKernelArgs(buffers); for (uint64_t z = 0; z < thread_dims.z; ++z) { for (uint64_t y = 0; y < thread_dims.y; ++y) { @@ -134,7 +142,7 @@ absl::Status HostKernel::Launch( SE_HOST_KernelError* error = (*kernel)(&call_frame); - if (error != nullptr) { + if (ABSL_PREDICT_FALSE(error != nullptr)) { return absl::InternalError("Failed to call host kernel"); } } @@ -147,12 +155,19 @@ absl::Status HostKernel::Launch( tsl::AsyncValueRef HostKernel::Launch( const ThreadDim& thread_dims, absl::Span buffers, TaskRunner task_runner) const { + return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers), + std::move(task_runner)); +} + +tsl::AsyncValueRef HostKernel::Launch( + const ThreadDim& thread_dims, absl::Span args, + TaskRunner task_runner) const { size_t num_tasks = thread_dims.x * thread_dims.y * thread_dims.z; CHECK_GT(num_tasks, 0) << "Number of tasks must be positive"; // Crash Ok // Short-circuit launch with a single task and run it in the caller thread. if (ABSL_PREDICT_TRUE(num_tasks == 1)) { - absl::Status launched = Launch(thread_dims, buffers); + absl::Status launched = Launch(thread_dims, args); return ABSL_PREDICT_TRUE(launched.ok()) ? OkLaunchEvent() : tsl::MakeErrorAsyncValueRef(std::move(launched)); @@ -160,7 +175,7 @@ tsl::AsyncValueRef HostKernel::Launch( // Allocate a control structure that will orchestrate kernel execution. auto state = tsl::MakeRef( - std::move(task_runner), function_.get(), thread_dims, buffers); + std::move(task_runner), function_.get(), thread_dims, args); state->CallAsync(/*start_index=*/0, /*end_index=*/num_tasks); @@ -169,12 +184,12 @@ tsl::AsyncValueRef HostKernel::Launch( HostKernelExecuteState::HostKernelExecuteState( HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function, - ThreadDim thread_dims, absl::Span buffers) + ThreadDim thread_dims, absl::Span args) : task_runner_(std::move(task_runner)), num_tasks_(thread_dims.x * thread_dims.y * thread_dims.z), kernel_(function->kernel()), thread_dims_({thread_dims.x, thread_dims.y, thread_dims.z}), - args_(ConvertBuffersToKernelArgs(buffers)), + args_(args.begin(), args.end()), abort_(false), counter_(num_tasks_), event_(tsl::MakeConstructedAsyncValueRef()) {} diff --git a/xla/stream_executor/host/host_kernel.h b/xla/stream_executor/host/host_kernel.h index e8e040a2d86173..9d278b2b79c357 100644 --- a/xla/stream_executor/host/host_kernel.h +++ b/xla/stream_executor/host/host_kernel.h @@ -80,6 +80,8 @@ class HostKernel : public Kernel { // `thread_dims` and calling the kernel function. absl::Status Launch(const ThreadDim& thread_dims, absl::Span buffers) const; + absl::Status Launch(const ThreadDim& thread_dims, + absl::Span args) const; // Launches the kernel by iterating over all threads in `thread_dims` and // calling `task_runner` to run individual task (implementation might decide @@ -93,6 +95,9 @@ class HostKernel : public Kernel { tsl::AsyncValueRef Launch( const ThreadDim& thread_dims, absl::Span buffers, TaskRunner task_runner) const; + tsl::AsyncValueRef Launch( + const ThreadDim& thread_dims, absl::Span args, + TaskRunner task_runner) const; // For host platform, we assume that a core is a thread, and we can run at // most one instance of a kernel on a given thread. diff --git a/xla/stream_executor/host/host_kernel_c_api.h b/xla/stream_executor/host/host_kernel_c_api.h index 6768706abc2800..30f710cb44b264 100644 --- a/xla/stream_executor/host/host_kernel_c_api.h +++ b/xla/stream_executor/host/host_kernel_c_api.h @@ -71,7 +71,7 @@ typedef struct SE_HOST_KernelCallFrame { SE_HOST_KernelThread* thread; size_t num_args; - SE_HOST_KernelArg* args; + const SE_HOST_KernelArg* args; } SE_HOST_KernelCallFrame; // Error reporting for host kernels. NULL means success.