Skip to content

Commit

Permalink
Pull cuda command buffer kernel definitions behind methods that retur…
Browse files Browse the repository at this point in the history
…n MultiKernelLoaderSpec.

PiperOrigin-RevId: 671492158
  • Loading branch information
IllogicalMoose authored and copybara-github committed Sep 5, 2024
1 parent a5bc8f6 commit dfefe81
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 140 deletions.
8 changes: 6 additions & 2 deletions xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,12 @@ cuda_only_cc_library(
)

cc_library(
name = "cuda_conditional_kernels",
srcs = ["cuda_conditional_kernels.cc"],
name = "command_buffer_kernels",
srcs = ["command_buffer_kernels.cc"],
deps = [
"//xla/stream_executor:kernel_spec",
"@com_google_absl//absl/status:statusor",
],
)

# TODO(leary) we likely need to canonicalize/eliminate this.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ limitations under the License.

#include <string_view>

namespace stream_executor::gpu {
#include "absl/status/statusor.h"
#include "xla/stream_executor/kernel_spec.h"

namespace stream_executor {
namespace cuda {
namespace {

// Collection of helper kernels required by command buffers on CUDA backends. We
// use pre-compiled PTX instead of a CUDA C++ because conditional nodes require
Expand All @@ -41,8 +46,7 @@ namespace stream_executor::gpu {
// }
//
// Easiest way to get PTX from C++ is to use https://godbolt.org.
std::string_view GetSetIfConditionKernel() {
return R"(
inline constexpr std::string_view kSetIfConditionKernel = R"(
.version 4.0
.target sm_50
.address_size 64
Expand Down Expand Up @@ -108,7 +112,6 @@ std::string_view GetSetIfConditionKernel() {
ret;
})";
}

// PTX kernel compiled from:
//
Expand All @@ -125,8 +128,7 @@ std::string_view GetSetIfConditionKernel() {
// }
//
// Easiest way to get PTX from C++ is to use https://godbolt.org.
std::string_view GetSetIfElseConditionKernel() {
return R"(
inline constexpr std::string_view kSetIfElseConditionKernel = R"(
.version 4.0
.target sm_50
.address_size 64
Expand Down Expand Up @@ -222,7 +224,6 @@ std::string_view GetSetIfElseConditionKernel() {
ret;
})";
}

// PTX kernel compiled from:
//
Expand Down Expand Up @@ -257,8 +258,7 @@ std::string_view GetSetIfElseConditionKernel() {
// }
//
// Easiest way to get PTX from C++ is to use https://godbolt.org.
std::string_view GetSetCaseConditionKernel() {
return R"(
inline constexpr std::string_view kSetCaseConditionKernel = R"(
.version 4.0
.target sm_50
.address_size 64
Expand Down Expand Up @@ -578,7 +578,6 @@ std::string_view GetSetCaseConditionKernel() {
ret;
})";
}

// PTX kernel compiled from:
//
Expand All @@ -594,8 +593,7 @@ std::string_view GetSetCaseConditionKernel() {
// }
//
// Easiest way to get PTX from C++ is to use https://godbolt.org.
std::string_view GetSetForConditionKernel() {
return R"(
inline constexpr std::string_view kSetForConditionKernel = R"(
.version 4.0
.target sm_50
.address_size 64
Expand Down Expand Up @@ -669,11 +667,9 @@ std::string_view GetSetForConditionKernel() {
ret;
})";
}

std::string_view GetSetWhileConditionKernel() {
// While condition kernel is the same as an `If` with a single branch.
return R"(
// While condition kernel is the same as an `If` with a single branch.
inline constexpr std::string_view kSetWhileConditionKernel = R"(
.version 4.0
.target sm_50
.address_size 64
Expand Down Expand Up @@ -739,6 +735,69 @@ std::string_view GetSetWhileConditionKernel() {
ret;
})";

// PTX kernel compiled from:
//
// __global__ void noop() {}
//
// Easiest way to get PTX from C++ is to use https://godbolt.org.
inline constexpr std::string_view kNoOpKernel = R"(
.version 4.0
.target sm_50
.address_size 64
.visible .entry noop()
{
.loc 1 1 0
.loc 1 4 1
ret;
})";

} // namespace
} // namespace cuda

namespace gpu {

absl::StatusOr<MultiKernelLoaderSpec> GetSetIfConditionKernelLoaderSpec() {
MultiKernelLoaderSpec spec(/*arity=*/2);
spec.AddCudaPtxInMemory(cuda::kSetIfConditionKernel, "set_if_condition");
return spec;
}

absl::StatusOr<MultiKernelLoaderSpec> GetSetIfElseConditionKernelLoaderSpec() {
MultiKernelLoaderSpec spec(/*arity=*/3);
spec.AddCudaPtxInMemory(cuda::kSetIfElseConditionKernel,
"set_if_else_condition");
return spec;
}

absl::StatusOr<MultiKernelLoaderSpec> GetSetCaseConditionKernelLoaderSpec() {
MultiKernelLoaderSpec spec(/*arity=*/10);
spec.AddCudaPtxInMemory(cuda::kSetCaseConditionKernel, "set_case_condition");
return spec;
}

absl::StatusOr<MultiKernelLoaderSpec> GetSetForConditionKernelLoaderSpec() {
MultiKernelLoaderSpec spec(/*arity=*/3);
spec.AddCudaPtxInMemory(cuda::kSetForConditionKernel, "set_for_condition");
return spec;
}

absl::StatusOr<MultiKernelLoaderSpec> GetSetWhileConditionKernelLoaderSpec() {
MultiKernelLoaderSpec spec(/*arity=*/2);
spec.AddCudaPtxInMemory(cuda::kSetWhileConditionKernel,
"set_while_condition");
return spec;
}

absl::StatusOr<MultiKernelLoaderSpec> GetNoOpKernelLoaderSpec() {
MultiKernelLoaderSpec spec(/*arity=*/0);
spec.AddCudaPtxInMemory(cuda::kNoOpKernel, "noop");
return spec;
}

} // namespace stream_executor::gpu
} // namespace gpu
} // namespace stream_executor
10 changes: 2 additions & 8 deletions xla/stream_executor/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,6 @@ gpu_only_cc_library(
],
)

