Skip to content

Commit

Permalink
[XLA:GPU] Add the fallback to F32 for dot algorithms bf16_x3 and bf16…
Browse files Browse the repository at this point in the history
…_x6 and f32 output.

We do not have the lowering to multiply version of the algorithm at the moment.
The default f32_f32_f32 version is way faster. Let's use it as the fallback now and do the lowering for these algorithms in the follow up cls.

PiperOrigin-RevId: 695351871
  • Loading branch information
loislo authored and Google-ML-Automation committed Nov 11, 2024
1 parent 57c363e commit b70917b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
4 changes: 4 additions & 0 deletions xla/hlo/transforms/simplifiers/algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3964,6 +3964,10 @@ absl::Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
PrecisionConfig::PACKED_NIBBLE);
const bool can_rewrite_dot_with_precision_config_algorithm =
dot->precision_config().algorithm() == PrecisionConfig::ALG_UNSET ||
dot->precision_config().algorithm() ==
PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 ||
dot->precision_config().algorithm() ==
PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 ||
dot->precision_config().algorithm() ==
PrecisionConfig::ALG_DOT_F32_F32_F32;
// If there are no contracting dimensions, a dot can be rewritten as
Expand Down
30 changes: 30 additions & 0 deletions xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9499,6 +9499,36 @@ TEST_F(AlgebraicSimplifierTest,
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
}

TEST_F(AlgebraicSimplifierTest,
DotToMultiplyRewriteWith_BF16_BF16_F32_X3_Algorithm) {
constexpr char kModuleStr[] = R"(
HloModule test
ENTRY dot {
a = f32[128]{0} parameter(0)
b = f32[128]{0} parameter(1)
ROOT dot = f32[128,128]{1,0} dot(a, b),
algorithm=dot_bf16_bf16_f32_x3
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
}

TEST_F(AlgebraicSimplifierTest,
DotToMultiplyRewriteWith_BF16_BF16_F32_X6_Algorithm) {
constexpr char kModuleStr[] = R"(
HloModule test
ENTRY dot {
a = f32[128]{0} parameter(0)
b = f32[128]{0} parameter(1)
ROOT dot = f32[128,128]{1,0} dot(a, b),
algorithm=dot_bf16_bf16_f32_x6
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
}

TEST_F(AlgebraicSimplifierTest, RemainderOfIota) {
const char* kModuleStr = R"(
HloModule m
Expand Down

0 comments on commit b70917b

Please sign in to comment.