Skip to content

Commit

Permalink
[xla:cpu] Optimize KernelThunk by passing SE_HOST_KernelArg directly …
Browse files Browse the repository at this point in the history
…to the kernel

PiperOrigin-RevId: 646664927
  • Loading branch information
ezhulenev authored and copybara-github committed Jun 26, 2024
1 parent e317907 commit ec96f90
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ static void BM_SelectAndScatterF32(benchmark::State& state) {

BENCHMARK(BM_SelectAndScatterF32)
->MeasureProcessCPUTime()
->Arg(64)
->Arg(128)
->Arg(256)
->Arg(512)
Expand Down
30 changes: 16 additions & 14 deletions xla/service/cpu/runtime/kernel_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,38 +87,40 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
kernel_name_, arguments_buffers_.size(), results_buffers_.size(),
thread_dim_.ToString());

absl::InlinedVector<se::DeviceMemoryBase, 8> buffers_data;
buffers_data.reserve(arguments_buffers_.size() + results_buffers_.size());
absl::InlinedVector<SE_HOST_KernelArg, 8> kernel_args;
kernel_args.reserve(arguments_buffers_.size() + results_buffers_.size());

int64_t arg_num = 0;
for (BufferAllocation::Slice& buffer : arguments_buffers_) {
TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(),
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase arg_data,
params.buffer_allocations->GetDeviceAddress(buffer));
kernel_args.push_back(
SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()});
VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", arg_num++,
buffer.ToString(),
buffers_data.back().opaque());
buffer.ToString(), kernel_args.back().data);
}

int64_t res_num = 0;
for (BufferAllocation::Slice& buffer : results_buffers_) {
TF_ASSIGN_OR_RETURN(buffers_data.emplace_back(),
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result_data,
params.buffer_allocations->GetDeviceAddress(buffer));
kernel_args.push_back(
SE_HOST_KernelArg{result_data.opaque(), result_data.size()});
VLOG(3) << absl::StreamFormat(" res #%d: %s (%p)", res_num++,
buffer.ToString(),
buffers_data.back().opaque());
buffer.ToString(), kernel_args.back().data);
}

// Check that all buffers are aligned to the minimum alignment. We codegen
// with the assumption that all buffers are aligned, and if they are not, we
// will crash with a segmentation fault, or worse, produce incorrect results.
if (min_alignment_.has_value()) {
for (int64_t i = 0; i < buffers_data.size(); ++i) {
auto ptr = reinterpret_cast<uintptr_t>(buffers_data[i].opaque());
for (int64_t i = 0; i < kernel_args.size(); ++i) {
auto ptr = reinterpret_cast<uintptr_t>(kernel_args[i].data);
if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) {
return Internal(
"Host kernel %s buffer argument #%d (%p) is not aligned to a "
"required minimum alignment of %d bytes",
info().op_name, i, buffers_data[i].opaque(), *min_alignment_);
info().op_name, i, kernel_args[i].data, *min_alignment_);
}
}
}
Expand All @@ -134,22 +136,22 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
params.host_kernels->Find(kernel_name_));

absl::MutexLock lock(&mutex_);
kernel_.emplace(buffers_data.size(), kernel_fn, nullptr);
kernel_.emplace(kernel_args.size(), kernel_fn, nullptr);
kernel_ptr_.store(kernel = &kernel_.value());
}

// If intra-op thread pool is not nullptr, we launch HostKernel in async mode
// by scheduling tasks into it. HostKernel launch completion will
// automatically signal KernelThunk execute completion.
if (ABSL_PREDICT_FALSE(params.intra_op_threadpool && use_task_runner_)) {
return kernel->Launch(thread_dim_, buffers_data,
return kernel->Launch(thread_dim_, kernel_args,
[&params](se::host::HostKernel::Task task) {
params.intra_op_threadpool->getPool()->Schedule(
ToCopyableTask(std::move(task)));
});
}

TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, buffers_data));
TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, kernel_args));
return OkExecuteEvent();
}