gpu_only_cc_library(
name = "gpu_kernels",
hdrs = ["gpu_kernels.h"],
)

gpu_only_cc_library(
name = "gpu_command_buffer",
srcs = ["gpu_command_buffer.cc"],
Expand All @@ -214,7 +209,6 @@ gpu_only_cc_library(
":gpu_driver_header",
":gpu_executor_header",
":gpu_kernel_header",
":gpu_kernels",
":gpu_stream",
":gpu_types_header",
"//xla/stream_executor",
Expand All @@ -240,9 +234,9 @@ gpu_only_cc_library(
"@tsl//tsl/platform:statusor",
] + if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers",
"//xla/stream_executor/cuda:cuda_conditional_kernels",
"//xla/stream_executor/cuda:command_buffer_kernels",
]) + if_rocm_is_configured([
"//xla/stream_executor/rocm:hip_conditional_kernels",
"//xla/stream_executor/rocm:command_buffer_kernels",
]),
)

Expand Down
43 changes: 21 additions & 22 deletions xla/stream_executor/gpu/gpu_command_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ limitations under the License.
#include "xla/stream_executor/gpu/gpu_driver.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/gpu_kernel.h"
#include "xla/stream_executor/gpu/gpu_kernels.h"
#include "xla/stream_executor/gpu/gpu_stream.h"
#include "xla/stream_executor/gpu/gpu_types.h"
#include "xla/stream_executor/kernel.h"
Expand All @@ -58,6 +57,21 @@ limitations under the License.

