diff --git a/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc b/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc index 7521dcda5b0f86..bbc32250444b0f 100644 --- a/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc +++ b/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc @@ -73,6 +73,7 @@ static void BM_SelectAndScatterF32(benchmark::State& state) { BENCHMARK(BM_SelectAndScatterF32) ->MeasureProcessCPUTime() + ->Arg(64) ->Arg(128) ->Arg(256) ->Arg(512) diff --git a/xla/service/cpu/runtime/kernel_thunk.cc b/xla/service/cpu/runtime/kernel_thunk.cc index 247c995649a525..a8d793d1076071 100644 --- a/xla/service/cpu/runtime/kernel_thunk.cc +++ b/xla/service/cpu/runtime/kernel_thunk.cc @@ -87,38 +87,40 @@ tsl::AsyncValueRef KernelThunk::Execute( kernel_name_, arguments_buffers_.size(), results_buffers_.size(), thread_dim_.ToString()); - absl::InlinedVector buffers_data; - buffers_data.reserve(arguments_buffers_.size() + results_buffers_.size()); + absl::InlinedVector kernel_args; + kernel_args.reserve(arguments_buffers_.size() + results_buffers_.size()); int64_t arg_num = 0; for (BufferAllocation::Slice& buffer : arguments_buffers_) { - TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(), + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase arg_data, params.buffer_allocations->GetDeviceAddress(buffer)); + kernel_args.push_back( + SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()}); VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", arg_num++, - buffer.ToString(), - buffers_data.back().opaque()); + buffer.ToString(), kernel_args.back().data); } int64_t res_num = 0; for (BufferAllocation::Slice& buffer : results_buffers_) { - TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(), + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result_data, params.buffer_allocations->GetDeviceAddress(buffer)); + kernel_args.push_back( + SE_HOST_KernelArg{result_data.opaque(), result_data.size()}); VLOG(3) << absl::StreamFormat(" res #%d: %s (%p)", res_num++, - buffer.ToString(), - buffers_data.back().opaque()); + buffer.ToString(), kernel_args.back().data); } // Check that all buffers are aligned to the minimum alignment. We codegen // with the assumption that all buffers are aligned, and if they are not, we // will crash with a segmentation fault, or worse, produce incorrect results. if (min_alignment_.has_value()) { - for (int64_t i = 0; i < buffers_data.size(); ++i) { - auto ptr = reinterpret_cast(buffers_data[i].opaque()); + for (int64_t i = 0; i < kernel_args.size(); ++i) { + auto ptr = reinterpret_cast(kernel_args[i].data); if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) { return Internal( "Host kernel %s buffer argument #%d (%p) is not aligned to a " "required minimum alignment of %d bytes", - info().op_name, i, buffers_data[i].opaque(), *min_alignment_); + info().op_name, i, kernel_args[i].data, *min_alignment_); } } } @@ -134,7 +136,7 @@ tsl::AsyncValueRef KernelThunk::Execute( params.host_kernels->Find(kernel_name_)); absl::MutexLock lock(&mutex_); - kernel_.emplace(buffers_data.size(), kernel_fn, nullptr); + kernel_.emplace(kernel_args.size(), kernel_fn, nullptr); kernel_ptr_.store(kernel = &kernel_.value()); } @@ -142,14 +144,14 @@ tsl::AsyncValueRef KernelThunk::Execute( // by scheduling tasks into it. HostKernel launch completion will // automatically signal KernelThunk execute completion. if (ABSL_PREDICT_FALSE(params.intra_op_threadpool && use_task_runner_)) { - return kernel->Launch(thread_dim_, buffers_data, + return kernel->Launch(thread_dim_, kernel_args, [¶ms](se::host::HostKernel::Task task) { params.intra_op_threadpool->getPool()->Schedule( ToCopyableTask(std::move(task))); }); } - TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, buffers_data)); + TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, kernel_args)); return OkExecuteEvent(); } diff --git a/xla/service/cpu/runtime/kernel_thunk_test.cc b/xla/service/cpu/runtime/kernel_thunk_test.cc index a80db35857e86f..9f6993dc9a73a7 100644 --- a/xla/service/cpu/runtime/kernel_thunk_test.cc +++ b/xla/service/cpu/runtime/kernel_thunk_test.cc @@ -40,8 +40,8 @@ class AddF32HostKernels : public Thunk::HostKernels { public: absl::StatusOr Find(std::string_view name) override { return +[](const SE_HOST_KernelCallFrame* call_frame) { - SE_HOST_KernelArg& in = call_frame->args[0]; - SE_HOST_KernelArg& out = call_frame->args[1]; + const SE_HOST_KernelArg& in = call_frame->args[0]; + const SE_HOST_KernelArg& out = call_frame->args[1]; float* in_ptr = reinterpret_cast(in.data); float* out_ptr = reinterpret_cast(out.data); diff --git a/xla/stream_executor/host/host_kernel.cc b/xla/stream_executor/host/host_kernel.cc index ceb148cdaaf918..04586b5272432b 100644 --- a/xla/stream_executor/host/host_kernel.cc +++ b/xla/stream_executor/host/host_kernel.cc @@ -69,7 +69,7 @@ class HostKernelExecuteState HostKernelExecuteState(HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function, ThreadDim thread_dims, - absl::Span buffers); + absl::Span args); // Notify of a completion of a host kernel task. void Notify(absl::Status status); @@ -118,11 +118,19 @@ HostKernel::HostKernel(unsigned arity, SE_HOST_Kernel* kernel, absl::Status HostKernel::Launch( const ThreadDim& thread_dims, absl::Span buffers) const { - SE_HOST_KernelThreadDim kernel_thread_dims = {thread_dims.x, thread_dims.y, - thread_dims.z}; + return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers)); +} + +absl::Status HostKernel::Launch( + const ThreadDim& thread_dims, + absl::Span args) const { + SE_HOST_KernelThreadDim kernel_thread_dims = { + thread_dims.x, + thread_dims.y, + thread_dims.z, + }; SE_HOST_Kernel* kernel = function_->kernel(); - auto args = ConvertBuffersToKernelArgs(buffers); for (uint64_t z = 0; z < thread_dims.z; ++z) { for (uint64_t y = 0; y < thread_dims.y; ++y) { @@ -134,7 +142,7 @@ absl::Status HostKernel::Launch( SE_HOST_KernelError* error = (*kernel)(&call_frame); - if (error != nullptr) { + if (ABSL_PREDICT_FALSE(error != nullptr)) { return absl::InternalError("Failed to call host kernel"); } } @@ -147,12 +155,19 @@ absl::Status HostKernel::Launch( tsl::AsyncValueRef HostKernel::Launch( const ThreadDim& thread_dims, absl::Span buffers, TaskRunner task_runner) const { + return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers), + std::move(task_runner)); +} + +tsl::AsyncValueRef HostKernel::Launch( + const ThreadDim& thread_dims, absl::Span args, + TaskRunner task_runner) const { size_t num_tasks = thread_dims.x * thread_dims.y * thread_dims.z; CHECK_GT(num_tasks, 0) << "Number of tasks must be positive"; // Crash Ok // Short-circuit launch with a single task and run it in the caller thread. if (ABSL_PREDICT_TRUE(num_tasks == 1)) { - absl::Status launched = Launch(thread_dims, buffers); + absl::Status launched = Launch(thread_dims, args); return ABSL_PREDICT_TRUE(launched.ok()) ? OkLaunchEvent() : tsl::MakeErrorAsyncValueRef(std::move(launched)); @@ -160,7 +175,7 @@ tsl::AsyncValueRef HostKernel::Launch( // Allocate a control structure that will orchestrate kernel execution. auto state = tsl::MakeRef( - std::move(task_runner), function_.get(), thread_dims, buffers); + std::move(task_runner), function_.get(), thread_dims, args); state->CallAsync(/*start_index=*/0, /*end_index=*/num_tasks); @@ -169,12 +184,12 @@ tsl::AsyncValueRef HostKernel::Launch( HostKernelExecuteState::HostKernelExecuteState( HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function, - ThreadDim thread_dims, absl::Span buffers) + ThreadDim thread_dims, absl::Span args) : task_runner_(std::move(task_runner)), num_tasks_(thread_dims.x * thread_dims.y * thread_dims.z), kernel_(function->kernel()), thread_dims_({thread_dims.x, thread_dims.y, thread_dims.z}), - args_(ConvertBuffersToKernelArgs(buffers)), + args_(args.begin(), args.end()), abort_(false), counter_(num_tasks_), event_(tsl::MakeConstructedAsyncValueRef()) {} diff --git a/xla/stream_executor/host/host_kernel.h b/xla/stream_executor/host/host_kernel.h index e8e040a2d86173..9d278b2b79c357 100644 --- a/xla/stream_executor/host/host_kernel.h +++ b/xla/stream_executor/host/host_kernel.h @@ -80,6 +80,8 @@ class HostKernel : public Kernel { // `thread_dims` and calling the kernel function. absl::Status Launch(const ThreadDim& thread_dims, absl::Span buffers) const; + absl::Status Launch(const ThreadDim& thread_dims, + absl::Span args) const; // Launches the kernel by iterating over all threads in `thread_dims` and // calling `task_runner` to run individual task (implementation might decide @@ -93,6 +95,9 @@ class HostKernel : public Kernel { tsl::AsyncValueRef Launch( const ThreadDim& thread_dims, absl::Span buffers, TaskRunner task_runner) const; + tsl::AsyncValueRef Launch( + const ThreadDim& thread_dims, absl::Span args, + TaskRunner task_runner) const; // For host platform, we assume that a core is a thread, and we can run at // most one instance of a kernel on a given thread. diff --git a/xla/stream_executor/host/host_kernel_c_api.h b/xla/stream_executor/host/host_kernel_c_api.h index 6768706abc2800..30f710cb44b264 100644 --- a/xla/stream_executor/host/host_kernel_c_api.h +++ b/xla/stream_executor/host/host_kernel_c_api.h @@ -71,7 +71,7 @@ typedef struct SE_HOST_KernelCallFrame { SE_HOST_KernelThread* thread; size_t num_args; - SE_HOST_KernelArg* args; + const SE_HOST_KernelArg* args; } SE_HOST_KernelCallFrame; // Error reporting for host kernels. NULL means success. diff --git a/xla/stream_executor/host/host_kernel_test.cc b/xla/stream_executor/host/host_kernel_test.cc index 5a121bf17cb5b7..aff9e1ed19ce7b 100644 --- a/xla/stream_executor/host/host_kernel_test.cc +++ b/xla/stream_executor/host/host_kernel_test.cc @@ -53,9 +53,9 @@ static auto ToCopyableTask(HostKernel::Task task) { } static SE_HOST_KernelError* AddI32(const SE_HOST_KernelCallFrame* call_frame) { - SE_HOST_KernelArg& lhs = call_frame->args[0]; - SE_HOST_KernelArg& rhs = call_frame->args[1]; - SE_HOST_KernelArg& out = call_frame->args[2]; + const SE_HOST_KernelArg& lhs = call_frame->args[0]; + const SE_HOST_KernelArg& rhs = call_frame->args[1]; + const SE_HOST_KernelArg& out = call_frame->args[2]; int32_t* lhs_ptr = reinterpret_cast(lhs.data); int32_t* rhs_ptr = reinterpret_cast(rhs.data); @@ -217,7 +217,9 @@ TEST(HostKernelTest, LaunchAsync) { }; HostKernel host_kernel(/*arity=*/0, no_op); - auto event = host_kernel.Launch(ThreadDim(4, 4, 4), {}, std::move(runner)); + auto event = host_kernel.Launch(ThreadDim(4, 4, 4), + absl::Span(), + std::move(runner)); tsl::BlockUntilReady(event); EXPECT_TRUE(event.IsConcrete()); @@ -245,7 +247,9 @@ TEST(HostKernelTest, LaunchAsyncError) { }; HostKernel host_kernel(/*arity=*/0, maybe_error); - auto event = host_kernel.Launch(ThreadDim(4, 4, 4), {}, std::move(runner)); + auto event = host_kernel.Launch(ThreadDim(4, 4, 4), + absl::Span(), + std::move(runner)); tsl::BlockUntilReady(event); ASSERT_TRUE(event.IsError()); @@ -269,7 +273,8 @@ static void BM_HostKernelSyncLaunch(benchmark::State& state) { HostKernel kernel(/*arity=*/0, NoOp); for (auto _ : state) { - benchmark::DoNotOptimize(kernel.Launch(ThreadDim(tdim_x), /*buffers=*/{})); + benchmark::DoNotOptimize(kernel.Launch( + ThreadDim(tdim_x), absl::Span())); } } @@ -281,9 +286,11 @@ static void BM_HostKernelAsyncLaunch(benchmark::State& state) { HostKernel kernel(/*arity=*/0, NoOp); for (auto _ : state) { - auto event = kernel.Launch(ThreadDim(tdim_x), {}, [&](auto task) { - thread_pool->Schedule(ToCopyableTask(std::move(task))); - }); + auto event = + kernel.Launch(ThreadDim(tdim_x), absl::Span(), + [&](auto task) { + thread_pool->Schedule(ToCopyableTask(std::move(task))); + }); tsl::BlockUntilReady(event); } }