Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] propagate the algorithm flag of dot op to cublasGemm custom call. #17595

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xla/service/algorithm_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ absl::StatusOr<se::blas::ComputationType> GetBlasComputationType(
switch (algorithm) {
case PrecisionConfig::ALG_DOT_F16_F16_F16:
return se::blas::ComputationType::kF16;
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
return se::blas::ComputationType::kBF16AsF32;
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32:
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM:
case PrecisionConfig::ALG_DOT_F16_F16_F32:
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
case PrecisionConfig::ALG_DOT_F32_F32_F32:
return se::blas::ComputationType::kF32;
case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
Expand Down
9 changes: 0 additions & 9 deletions xla/service/gpu/dot_algorithm_support_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,15 +215,6 @@ INSTANTIATE_TEST_SUITE_P(DotF16F16F32Tests, DotAlgorithmSupportTest,
Values(Sizes{32, 32}, Sizes{16, 2})),
TestParamsToString);

INSTANTIATE_TEST_SUITE_P(DotBF16ForBf16Bf16F32Tests, DotAlgorithmSupportTest,
Combine(Values(PC::ALG_DOT_BF16_BF16_F32),
Values(BF16), Values(BF16, F32),
Values(CC(8, 0)),
Values(SemanticVersion{6, 0, 0}),
Values(BackendRestriction::kNoRestriction),
Values(Sizes{32, 32}, Sizes{16, 2})),
TestParamsToString);

INSTANTIATE_TEST_SUITE_P(DotF32ForBf16Bf16F32Tests, DotAlgorithmSupportTest,
Combine(Values(PC::ALG_DOT_BF16_BF16_F32), Values(F32),
Values(F32), Values(CC(8, 0)),
Expand Down
39 changes: 35 additions & 4 deletions xla/service/gpu/fusions/triton/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
load("@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured")
load("//xla:xla.bzl", "xla_cc_test")
Expand Down Expand Up @@ -206,6 +207,7 @@ xla_test(
"no_mac",
],
deps = [
":kernel_name_tracer",
":triton_fusion_emitter",
":triton_test_utils",
"//xla:autotuning_proto_cc",
Expand All @@ -223,21 +225,17 @@ xla_test(
"//xla/stream_executor:device_description",
"//xla/stream_executor/cuda:cublas_plugin",
"//xla/tests:filecheck",
"//xla/tests:verified_hlo_module",
"//xla/tests:xla_internal_test_main", # fixdeps: keep
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest",
"@llvm-project//llvm:ir_headers",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:path",
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:status_matchers",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
Expand Down Expand Up @@ -286,6 +284,37 @@ xla_test(
],
)

cc_library(
name = "kernel_name_tracer_cuda",
testonly = True,
srcs = if_cuda(["kernel_name_tracer_cuda.cc"]),
hdrs = ["kernel_name_tracer.h"],
tags = ["manual"], # Need to exclude this from wildcard builds
deps = [
"//xla/backends/profiler/gpu:cupti_collector",
"//xla/backends/profiler/gpu:cupti_tracer",
"@tsl//tsl/profiler/utils:time_utils",
],
)

cc_library(
name = "kernel_name_tracer_noop",
testonly = True,
srcs = ["kernel_name_tracer_noop.cc"],
hdrs = ["kernel_name_tracer.h"],
tags = ["manual"], # Need to exclude this from wildcard builds
)

cc_library(
name = "kernel_name_tracer",
testonly = True,
hdrs = ["kernel_name_tracer.h"],
deps = if_cuda(
[":kernel_name_tracer_cuda"],
[":kernel_name_tracer_noop"],
),
)

cc_library(
name = "triton_test_utils",
testonly = True,
Expand Down Expand Up @@ -321,6 +350,7 @@ cc_library(
"@llvm-project//mlir:IR",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/profiler/utils:time_utils",
],
)

Expand Down Expand Up @@ -479,6 +509,7 @@ xla_test(
],
tags = ["no_mac"],
deps = [
":kernel_name_tracer",
":triton_fusion_emitter",
":triton_support",
":triton_test_utils",
Expand Down
39 changes: 39 additions & 0 deletions xla/service/gpu/fusions/triton/kernel_name_tracer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_KERNEL_NAME_TRACER_H_
#define XLA_SERVICE_GPU_FUSIONS_TRITON_KERNEL_NAME_TRACER_H_

#include <memory>
#include <string>

namespace xla::gpu {

// In some cases we need to know what exact kernel was used. It happens when we
// have no direct way to get this information from the HLO. For example, when we
// have a fusion with a custom call to cuBLAS or another third party library.
// This class allows to get the name of the kernel that was used.
class KernelNameTracer {
public:
static std::unique_ptr<KernelNameTracer> Create();

virtual void start() = 0;
virtual std::string stop() = 0;
virtual ~KernelNameTracer() = default;
};

} // namespace xla::gpu

#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_KERNEL_NAME_TRACER_H_
72 changes: 72 additions & 0 deletions xla/service/gpu/fusions/triton/kernel_name_tracer_cuda.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <memory>

#include "xla/backends/profiler/gpu/cupti_collector.h"
#include "xla/backends/profiler/gpu/cupti_tracer.h"
#include "xla/service/gpu/fusions/triton/kernel_name_tracer.h"
#include "tsl/profiler/utils/time_utils.h"

namespace xla::gpu {

// This class allows to get the name of the kernel that was used.
// It works only on CUDA. It uses CuptiTracer to get the kernel name.
class KernelNameTracerCuda : public KernelNameTracer {
public:
KernelNameTracerCuda()
: cupti_tracer_(profiler::CuptiTracer::GetCuptiTracerSingleton()) {}

void start() override;

// As of now it returns the name of the first kernel that was executed on
// GPU:0.
std::string stop() override;

private:
std::unique_ptr<profiler::CuptiTracer> cupti_tracer_;
std::unique_ptr<profiler::CuptiTraceCollector> cupti_collector_;
};

std::unique_ptr<KernelNameTracer> KernelNameTracer::Create() {
return std::make_unique<KernelNameTracerCuda>();
}

void KernelNameTracerCuda::start() {
profiler::CuptiTracerCollectorOptions collector_options;
collector_options.num_gpus = profiler::CuptiTracer::NumGpus();
auto start_gputime_ns = profiler::CuptiTracer::GetTimestamp();
auto start_walltime_ns = tsl::profiler::GetCurrentTimeNanos();
cupti_collector_ = profiler::CreateCuptiCollector(
collector_options, start_walltime_ns, start_gputime_ns);
profiler::CuptiTracerOptions options;
options.activities_selected = {CUPTI_ACTIVITY_KIND_KERNEL};
cupti_tracer_->Enable(options, cupti_collector_.get());
}

std::string KernelNameTracerCuda::stop() {
cupti_tracer_->Disable();
uint64_t end_gpu_ns = cupti_collector_->GetTracingEndTimeNs();
auto space = std::make_unique<tensorflow::profiler::XSpace>();
cupti_collector_->Export(space.get(), end_gpu_ns);
for (const auto& plane : space->planes()) {
if (plane.name() == "/device:GPU:0") {
return plane.event_metadata().at(1).name();
}
}
return "";
}

} // namespace xla::gpu
33 changes: 33 additions & 0 deletions xla/service/gpu/fusions/triton/kernel_name_tracer_noop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <memory>
#include <string>

#include "xla/service/gpu/fusions/triton/kernel_name_tracer.h"

namespace xla::gpu {

class KernelNameTracerNoop : public KernelNameTracer {
public:
void start() override {};
std::string stop() override { return "kernel_name_tracer_not_implemented"; };
};

std::unique_ptr<KernelNameTracer> KernelNameTracer::Create() {
return std::make_unique<KernelNameTracerNoop>();
}

} // namespace xla::gpu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include <cstdlib>
#include <iterator>
#include <memory>
#include <string>
#include <utility>
#include <variant>
Expand All @@ -37,6 +38,7 @@ limitations under the License.
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/fusions/triton/kernel_name_tracer.h"
#include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h"
#include "xla/service/gpu/fusions/triton/triton_test_utils.h"
#include "xla/service/gpu/gpu_device_info_for_tests.h"
Expand All @@ -46,7 +48,6 @@ limitations under the License.
#include "xla/service/pattern_matcher_gmock.h"
#include "xla/stream_executor/device_description.h"
#include "xla/tests/filecheck.h"
#include "xla/tests/verified_hlo_module.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/xla.pb.h"
#include "tsl/platform/env.h"
Expand Down Expand Up @@ -147,6 +148,74 @@ class TritonGemmTestWithoutTritonGemmAny : public TritonGemmTest {
}
};

class TritonBF16BF16F32BlasTest : public TritonTest {
public:
DebugOptions GetDebugOptionsForTest() override {
DebugOptions debug_options = TritonTest::GetDebugOptionsForTest();
// Do not autotune split-k by default, since this prevents deterministically
// matching the optimized HLO.
debug_options.set_xla_gpu_enable_split_k_autotuning(false);
debug_options.set_xla_gpu_enable_triton_gemm(false);
return debug_options;
}

protected:
void SetUp() override {
if (!SupportsBF16(GpuComputeComp())) {
GTEST_SKIP() << "BF16 not supported.";
}
}
};

TEST_F(TritonBF16BF16F32BlasTest, PropagateAlgorithmToBlas) {
// We check that the algorithm is propagated to the BLAS call.
// We also check that the kernel name matches the algorithm for Ampere.
// The algorithm for Hopper is not the one we expect because it uses TF32.

constexpr std::string_view kHloText = R"(
HloModule t

ENTRY main {
lhs = f32[8512,256]{1,0} parameter(0)
rhs = f32[256,8512]{1,0} parameter(1)
ROOT dot = f32[8512,8512]{1,0} dot(lhs, rhs),
algorithm=dot_bf16_bf16_f32,
lhs_contracting_dims={1},
rhs_contracting_dims={0}
}
)";
const std::string pattern = R"(CHECK: "algorithm":"ALG_DOT_BF16_BF16_F32")";
TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText));

auto tracer = KernelNameTracer::Create();
tracer->start();
EXPECT_TRUE(Run(std::move(module), /*run_hlo_passes=*/false));
auto kernel_name = tracer->stop();

if (kernel_name == "kernel_name_tracer_not_implemented") return;

auto cc = GetCudaComputeCapability();
using CudaComputeCapabilities =
stream_executor::CudaComputeCapability::CudaComputeCapabilities;
switch (cc.major) {
case CudaComputeCapabilities::BLACKWELL:
GTEST_SKIP() << "CudaComputeCapabilities::BLACKWELL has the kernel name: "
<< kernel_name;
break;
case CudaComputeCapabilities::AMPERE:
EXPECT_THAT(kernel_name, ::testing::HasSubstr("bf16gemm_"));
break;
case CudaComputeCapabilities::HOPPER:
// Hopper does not have bf16 kernels for ALG_DOT_BF16_BF16_F32 algorithm.
// As a result it uses TF32.
EXPECT_THAT(kernel_name, ::testing::HasSubstr("gemm_f32f32_tf32f32_f32"));
break;
default:
GTEST_SKIP() << "Unsupported compute capability: " << cc.major
<< " has the kernel name: " << kernel_name;
}
}

TEST_F(TritonGemmTest, RejectDotInt4HLO) {
constexpr std::string_view kHloText = R"(
HloModule t
Expand Down Expand Up @@ -200,6 +269,7 @@ TEST_F(TritonGemmTest, RejectTritonFusionForInt4WithMinorBatchDim) {
rhs_batch_dims={0}
}
)";

const std::string pattern =
R"(CHECK-NOT: "kind":"__triton_gemm","triton_gemm_config")";
TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText));
Expand Down
7 changes: 4 additions & 3 deletions xla/service/gpu/fusions/triton/triton_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ bool SupportsBF16(const stream_executor::GpuComputeCapability& cc) {
CHECK(false);
}

absl::Status CreateTritonIrAndFileCheck(
HloTestBase* test, absl::string_view hlo_text,
absl::string_view triton_fusion_name, absl::string_view filecheck_pattern) {
absl::Status CreateTritonIrAndFileCheck(HloTestBase* test,
absl::string_view hlo_text,
absl::string_view triton_fusion_name,
absl::string_view filecheck_pattern) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> verified_module,
test->ParseAndReturnVerifiedModule(hlo_text));
auto* comp = verified_module->GetComputationWithName(triton_fusion_name);
Expand Down
Loading
Loading