namespace stream_executor::gpu {

//===----------------------------------------------------------------------===//
// Implementation details device kernels required by GpuCommandBuffer.
//===----------------------------------------------------------------------===//

// See device specific implementations. These are
// various kernels that update Gpu conditionals based on the device memory
// values, and allow implementing on-device control flow via conditional command
// buffers.
absl::StatusOr<MultiKernelLoaderSpec> GetSetIfConditionKernelLoaderSpec();
absl::StatusOr<MultiKernelLoaderSpec> GetSetIfElseConditionKernelLoaderSpec();
absl::StatusOr<MultiKernelLoaderSpec> GetSetCaseConditionKernelLoaderSpec();
absl::StatusOr<MultiKernelLoaderSpec> GetSetForConditionKernelLoaderSpec();
absl::StatusOr<MultiKernelLoaderSpec> GetSetWhileConditionKernelLoaderSpec();
absl::StatusOr<MultiKernelLoaderSpec> GetNoOpKernelLoaderSpec();

using Mode = CommandBuffer::Mode;
using State = CommandBuffer::State;

Expand Down Expand Up @@ -215,8 +229,7 @@ GpuCommandBuffer::Dependencies GpuCommandBuffer::GetBarrier(
absl::StatusOr<GpuCommandBuffer::SetIfConditionKernel*>
GpuCommandBuffer::GetSetIfConditionKernel() {
if (!set_if_condition_kernel_) {
MultiKernelLoaderSpec spec(/*arity=*/2);
spec.AddCudaPtxInMemory(gpu::GetSetIfConditionKernel(), "set_if_condition");
TF_ASSIGN_OR_RETURN(auto spec, GetSetIfConditionKernelLoaderSpec());
TF_ASSIGN_OR_RETURN(
set_if_condition_kernel_,
SetIfConditionKernel::FactoryType::Create(parent_, spec));
Expand All @@ -227,9 +240,7 @@ GpuCommandBuffer::GetSetIfConditionKernel() {
absl::StatusOr<GpuCommandBuffer::SetIfElseConditionKernel*>
GpuCommandBuffer::GetSetIfElseConditionKernel() {
if (!set_if_else_condition_kernel_) {
MultiKernelLoaderSpec spec(/*arity=*/3);
spec.AddCudaPtxInMemory(gpu::GetSetIfElseConditionKernel(),
"set_if_else_condition");
TF_ASSIGN_OR_RETURN(auto spec, GetSetIfElseConditionKernelLoaderSpec());
TF_ASSIGN_OR_RETURN(
set_if_else_condition_kernel_,
SetIfElseConditionKernel::FactoryType::Create(parent_, spec));
Expand All @@ -240,9 +251,7 @@ GpuCommandBuffer::GetSetIfElseConditionKernel() {
absl::StatusOr<GpuCommandBuffer::SetCaseConditionKernel*>
GpuCommandBuffer::GetSetCaseConditionKernel() {
if (!set_case_condition_kernel_) {
MultiKernelLoaderSpec spec(/*arity=*/10);
spec.AddCudaPtxInMemory(gpu::GetSetCaseConditionKernel(),
"set_case_condition");
TF_ASSIGN_OR_RETURN(auto spec, GetSetCaseConditionKernelLoaderSpec());
TF_ASSIGN_OR_RETURN(
set_case_condition_kernel_,
SetCaseConditionKernel::FactoryType::Create(parent_, spec));
Expand All @@ -253,9 +262,7 @@ GpuCommandBuffer::GetSetCaseConditionKernel() {
absl::StatusOr<GpuCommandBuffer::SetForConditionKernel*>
GpuCommandBuffer::GetSetForConditionKernel() {
if (!set_for_condition_kernel_) {
MultiKernelLoaderSpec spec(/*arity=*/3);
spec.AddCudaPtxInMemory(gpu::GetSetForConditionKernel(),
"set_for_condition");
TF_ASSIGN_OR_RETURN(auto spec, GetSetForConditionKernelLoaderSpec());
TF_ASSIGN_OR_RETURN(
set_for_condition_kernel_,
SetForConditionKernel::FactoryType::Create(parent_, spec));
Expand All @@ -266,9 +273,7 @@ GpuCommandBuffer::GetSetForConditionKernel() {
absl::StatusOr<GpuCommandBuffer::SetWhileConditionKernel*>
GpuCommandBuffer::GetSetWhileConditionKernel() {
if (!set_while_condition_kernel_) {
MultiKernelLoaderSpec spec(/*arity=*/2);
spec.AddCudaPtxInMemory(gpu::GetSetWhileConditionKernel(),
"set_while_condition");
TF_ASSIGN_OR_RETURN(auto spec, GetSetWhileConditionKernelLoaderSpec());
TF_ASSIGN_OR_RETURN(
set_while_condition_kernel_,
SetWhileConditionKernel::FactoryType::Create(parent_, spec));
Expand All @@ -278,18 +283,12 @@ GpuCommandBuffer::GetSetWhileConditionKernel() {

absl::StatusOr<GpuCommandBuffer::NoOpKernel*>
GpuCommandBuffer::GetNoOpKernel() {
#if !defined(TENSORFLOW_USE_ROCM)
if (!noop_kernel_) {
MultiKernelLoaderSpec spec(/*arity=*/0);
spec.AddCudaPtxInMemory(gpu::kNoOpKernel, "noop");
TF_ASSIGN_OR_RETURN(auto spec, GetNoOpKernelLoaderSpec());
TF_ASSIGN_OR_RETURN(noop_kernel_,
NoOpKernel::FactoryType::Create(parent_, spec));
}
return &noop_kernel_;
#else
return absl::UnimplementedError(
"GpuCommandBuffer::GetNoOpKernel is not implemented.");
#endif // TENSORFLOW_USE_ROCM
}

absl::Status GpuCommandBuffer::DisableBarriersExecution(
Expand Down
14 changes: 0 additions & 14 deletions xla/stream_executor/gpu/gpu_command_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -353,20 +353,6 @@ class GpuCommandBuffer : public CommandBuffer {
NoOpKernel noop_kernel_;
};

//===----------------------------------------------------------------------===//
// Implementation details device kernels required by GpuCommandBuffer.
//===----------------------------------------------------------------------===//

// See `cuda_conditional_kernels.cc` for CUDA implementation. These are
// various kernels that update Gpu conditionals based on the device memory
// values, and allow implementing on-device control flow via conditional command
// buffers.
std::string_view GetSetIfConditionKernel();
std::string_view GetSetIfElseConditionKernel();
std::string_view GetSetCaseConditionKernel();
std::string_view GetSetForConditionKernel();
std::string_view GetSetWhileConditionKernel();

} // namespace stream_executor::gpu

#endif // XLA_STREAM_EXECUTOR_GPU_GPU_COMMAND_BUFFER_H_
47 changes: 0 additions & 47 deletions xla/stream_executor/gpu/gpu_kernels.h

This file was deleted.

Loading

0 comments on commit dfefe81

Please sign in to comment.