Skip to content

Commit

Permalink
PR #17580: Algebraic simplifier: optimize comparisons of all non-nega…
Browse files Browse the repository at this point in the history
…tive instructions to zero.

Imported from GitHub PR #17580

PR stacked with #17579
Copybara import of the project:

--
02c09a8 by Ilia Sergachev <[email protected]>:

Algebraic simplifier: mark iota non-negative.

--
4735edc by Ilia Sergachev <[email protected]>:

Fix unrelated clang-format issues to make CI happy

--
9494797 by Ilia Sergachev <[email protected]>:

Algebraic simplifier: optimize comparisons of all non-negative instructions to zero.

Merging this change closes #17580

COPYBARA_INTEGRATE_REVIEW=#17580 from openxla:non_neg_compare_zero 9494797
PiperOrigin-RevId: 679130659
  • Loading branch information
sergachev authored and Google-ML-Automation committed Sep 26, 2024
1 parent b4a8fad commit 5123bd4
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
8 changes: 4 additions & 4 deletions xla/service/algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5096,16 +5096,16 @@ absl::Status AlgebraicSimplifierVisitor::HandleCompare(
}

if (compare->comparison_direction() == ComparisonDirection::kLt &&
lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
IsNonNegative(lhs, options_) && IsAll(rhs, 0)) {
return ReplaceInstruction(compare, MakeScalarLike(compare, false));
} else if (compare->comparison_direction() == ComparisonDirection::kGt &&
IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
IsAll(lhs, 0) && IsNonNegative(rhs, options_)) {
return ReplaceInstruction(compare, MakeScalarLike(compare, false));
} else if (compare->comparison_direction() == ComparisonDirection::kGe &&
lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
IsNonNegative(lhs, options_) && IsAll(rhs, 0)) {
return ReplaceInstruction(compare, MakeScalarLike(compare, true));
} else if (compare->comparison_direction() == ComparisonDirection::kLe &&
IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
IsAll(lhs, 0) && IsNonNegative(rhs, options_)) {
return ReplaceInstruction(compare, MakeScalarLike(compare, true));
}
if (lhs == rhs &&
Expand Down
15 changes: 15 additions & 0 deletions xla/service/algebraic_simplifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8985,6 +8985,21 @@ TEST_F(AlgebraicSimplifierTest, CompareIota) {
GmockMatch(m::Broadcast(m::ConstantScalar(false))));
}

TEST_F(AlgebraicSimplifierTest, CompareAbsLtZeroBecomesFalse) {
// |x| < 0 -> false
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(R"(
m {
p = s32[5] parameter(0)
a = s32[5] abs(p)
z = s32[] constant(0)
b = s32[5] broadcast(z)
ROOT r = pred[5] compare(a, b), direction=LT
})"));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Broadcast(m::ConstantScalar(false))));
}

TEST_F(AlgebraicSimplifierTest, CompareLtZero) {
const char* kModuleStr = R"(
HloModule m
Expand Down

0 comments on commit 5123bd4

Please sign in to comment.