Skip to content

Commit

Permalink
PR #16520: [ROCM] ResetStream function for GemmAlgorithmPicker (BlasS…
Browse files Browse the repository at this point in the history
…upport interface)

Imported from GitHub PR #16520

Here I added **ResetStream** function which sets the underlying stream for cublas/rocblas libraries to default stream 0.

This is useful for GemmAlgorithmPicker which uses a temporary stream object for autotuning. In rocblas, **rocblas_set_stream** function is **persistent**, meaning that once the stream value is set, it will be used in all subsequent computations until new stream value is set.

In case of GemmAlgorithmPicker, we leave a **destroyed** stream object set into the math library. This does not produce any error behaviour but merely just a warning on ROCM side: "Stream Capture Check Failed".

With this new ResetStream function, one can reset the stream value in GemmAlgorithmPicker destructor. Potentially, it can also be useful in other places where temporary stream value is used.

Besides, I have also made some small code restructure for GemmAlgorithmPicker

@xla-rotation: could you have a look please?

Copybara import of the project:

--
2bd0cf2 by Pavel Emeliyanenko <[email protected]>:

set stream to null at the end of rocm_blas gemm function call

--
436d073 by Pavel Emeliyanenko <[email protected]>:

fixing buildbreaks

--
9347c71 by Pavel Emeliyanenko <[email protected]>:

added test for reset_stream

--
bb009b0 by Pavel Emeliyanenko <[email protected]>:

changed IsMainStreamSet interface

Merging this change closes #16520

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16520 from ROCm:ci_blas_reset_stream bb009b0
PiperOrigin-RevId: 679049775
  • Loading branch information
pemeliya authored and Google-ML-Automation committed Sep 30, 2024
1 parent 9e9b500 commit 3cb1e67
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 56 deletions.
29 changes: 13 additions & 16 deletions xla/service/gpu/autotuning/gemm_algorithm_picker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class GemmAutotuner {
explicit GemmAutotuner(const AutotuneConfig& autotune_config)
: autotune_config_(autotune_config) {}

const AutotuneConfig& config() const { return autotune_config_; }

size_t num_algorithms_left() const { return num_algorithms_left_; }

absl::StatusOr<AutotuneResult> operator()(const HloInstruction* gemm,
Expand Down Expand Up @@ -388,34 +390,31 @@ class GemmAutotuner {
<< best.status();
return AutotuneResult{};
} // GetBestAlgorithm
}; // GemmAutotuner
}; // class GemmAutotuner

// Do Gemm Autotune without stream executor. Use results from autotune cache
// only.
absl::StatusOr<bool> RunOnInstruction(HloInstruction* gemm,
const AutotuneConfig& config,
size_t* num_algorithms_left) {
GemmAutotuner& autotuner) {
VLOG(3) << "Loading the autotune result of GemmThunk " << gemm->ToString();

GpuBackendConfig gpu_config =
gemm->backend_config<GpuBackendConfig>().value();
GemmBackendConfig& backend_config = *gpu_config.mutable_gemm_backend_config();

*num_algorithms_left = 0;
// Degenerate gemms replaced with memzero operation, no need to auto tune it.
if (backend_config.alpha_real() == 0.0 &&
backend_config.alpha_imag() == 0.0 && backend_config.beta() == 0.0) {
VLOG(3) << "Skip degenerate gemm instruction auto tuning";
return false;
}

const AutotuneConfig& config = autotuner.config();
AutotuneCacheKey key(config.GetModelStr(), *gemm);
GemmAutotuner autotuner(config);
TF_ASSIGN_OR_RETURN(AutotuneResult algorithm,
AutotunerUtil::Autotune(
gemm, config, [&] { return autotuner(gemm, key); }));

*num_algorithms_left = autotuner.num_algorithms_left();
auto old_algorithm = backend_config.selected_algorithm();
bool update_algorithm =
IsCublasLtMatmulF8(*gemm) ||
Expand All @@ -442,9 +441,8 @@ absl::StatusOr<bool> RunOnInstruction(HloInstruction* gemm,

if (new_algorithm == old_algorithm &&
backend_config.has_selected_algorithm()) {
// We don't need to update the backend config if
// the algorithm hasn't changed unless previously
// the algorithm wasn't set explicitly.
// We don't need to update the backend config if the algorithm was not
// changed unless previously the algorithm wasn't set explicitly.
return false;
}

Expand All @@ -457,17 +455,16 @@ absl::StatusOr<bool> RunOnInstruction(HloInstruction* gemm,
}

absl::StatusOr<bool> RunOnComputation(HloComputation* computation,
AutotuneConfig config,
GemmAutotuner& autotuner,
size_t* num_algorithms_left) {
bool changed = false;

for (HloInstruction* instr : computation->instructions()) {
if (IsCublasGemm(*instr)) {
size_t num_left;
TF_ASSIGN_OR_RETURN(bool result,
RunOnInstruction(instr, config, &num_left));
TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr, autotuner));
// Gathering statistics on the algorithms left after tuning (for testing)
*num_algorithms_left = std::max(*num_algorithms_left, num_left);
*num_algorithms_left =
std::max(*num_algorithms_left, autotuner.num_algorithms_left());
changed |= result;
}
}
Expand All @@ -487,11 +484,11 @@ absl::StatusOr<bool> GemmAlgorithmPicker::Run(
VLOG(2) << "GEMM auto-tuning disabled, GemmAlgorithmPicker returning early";
return false;
}

GemmAutotuner autotuner(config_);
bool changed = false;
for (HloComputation* computation :
module->MakeNonfusionComputations(execution_threads)) {
TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation, config_,
TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation, autotuner,
&num_algorithms_left_));
changed |= result;
}
Expand Down
21 changes: 13 additions & 8 deletions xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,14 @@ class GemmAlgorithmPickerTest : public HloTestBase,
return debug_options;
}