Expand Down
4 changes: 2 additions & 2 deletions xla/service/cpu/runtime/kernel_thunk_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class AddF32HostKernels : public Thunk::HostKernels {
public:
absl::StatusOr<SE_HOST_Kernel*> Find(std::string_view name) override {
return +[](const SE_HOST_KernelCallFrame* call_frame) {
SE_HOST_KernelArg& in = call_frame->args[0];
SE_HOST_KernelArg& out = call_frame->args[1];
const SE_HOST_KernelArg& in = call_frame->args[0];
const SE_HOST_KernelArg& out = call_frame->args[1];

float* in_ptr = reinterpret_cast<float*>(in.data);
float* out_ptr = reinterpret_cast<float*>(out.data);
Expand Down
33 changes: 24 additions & 9 deletions xla/stream_executor/host/host_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class HostKernelExecuteState
HostKernelExecuteState(HostKernel::TaskRunner task_runner,
HostKernel::KernelFunction* function,
ThreadDim thread_dims,
absl::Span<const DeviceMemoryBase> buffers);
absl::Span<const SE_HOST_KernelArg> args);

// Notify of a completion of a host kernel task.
void Notify(absl::Status status);
Expand Down Expand Up @@ -118,11 +118,19 @@ HostKernel::HostKernel(unsigned arity, SE_HOST_Kernel* kernel,
absl::Status HostKernel::Launch(
const ThreadDim& thread_dims,
absl::Span<const DeviceMemoryBase> buffers) const {
SE_HOST_KernelThreadDim kernel_thread_dims = {thread_dims.x, thread_dims.y,
thread_dims.z};
return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers));
}

absl::Status HostKernel::Launch(
const ThreadDim& thread_dims,
absl::Span<const SE_HOST_KernelArg> args) const {
SE_HOST_KernelThreadDim kernel_thread_dims = {
thread_dims.x,
thread_dims.y,
thread_dims.z,
};

SE_HOST_Kernel* kernel = function_->kernel();
auto args = ConvertBuffersToKernelArgs(buffers);

for (uint64_t z = 0; z < thread_dims.z; ++z) {
for (uint64_t y = 0; y < thread_dims.y; ++y) {
Expand All @@ -134,7 +142,7 @@ absl::Status HostKernel::Launch(

SE_HOST_KernelError* error = (*kernel)(&call_frame);

if (error != nullptr) {
if (ABSL_PREDICT_FALSE(error != nullptr)) {
return absl::InternalError("Failed to call host kernel");
}
}
Expand All @@ -147,20 +155,27 @@ absl::Status HostKernel::Launch(
tsl::AsyncValueRef<LaunchEvent> HostKernel::Launch(
const ThreadDim& thread_dims, absl::Span<const DeviceMemoryBase> buffers,
TaskRunner task_runner) const {
return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers),
std::move(task_runner));
}

tsl::AsyncValueRef<LaunchEvent> HostKernel::Launch(
const ThreadDim& thread_dims, absl::Span<const SE_HOST_KernelArg> args,
TaskRunner task_runner) const {
size_t num_tasks = thread_dims.x * thread_dims.y * thread_dims.z;
CHECK_GT(num_tasks, 0) << "Number of tasks must be positive"; // Crash Ok

// Short-circuit launch with a single task and run it in the caller thread.
if (ABSL_PREDICT_TRUE(num_tasks == 1)) {
absl::Status launched = Launch(thread_dims, buffers);
absl::Status launched = Launch(thread_dims, args);
return ABSL_PREDICT_TRUE(launched.ok())
? OkLaunchEvent()
: tsl::MakeErrorAsyncValueRef(std::move(launched));
}

// Allocate a control structure that will orchestrate kernel execution.
auto state = tsl::MakeRef<HostKernelExecuteState>(
std::move(task_runner), function_.get(), thread_dims, buffers);
std::move(task_runner), function_.get(), thread_dims, args);

state->CallAsync(/*start_index=*/0, /*end_index=*/num_tasks);

Expand All @@ -169,12 +184,12 @@ tsl::AsyncValueRef<LaunchEvent> HostKernel::Launch(

HostKernelExecuteState::HostKernelExecuteState(
HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function,
ThreadDim thread_dims, absl::Span<const DeviceMemoryBase> buffers)
ThreadDim thread_dims, absl::Span<const SE_HOST_KernelArg> args)
: task_runner_(std::move(task_runner)),
num_tasks_(thread_dims.x * thread_dims.y * thread_dims.z),
kernel_(function->kernel()),
thread_dims_({thread_dims.x, thread_dims.y, thread_dims.z}),
args_(ConvertBuffersToKernelArgs(buffers)),
args_(args.begin(), args.end()),
abort_(false),
counter_(num_tasks_),
event_(tsl::MakeConstructedAsyncValueRef<LaunchEvent>()) {}
Expand Down
5 changes: 5 additions & 0 deletions xla/stream_executor/host/host_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class HostKernel : public Kernel {
// `thread_dims` and calling the kernel function.
absl::Status Launch(const ThreadDim& thread_dims,
absl::Span<const DeviceMemoryBase> buffers) const;
absl::Status Launch(const ThreadDim& thread_dims,
absl::Span<const SE_HOST_KernelArg> args) const;

// Launches the kernel by iterating over all threads in `thread_dims` and
// calling `task_runner` to run individual task (implementation might decide
Expand All @@ -93,6 +95,9 @@ class HostKernel : public Kernel {
tsl::AsyncValueRef<LaunchEvent> Launch(
const ThreadDim& thread_dims, absl::Span<const DeviceMemoryBase> buffers,
TaskRunner task_runner) const;
tsl::AsyncValueRef<LaunchEvent> Launch(
const ThreadDim& thread_dims, absl::Span<const SE_HOST_KernelArg> args,
TaskRunner task_runner) const;

// For host platform, we assume that a core is a thread, and we can run at
// most one instance of a kernel on a given thread.
Expand Down
2 changes: 1 addition & 1 deletion xla/stream_executor/host/host_kernel_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ typedef struct SE_HOST_KernelCallFrame {
SE_HOST_KernelThread* thread;

size_t num_args;
SE_HOST_KernelArg* args;
const SE_HOST_KernelArg* args;
} SE_HOST_KernelCallFrame;

// Error reporting for host kernels. NULL means success.
Expand Down
25 changes: 16 additions & 9 deletions xla/stream_executor/host/host_kernel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ static auto ToCopyableTask(HostKernel::Task task) {
}

static SE_HOST_KernelError* AddI32(const SE_HOST_KernelCallFrame* call_frame) {
SE_HOST_KernelArg& lhs = call_frame->args[0];
SE_HOST_KernelArg& rhs = call_frame->args[1];
SE_HOST_KernelArg& out = call_frame->args[2];
const SE_HOST_KernelArg& lhs = call_frame->args[0];
const SE_HOST_KernelArg& rhs = call_frame->args[1];
const SE_HOST_KernelArg& out = call_frame->args[2];

int32_t* lhs_ptr = reinterpret_cast<int32_t*>(lhs.data);
int32_t* rhs_ptr = reinterpret_cast<int32_t*>(rhs.data);
Expand Down Expand Up @@ -217,7 +217,9 @@ TEST(HostKernelTest, LaunchAsync) {
};

HostKernel host_kernel(/*arity=*/0, no_op);
auto event = host_kernel.Launch(ThreadDim(4, 4, 4), {}, std::move(runner));
auto event = host_kernel.Launch(ThreadDim(4, 4, 4),
absl::Span<const SE_HOST_KernelArg>(),
std::move(runner));

tsl::BlockUntilReady(event);
EXPECT_TRUE(event.IsConcrete());
Expand Down Expand Up @@ -245,7 +247,9 @@ TEST(HostKernelTest, LaunchAsyncError) {
};

HostKernel host_kernel(/*arity=*/0, maybe_error);
auto event = host_kernel.Launch(ThreadDim(4, 4, 4), {}, std::move(runner));
auto event = host_kernel.Launch(ThreadDim(4, 4, 4),
absl::Span<const SE_HOST_KernelArg>(),
std::move(runner));

tsl::BlockUntilReady(event);
ASSERT_TRUE(event.IsError());
Expand All @@ -269,7 +273,8 @@ static void BM_HostKernelSyncLaunch(benchmark::State& state) {

HostKernel kernel(/*arity=*/0, NoOp);
for (auto _ : state) {
benchmark::DoNotOptimize(kernel.Launch(ThreadDim(tdim_x), /*buffers=*/{}));
benchmark::DoNotOptimize(kernel.Launch(
ThreadDim(tdim_x), absl::Span<const SE_HOST_KernelArg>()));
}
}

Expand All @@ -281,9 +286,11 @@ static void BM_HostKernelAsyncLaunch(benchmark::State& state) {

HostKernel kernel(/*arity=*/0, NoOp);
for (auto _ : state) {
auto event = kernel.Launch(ThreadDim(tdim_x), {}, [&](auto task) {
thread_pool->Schedule(ToCopyableTask(std::move(task)));
});
auto event =
kernel.Launch(ThreadDim(tdim_x), absl::Span<const SE_HOST_KernelArg>(),
[&](auto task) {
thread_pool->Schedule(ToCopyableTask(std::move(task)));
});
tsl::BlockUntilReady(event);
}
}
Expand Down

0 comments on commit ec96f90

Please sign in to comment.