Skip to content

Commit

Permalink
PR #14073: Add select(compare(a, b, GT/GE), a, b) => or(a, b) to algsimp
Browse files Browse the repository at this point in the history
Imported from GitHub PR #14073

In one of the customer's HLOs (in the reduceOp computation function), I found the following pattern:
```c
p0 = pred[]{0} parameter(0)
p1 = pred[]{0} parameter(1)
compare = pred[]{0} compare(p0, p1), direction=GT
select = pred[]{0} select(compare, p0, p1)
```
It can be simplified to `logical_or`.

This PR adds the following patterns to algsimp
```c
select(compare(a, b, GT/GE), a, b) => or(a, b)
select(compare(a, b, LT/LE), a, b) => and(a, b)
select(compare(a, b, EQ), a, b) => b
select(compare(a, b, NE), a, b) => a

a,b ∈ PRED
```
Copybara import of the project:

--
6fe68d7 by Alexander Pivovarov <[email protected]>:

Add select(compare(a, b, GT/GE), a, b) => or(a, b) to algsimp

Merging this change closes #14073

COPYBARA_INTEGRATE_REVIEW=#14073 from apivovarov:select_compare_algsimp 6fe68d7
PiperOrigin-RevId: 646667024
  • Loading branch information
apivovarov authored and copybara-github committed Jun 26, 2024
1 parent 6fdb791 commit 6a009c8
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 0 deletions.
27 changes: 27 additions & 0 deletions xla/service/algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8188,6 +8188,33 @@ absl::Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) {
select->mutable_operand(0)->shape(), HloOpcode::kNot,
select->mutable_operand(0)));
}
// select(compare(a, b, GT/GE), a, b) => or(a, b)
// select(compare(a, b, LT/LE), a, b) => and(a, b)
// select(compare(a, b, EQ), a, b) => b
// select(compare(a, b, NE), a, b) => a
HloInstruction *compare, *lhs, *rhs;
if (Match(select, m::Select(m::Op(&compare), m::Op(&lhs), m::Op(&rhs))) &&
Match(compare, m::Compare(m::Op().Is(lhs), m::Op().Is(rhs)))) {
auto cmp_dir = compare->comparison_direction();
if (cmp_dir == ComparisonDirection::kGt ||
cmp_dir == ComparisonDirection::kGe) {
return ReplaceWithNewInstruction(
select, HloInstruction::CreateBinary(select->shape(),
HloOpcode::kOr, lhs, rhs));
}
if (cmp_dir == ComparisonDirection::kLt ||
cmp_dir == ComparisonDirection::kLe) {
return ReplaceWithNewInstruction(
select, HloInstruction::CreateBinary(select->shape(),
HloOpcode::kAnd, lhs, rhs));
}
if (cmp_dir == ComparisonDirection::kEq) {
return ReplaceInstruction(select, rhs);
}
if (cmp_dir == ComparisonDirection::kNe) {
return ReplaceInstruction(select, lhs);
}
}
}

// select(pred, xs, dynamic_update_slice(xs, x, i))
Expand Down
89 changes: 89 additions & 0 deletions xla/service/algebraic_simplifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,95 @@ TEST_F(AlgebraicSimplifierTest, SelectPredPred2) {
GmockMatch(m::Not(m::Parameter(0))));
}

// select(compare(a, b, GT/GE), a, b) => or(a, b), a,b ∈ PRED
TEST_F(AlgebraicSimplifierTest, SelectGtCompare) {
for (const auto cmp_dir : {"GT", "GE"}) {
const auto kModuleStr = absl::StrFormat(R"(
HloModule m
test {
p0 = pred[8]{0} parameter(0)
p1 = pred[8]{0} parameter(1)
compare = pred[8]{0} compare(p0, p1), direction=%s
ROOT select = pred[8]{0} select(compare, p0, p1)
}
)",
cmp_dir);
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Or(m::Parameter(0), m::Parameter(1))));
}
}

// select(compare(a, b, LT/LE), a, b) => and(a, b), a,b ∈ PRED
TEST_F(AlgebraicSimplifierTest, SelectLtCompare) {
for (const auto cmp_dir : {"LT", "LE"}) {
const auto kModuleStr = absl::StrFormat(R"(
HloModule m
test {
p0 = pred[8]{0} parameter(0)
p1 = pred[8]{0} parameter(1)
compare = pred[8]{0} compare(p0, p1), direction=%s
ROOT select = pred[8]{0} select(compare, p0, p1)
}
)",
cmp_dir);
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::And(m::Parameter(0), m::Parameter(1))));
}
}

// select(compare(a, b, EQ), a, b) => b, a,b ∈ PRED
TEST_F(AlgebraicSimplifierTest, SelectEqCompare) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = pred[8]{0} parameter(0)
p1 = pred[8]{0} parameter(1)
compare = pred[8]{0} compare(p0, p1), direction=EQ
ROOT select = pred[8]{0} select(compare, p0, p1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Parameter(1)));
}

// select(compare(a, b, NE), a, b) => a, a,b ∈ PRED
TEST_F(AlgebraicSimplifierTest, SelectNeCompare) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = pred[8]{0} parameter(0)
p1 = pred[8]{0} parameter(1)
compare = pred[8]{0} compare(p0, p1), direction=NE
ROOT select = pred[8]{0} select(compare, p0, p1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Parameter(0)));
}

// select(compare(a, b, NE), b, a) ≠> a - wrong operands order
TEST_F(AlgebraicSimplifierTest, SelectNeCompare_NegativeTestCase) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = pred[8]{0} parameter(0)
p1 = pred[8]{0} parameter(1)
compare = pred[8]{0} compare(p0, p1), direction=NE
ROOT select = pred[8]{0} select(compare, p1, p0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value());
}

// Test that select(pred, xs, dynamic_update_slice(xs, x, i)) is simplified
// to dynamic_update_slice(xs, select(pred, dynamic_slice(xs, i), x), i)
TEST_F(AlgebraicSimplifierTest, SelectDUSWithShapedPred) {
Expand Down

0 comments on commit 6a009c8

Please sign in to comment.