const se::DeviceDescription& device_desc() {
return backend().default_stream_executor()->GetDeviceDescription();
}

se::StreamExecutor* stream_exec() {
return backend().default_stream_executor();
}
const se::DeviceDescription& gpu_device_desc() {
const se::DeviceDescription& device_desc() {
return stream_exec()->GetDeviceDescription();
}
const se::GpuComputeCapability& gpu_comp() {
return gpu_device_desc().gpu_compute_capability();
return device_desc().gpu_compute_capability();
}

void SetUp() override {
Expand Down Expand Up @@ -103,7 +99,7 @@ class GemmAlgorithmPickerTest : public HloTestBase,
};

TEST_P(GemmAlgorithmPickerTest, BlasGetVersion) {
auto* blas = backend().default_stream_executor()->AsBlas();
auto* blas = stream_exec()->AsBlas();
ASSERT_TRUE(blas != nullptr);
std::string version;
ASSERT_TRUE(blas->GetVersion(&version).ok());
Expand Down Expand Up @@ -148,6 +144,15 @@ ENTRY main {
if (num_left1 < 2) {
GTEST_SKIP() << "Too few algorithms left after the first step";
}

// Test that the function to get current stream value works fine:
auto* blas = stream_exec()->AsBlas();
ASSERT_TRUE(blas != nullptr);
TF_ASSERT_OK_AND_ASSIGN(bool is_main_stream, blas->IsMainStreamSet());
// ROCM only: CUDA blas API does not reset stream after each blas call.
if (std::holds_alternative<se::RocmComputeCapability>(gpu_comp())) {
ASSERT_TRUE(is_main_stream);
}
}

// Clear cache before the second run!
Expand Down Expand Up @@ -291,7 +296,7 @@ ENTRY main {
TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo, module_cfg));
changed = false;

DevicelessConfig deviceless_config{gpu_device_desc()};
DevicelessConfig deviceless_config{device_desc()};
AutotuneConfig deviceless_cfg{deviceless_config, opts};
TF_ASSERT_OK_AND_ASSIGN(
changed,
Expand Down
1 change: 1 addition & 0 deletions xla/stream_executor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ cc_library(
"//xla/tsl/protobuf:dnn_proto_cc",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:errors",
Expand Down
6 changes: 6 additions & 0 deletions xla/stream_executor/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include <vector>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/stream_executor/data_type.h"
#include "xla/stream_executor/device_memory.h"
Expand Down Expand Up @@ -222,6 +223,10 @@ class BlasSupport {

virtual gpu::BlasLt *GetBlasLt() = 0;

// For tests only: sets *is_main_stream to true if the underlying Blas library
// has stream 0 set as its current stream.
virtual absl::StatusOr<bool> IsMainStreamSet() const = 0;

// Computes the product of a vector by a scalar: x <- a*x.
virtual bool DoBlasScal(Stream *stream, uint64_t elem_count, float alpha,
DeviceMemory<float> *x, int incx) = 0;
Expand Down Expand Up @@ -727,6 +732,7 @@ class BlasSupport {
// Macro used to quickly declare overrides for abstract virtuals in the
// BlasSupport base class.
#define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \
absl::StatusOr<bool> IsMainStreamSet() const override; \
bool DoBlasScal(Stream *stream, uint64_t elem_count, float alpha, \
DeviceMemory<float> *x, int incx) override; \
bool DoBlasScal(Stream *stream, uint64_t elem_count, double alpha, \
Expand Down
21 changes: 11 additions & 10 deletions xla/stream_executor/cuda/cuda_blas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,25 +229,26 @@ CUDABlas::~CUDABlas() {
}

bool CUDABlas::SetStream(Stream *stream) {
CHECK(stream != nullptr);
CHECK(AsGpuStreamValue(stream) != nullptr);
CHECK(blas_ != nullptr);
gpu::ScopedActivateContext sac{parent_};

cublasStatus_t ret = cublasSetStream(blas_, AsGpuStreamValue(stream));
if (ret != CUBLAS_STATUS_SUCCESS) {
auto handle = (stream != nullptr) ? AsGpuStreamValue(stream) : 0;
if (auto ret = cublasSetStream(blas_, handle); ret != CUBLAS_STATUS_SUCCESS) {
LOG(ERROR) << "failed to set stream for cuBLAS calls: " << ToString(ret);
return false;
}

return true;
}

cudaStream_t CUDABlas::CUDAStream(Stream *stream) {
CHECK(stream != nullptr);
CHECK(AsGpuStreamValue(stream) != nullptr);
gpu::ScopedActivateContext sac{parent_};
return AsGpuStreamValue(stream);
absl::StatusOr<bool> CUDABlas::IsMainStreamSet() const {
CHECK(blas_ != nullptr);
absl::MutexLock lock{&mu_};
GpuStreamHandle handle{};
if (auto ret = cublasGetStream(blas_, &handle);
ret != CUBLAS_STATUS_SUCCESS) {
return absl::InternalError("failed to get the current stream value");
}
return (handle == 0);
}

namespace {
Expand Down
5 changes: 1 addition & 4 deletions xla/stream_executor/cuda/cuda_blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@ class CUDABlas : public blas::BlasSupport {
// invoked before calling into cuBLAS.
bool SetStream(Stream *stream) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);

// Returns the underlying CUDA stream.
cudaStream_t CUDAStream(Stream *stream);

// A helper function that calls the real cuBLAS function together with error
// handling.
//
Expand Down Expand Up @@ -114,7 +111,7 @@ class CUDABlas : public blas::BlasSupport {
ScratchAllocator *scratch_allocator);

// Guards the cuBLAS handle for this device.
absl::Mutex mu_;
mutable absl::Mutex mu_;

// GpuExecutor which instantiated this CUDABlas.
// Immutable post-initialization.
Expand Down
29 changes: 15 additions & 14 deletions xla/stream_executor/rocm/rocm_blas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,26 +154,25 @@ ROCMBlas::~ROCMBlas() {
}

bool ROCMBlas::SetStream(Stream *stream) {
CHECK(stream != nullptr);
CHECK(AsGpuStreamValue(stream) != nullptr);
CHECK(blas_ != nullptr);
ScopedActivateContext sac{parent_};

rocblas_status ret =
wrap::rocblas_set_stream(blas_, AsGpuStreamValue(stream));
if (ret != rocblas_status_success) {
auto handle = (stream != nullptr) ? AsGpuStreamValue(stream) : 0;
if (auto ret = wrap::rocblas_set_stream(blas_, handle);
ret != rocblas_status_success) {
LOG(ERROR) << "failed to set stream for rocBLAS calls: " << ToString(ret);
return false;
}

return true;
}

hipStream_t ROCMBlas::ROCMStream(Stream *stream) {
CHECK(stream != nullptr);
CHECK(AsGpuStreamValue(stream) != nullptr);
ScopedActivateContext sac{parent_};
return AsGpuStreamValue(stream);
absl::StatusOr<bool> ROCMBlas::IsMainStreamSet() const {
CHECK(blas_ != nullptr);
absl::MutexLock lock{&mu_};
GpuStreamHandle handle{};
if (auto ret = wrap::rocblas_get_stream(blas_, &handle);
ret != rocblas_status_success) {
return absl::InternalError("failed to get the current stream value");
}
return (handle == 0);
}

namespace {
Expand Down Expand Up @@ -351,11 +350,11 @@ absl::Status ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
absl::MutexLock lock{&mu_};

CHECK(blas_ != nullptr);
ScopedActivateContext sac{parent_};
if (!SetStream(stream)) {
return absl::InternalError("Setting stream failed");
}

ScopedActivateContext sac{parent_};
rocblas_status ret;
// set the atomics mode, leaving default to library
bool allow_atomics = !OpDeterminismRequired();
Expand Down Expand Up @@ -383,6 +382,8 @@ absl::Status ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
#endif

ret = rocblas_func(blas_, std::forward<Args>(args)...);
SetStream(nullptr); // Resetting stream after the function call

if (ret != rocblas_status_success) {
auto err_str =
absl::StrFormat("%s failed with: %s", FuncT::kName, ToString(ret));
Expand Down
5 changes: 1 addition & 4 deletions xla/stream_executor/rocm/rocm_blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,6 @@ class ROCMBlas : public blas::BlasSupport {
// invoked before calling into rocBLAS.
bool SetStream(Stream *stream) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);

// Returns the underlying ROCm stream
hipStream_t ROCMStream(Stream *stream);

// A helper function that calls the real rocBLAS function together with error
// handling.
//
Expand Down Expand Up @@ -188,7 +185,7 @@ class ROCMBlas : public blas::BlasSupport {
ScratchAllocator *scratch_allocator);

// mutex that guards the rocBLAS handle for this device.
absl::Mutex mu_;
mutable absl::Mutex mu_;

// GpuExecutor which instantiated this ROCMBlas.
// Immutable post-initialization.
Expand Down

0 comments on commit 3cb1e67

Please sign in to comment.