Skip to content

Commit

Permalink
Move GpuDriver::GetDeviceCount functionality into the appropriate Pla…
Browse files Browse the repository at this point in the history
…tform.

PiperOrigin-RevId: 695492725
  • Loading branch information
klucke authored and Google-ML-Automation committed Nov 11, 2024
1 parent 1067d1a commit 0fd7aac
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 71 deletions.
1 change: 0 additions & 1 deletion xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ cc_library(
"//xla/stream_executor:platform_manager",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor/gpu:gpu_diagnostics_header",
"//xla/stream_executor/gpu:gpu_driver_header",
"//xla/stream_executor/platform:initialize",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
Expand Down
11 changes: 0 additions & 11 deletions xla/stream_executor/cuda/cuda_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,6 @@ limitations under the License.
namespace stream_executor {
namespace gpu {

int GpuDriver::GetDeviceCount() {
int device_count = 0;
auto status = cuda::ToStatus(cuDeviceGetCount(&device_count));
if (!status.ok()) {
LOG(ERROR) << "could not retrieve CUDA device count: " << status;
return 0;
}

return device_count;
}

absl::StatusOr<int32_t> GpuDriver::GetDriverVersion() {
int32_t version;
TF_RETURN_IF_ERROR(cuda::ToStatus(cuDriverGetVersion(&version),
Expand Down
10 changes: 8 additions & 2 deletions xla/stream_executor/cuda/cuda_platform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ limitations under the License.
#include "xla/stream_executor/cuda/cuda_status.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/gpu/gpu_diagnostics.h"
#include "xla/stream_executor/gpu/gpu_driver.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform/initialize.h"
#include "xla/stream_executor/platform_manager.h"
Expand Down Expand Up @@ -74,7 +73,14 @@ int CudaPlatform::VisibleDeviceCount() const {
// Initialized in a thread-safe manner the first time this is run.
static const int num_devices = [] {
if (!PlatformInitialize().ok()) return -1;
return GpuDriver::GetDeviceCount();
int device_count = 0;
auto status = cuda::ToStatus(cuDeviceGetCount(&device_count));
if (!status.ok()) {
LOG(ERROR) << "could not retrieve CUDA device count: " << status;
return 0;
}

return device_count;
}();
return num_devices;
}
Expand Down
42 changes: 0 additions & 42 deletions xla/stream_executor/gpu/gpu_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,9 @@ limitations under the License.
#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_DRIVER_H_
#define XLA_STREAM_EXECUTOR_GPU_GPU_DRIVER_H_

#include <stddef.h>

#include <cstdint>
#include <string>
#include <utility>
#include <variant>
#include <vector>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/stream_executor/gpu/context.h"
#include "xla/stream_executor/gpu/gpu_types.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream_executor.h"

namespace stream_executor {
namespace gpu {
Expand All @@ -57,35 +44,6 @@ namespace gpu {
// Thread safety: these functions should not be used from signal handlers.
class GpuDriver {
public:
// The CUDA stream callback type signature.
// The data passed to AddStreamCallback is subsequently passed to this
// callback when it fires.
//
// Some notable things:
// * Callbacks must not make any CUDA API calls.
// * Callbacks from independent streams execute in an undefined order and may
// be serialized.
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gab95a78143bae7f21eebb978f91e7f3f
typedef void (*StreamCallback)(void* data);

// Blocks the calling thread until the operations enqueued onto stream have
// been completed, via cuStreamSynchronize.
//
// TODO(leary) if a pathological thread enqueues operations onto the stream
// while another thread blocks like this, can you wind up waiting an unbounded
// amount of time?
//
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g15e49dd91ec15991eb7c0a741beb7dad
static absl::Status SynchronizeStream(Context* context,
GpuStreamHandle stream);

// -- Context- and device-independent calls.

// Returns the number of visible CUDA device via cuDeviceGetCount.
// This should correspond to the set of device ordinals available.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g52b5ce05cb8c5fb6831b2c0ff2887c74
static int GetDeviceCount();

// Returns the driver version number via cuDriverGetVersion.
// This is, surprisingly, NOT the actual driver version (e.g. 331.79) but,
// instead, the CUDA toolkit release number that this driver is compatible
Expand Down
1 change: 0 additions & 1 deletion xla/stream_executor/rocm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,6 @@ cc_library(
"//xla/stream_executor:platform_manager",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor/gpu:gpu_diagnostics_header",
"//xla/stream_executor/gpu:gpu_driver_header",
"//xla/stream_executor/platform:initialize",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
Expand Down
11 changes: 0 additions & 11 deletions xla/stream_executor/rocm/rocm_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,6 @@ limitations under the License.

namespace stream_executor::gpu {

int GpuDriver::GetDeviceCount() {
int device_count = 0;
hipError_t res = wrap::hipGetDeviceCount(&device_count);
if (res != hipSuccess) {
LOG(ERROR) << "could not retrieve ROCM device count: " << ToString(res);
return 0;
}

return device_count;
}

absl::StatusOr<int32_t> GpuDriver::GetDriverVersion() {
int32_t version;
TF_RETURN_IF_ERROR(ToStatus(wrap::hipDriverGetVersion(&version),
Expand Down
10 changes: 8 additions & 2 deletions xla/stream_executor/rocm/rocm_platform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/gpu/gpu_diagnostics.h"
#include "xla/stream_executor/gpu/gpu_driver.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform/initialize.h"
#include "xla/stream_executor/platform_manager.h"
Expand Down Expand Up @@ -76,7 +75,14 @@ int ROCmPlatform::VisibleDeviceCount() const {
return -1;
}

return GpuDriver::GetDeviceCount();
int device_count = 0;
hipError_t res = wrap::hipGetDeviceCount(&device_count);
if (res != hipSuccess) {
LOG(ERROR) << "could not retrieve ROCM device count: " << ToString(res);
return 0;
}

return device_count;
}

const std::string& ROCmPlatform::Name() const { return name_; }
Expand Down
2 changes: 1 addition & 1 deletion xla/stream_executor/sycl/sycl_platform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Platform::Id SyclPlatform::id() const { return sycl::kSyclPlatformId; }

int SyclPlatform::VisibleDeviceCount() const {
// Initialized in a thread-safe manner the first time this is run.
static const int num_devices = [] { return GpuDriver::GetDeviceCount(); }();
static const int num_devices = [] { return 0; }();
return num_devices;
}

Expand Down

0 comments on commit 0fd7aac

Please sign in to comment.