Skip to content

Commit

Permalink
Introduce derived classes CudaKernel and RocmKernel
Browse files Browse the repository at this point in the history
This change makes `GpuKernel` an abstract base class
and moves its implementation into the derived classes
`CudaKernel` and `RocmKernel`.

This avoids having two implementations for the same functions
and also reduces the exposure of gpu_types.h which we want to
get rid of.

I'm also adding some basic tests for the new classes.

PiperOrigin-RevId: 679003532
  • Loading branch information
beckerhe authored and Google-ML-Automation committed Sep 26, 2024
1 parent 88648d3 commit 7e2ff2b
Show file tree
Hide file tree
Showing 11 changed files with 334 additions and 51 deletions.
24 changes: 24 additions & 0 deletions xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -527,12 +527,36 @@ cuda_only_cc_library(
cuda_only_cc_library(
name = "cuda_kernel",
srcs = ["cuda_kernel.cc"],
hdrs = ["cuda_kernel.h"],
deps = [
"//xla/stream_executor",
"//xla/stream_executor/gpu:gpu_driver_header",
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/stream_executor/gpu:gpu_kernel_header",
"//xla/stream_executor/gpu:gpu_types_header",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@tsl//tsl/platform:logging",
],
)

xla_test(
name = "cuda_kernel_test",
srcs = ["cuda_kernel_test.cc"],
backends = ["gpu_any"],
deps = [
":cuda_kernel",
":cuda_runtime",
"//xla/stream_executor:launch_dim",
"//xla/stream_executor:platform",
"//xla/stream_executor:platform_manager",
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/stream_executor/gpu:gpu_test_kernels_cuda",
"@com_google_googletest//:gtest_main",
"@local_config_cuda//cuda:cuda_headers",
"@tsl//tsl/platform:status_matchers",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)

Expand Down
3 changes: 2 additions & 1 deletion xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ limitations under the License.
#include "xla/stream_executor/command_buffer.h"
#include "xla/stream_executor/cuda/cuda_collectives.h"
#include "xla/stream_executor/cuda/cuda_event.h"
#include "xla/stream_executor/cuda/cuda_kernel.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/cuda/cuda_runtime.h"
#include "xla/stream_executor/cuda/cuda_status.h"
Expand Down Expand Up @@ -191,7 +192,7 @@ absl::Status CudaExecutor::LoadModuleFromHsaco(const char* hsaco,

absl::StatusOr<std::unique_ptr<Kernel>> CudaExecutor::LoadKernel(
const MultiKernelLoaderSpec& spec) {
auto cuda_kernel = std::make_unique<GpuKernel>(this);
auto cuda_kernel = std::make_unique<CudaKernel>(this);
CUmodule module;
const std::string* kernel_name;

Expand Down
11 changes: 6 additions & 5 deletions xla/stream_executor/cuda/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,29 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/stream_executor/cuda/cuda_kernel.h"

#include <cstddef>
#include <cstdint>

#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "xla/stream_executor/gpu/gpu_driver.h"
#include "xla/stream_executor/gpu/gpu_kernel.h"
#include "xla/stream_executor/launch_dim.h"

namespace stream_executor {
namespace gpu {

absl::StatusOr<int32_t> GpuKernel::GetMaxOccupiedBlocksPerCore(
absl::StatusOr<int32_t> CudaKernel::GetMaxOccupiedBlocksPerCore(
ThreadDim threads, size_t dynamic_shared_memory_bytes) const {
int32_t threads_per_block = threads.x * threads.y * threads.z;
VLOG(3) << "Get kernel block occupancy: " << name()
<< "; threads_per_block: " << threads_per_block
<< "; dynamic_shared_memory_bytes: " << dynamic_shared_memory_bytes;

return GpuDriver::GetMaxOccupiedBlocksPerCore(gpu_context_, gpu_function_,
threads_per_block,
dynamic_shared_memory_bytes);
return GpuDriver::GetMaxOccupiedBlocksPerCore(
gpu_executor_->gpu_context(), gpu_function_, threads_per_block,
dynamic_shared_memory_bytes);
}

} // namespace gpu
Expand Down
69 changes: 69 additions & 0 deletions xla/stream_executor/cuda/cuda_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/* Copyright 2019 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// The CUDA implementation of the StreamExecutor functionality.
// CUDA inclusions are ideally confined to this implementation file.
//
// The notions from the StreamExecutor basically correspond to the CUDA streams
// programming model provided by the libcuda.so driver APIs, so we don't have
// to do much more than wrap the calls to the libraries appropriately.
#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_KERNEL_H_
#define XLA_STREAM_EXECUTOR_CUDA_CUDA_KERNEL_H_

#include <cstddef>
#include <cstdint>

#include "absl/status/statusor.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/gpu_kernel.h"
#include "xla/stream_executor/gpu/gpu_types.h"
#include "xla/stream_executor/launch_dim.h"
#include "tsl/platform/logging.h"

namespace stream_executor::gpu {

class CudaKernel : public GpuKernel {
public:
explicit CudaKernel(GpuExecutor* gpu_executor)
: gpu_executor_(gpu_executor) {}

// Note that the function is unloaded when the module is unloaded, and the
// module that the function is contained in is owned by the GpuExecutor.
~CudaKernel() override { gpu_executor_->UnloadKernel(this); }

// As arity cannot be reflected upon using the CUDA API, the arity is
// explicitly set during the GpuExecutor::GetKernel initialization process.
void set_arity(unsigned arity) { arity_ = arity; }
unsigned Arity() const override { return arity_; }

absl::StatusOr<int32_t> GetMaxOccupiedBlocksPerCore(
ThreadDim threads, size_t dynamic_shared_memory_bytes) const override;

// Simple accessor methods.
GpuFunctionHandle gpu_function() const override { return gpu_function_; }
void set_gpu_function(GpuFunctionHandle gpu_function) {
gpu_function_ = gpu_function;
}

private:
GpuExecutor* gpu_executor_ = nullptr;

CUfunction gpu_function_ = nullptr; // wrapped CUDA kernel handle
unsigned arity_ = 0; // number of formal parameters the kernel takes
};

} // namespace stream_executor::gpu

#endif // XLA_STREAM_EXECUTOR_CUDA_CUDA_KERNEL_H_
60 changes: 60 additions & 0 deletions xla/stream_executor/cuda/cuda_kernel_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/stream_executor/cuda/cuda_kernel.h"

#include <gtest/gtest.h>
#include "third_party/gpus/cuda/include/cuda.h"
#include "xla/stream_executor/cuda/cuda_runtime.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/gpu_test_kernels.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

namespace stream_executor::gpu {
namespace {
using testing::Ge;
using tsl::testing::IsOkAndHolds;

TEST(CudaKernelTest, GetMaxOccupiedBlocksPerCore) {
TF_ASSERT_OK_AND_ASSIGN(Platform * platform,
PlatformManager::PlatformWithName("CUDA"));
TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor,
platform->ExecutorForDevice(0));
GpuExecutor* gpu_executor = ExtractGpuExecutor(executor);

CudaKernel cuda_kernel(gpu_executor);
cuda_kernel.set_arity(3);

TF_ASSERT_OK_AND_ASSIGN(
CUfunction function,
CudaRuntime::GetFuncBySymbol(internal::GetAddI32Kernel()));

cuda_kernel.set_gpu_function(function);

EXPECT_EQ(cuda_kernel.Arity(), 3);
EXPECT_EQ(cuda_kernel.gpu_function(), function);

EXPECT_THAT(cuda_kernel.GetMaxOccupiedBlocksPerCore(
ThreadDim(1, 1, 1), /*dynamic_shared_memory_bytes=*/0),
IsOkAndHolds(Ge(1)));
}

} // namespace
} // namespace stream_executor::gpu
41 changes: 3 additions & 38 deletions xla/stream_executor/gpu/gpu_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,51 +22,16 @@ limitations under the License.
#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_KERNEL_H_
#define XLA_STREAM_EXECUTOR_GPU_GPU_KERNEL_H_

#include <cstddef>
#include <cstdint>
#include <string>
#include <utility>

#include "absl/status/statusor.h"
#include "xla/stream_executor/gpu/context.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/gpu_types.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/launch_dim.h"
#include "tsl/platform/logging.h"

namespace stream_executor::gpu {

// A GpuKernel is a `Kernel` that can be launched on a GPU. It allows
// access to the underlying GPU function through `gpu_function()`.
class GpuKernel : public Kernel {
public:
explicit GpuKernel(GpuExecutor* gpu_executor)
: gpu_executor_(gpu_executor),
gpu_context_(gpu_executor->gpu_context()) {}

// Note that the function is unloaded when the module is unloaded, and the
// module that the function is contained in is owned by the GpuExecutor.
~GpuKernel() override { gpu_executor_->UnloadKernel(this); }

// As arity cannot be reflected upon using the CUDA API, the arity is
// explicitly set during the GpuExecutor::GetKernel initialization process.
void set_arity(unsigned arity) { arity_ = arity; }
unsigned Arity() const override { return arity_; }

absl::StatusOr<int32_t> GetMaxOccupiedBlocksPerCore(
ThreadDim threads, size_t dynamic_shared_memory_bytes) const override;

// Simple accessor methods.
GpuFunctionHandle gpu_function() const { return gpu_function_; }
void set_gpu_function(GpuFunctionHandle gpu_function) {
gpu_function_ = gpu_function;
}

private:
GpuExecutor* gpu_executor_ = nullptr;
Context* gpu_context_ = nullptr; // context where kernel is loaded

GpuFunctionHandle gpu_function_ = nullptr; // wrapped CUDA kernel handle
unsigned arity_ = 0; // number of formal parameters the kernel takes
virtual GpuFunctionHandle gpu_function() const = 0;
};

inline const GpuKernel* AsGpuKernel(const Kernel* kernel) {
Expand Down
29 changes: 28 additions & 1 deletion xla/stream_executor/rocm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ load(
"//xla/stream_executor:build_defs.bzl",
"stream_executor_friends",
)
load("//xla/tests:build_defs.bzl", "xla_test")
load(
"//xla/tsl:tsl.bzl",
"if_google",
Expand Down Expand Up @@ -219,6 +220,7 @@ cc_library(
cc_library(
name = "rocm_kernel",
srcs = ["rocm_kernel.cc"],
hdrs = ["rocm_kernel.h"],
tags = [
"gpu",
"rocm-only",
Expand All @@ -228,10 +230,35 @@ cc_library(
]),
visibility = ["//visibility:public"],
deps = [
"//xla/stream_executor:launch_dim",
"//xla/stream_executor/gpu:gpu_driver_header",
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/stream_executor/gpu:gpu_kernel_header",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@local_config_rocm//rocm:rocm_headers",
"@tsl//tsl/platform:logging",
],
)

xla_test(
name = "rocm_kernel_test",
srcs = ["rocm_kernel_test.cc"],
backends = ["gpu_amd_any"],
deps = [
":rocm_kernel",
":rocm_runtime",
"//xla/stream_executor:launch_dim",
"//xla/stream_executor:platform",
"//xla/stream_executor:platform_manager",
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/stream_executor/gpu:gpu_test_kernels",
"@com_google_googletest//:gtest_main",
"@local_config_rocm//rocm:rocm_headers",
"@tsl//tsl/platform:status_matchers",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
alwayslink = True,
)

cc_library(
Expand Down
4 changes: 3 additions & 1 deletion xla/stream_executor/rocm/rocm_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "rocm/include/hip/hip_runtime.h"
#include "rocm/include/hip/hip_version.h"
#include "rocm/rocm_config.h"
#include "xla/stream_executor/blas.h"
Expand Down Expand Up @@ -71,6 +72,7 @@ limitations under the License.
#include "xla/stream_executor/rocm/rocm_driver.h"
#include "xla/stream_executor/rocm/rocm_driver_wrapper.h"
#include "xla/stream_executor/rocm/rocm_event.h"
#include "xla/stream_executor/rocm/rocm_kernel.h"
#include "xla/stream_executor/rocm/rocm_platform_id.h"
#include "xla/stream_executor/rocm/rocm_runtime.h"
#include "xla/stream_executor/rocm/rocm_version_parser.h"
Expand Down Expand Up @@ -273,7 +275,7 @@ absl::Status RocmExecutor::Init() {

absl::StatusOr<std::unique_ptr<Kernel>> RocmExecutor::LoadKernel(
const MultiKernelLoaderSpec& spec) {
auto rocm_kernel = std::make_unique<GpuKernel>(this);
auto rocm_kernel = std::make_unique<RocmKernel>(this);
hipModule_t module = nullptr;
const std::string* kernel_name;

Expand Down
15 changes: 10 additions & 5 deletions xla/stream_executor/rocm/rocm_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,29 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/stream_executor/rocm/rocm_kernel.h"

#include <cstddef>
#include <cstdint>

#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "xla/stream_executor/gpu/gpu_driver.h"
#include "xla/stream_executor/gpu/gpu_kernel.h"
#include "xla/stream_executor/launch_dim.h"

namespace stream_executor {
namespace gpu {

absl::StatusOr<int32_t> GpuKernel::GetMaxOccupiedBlocksPerCore(
absl::StatusOr<int32_t> RocmKernel::GetMaxOccupiedBlocksPerCore(
ThreadDim threads, size_t dynamic_shared_memory_bytes) const {
int32_t threads_per_block = threads.x * threads.y * threads.z;
VLOG(0) << "Get kernel block occupancy: " << name()
<< "; threads_per_block: " << threads_per_block
<< "; dynamic_shared_memory_bytes: " << dynamic_shared_memory_bytes;

return GpuDriver::GetMaxOccupiedBlocksPerCore(gpu_context_, gpu_function_,
threads_per_block,
dynamic_shared_memory_bytes);
return GpuDriver::GetMaxOccupiedBlocksPerCore(
gpu_executor_->gpu_context(), rocm_function_, threads_per_block,
dynamic_shared_memory_bytes);
}

} // namespace gpu
Expand Down
Loading

0 comments on commit 7e2ff2b

Please sign in to comment.