Skip to content

Commit

Permalink
This is not correct in general.
Browse files Browse the repository at this point in the history
(((s1 + s1) floordiv 3) + (s0 floordiv 3)) floordiv 6)

is simplified to

(s1 * 2 + s0) floordiv 18

But this is not equivalent: the results for [2, 8] are 1 and 0, respectively.

Reverts 2736624

PiperOrigin-RevId: 646381435
  • Loading branch information
jreiffers authored and copybara-github committed Jun 25, 2024
1 parent 7db2902 commit ca0f4b3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 51 deletions.
47 changes: 12 additions & 35 deletions xla/service/gpu/model/indexing_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,23 +271,13 @@ AffineExpr AffineExprSimplifier::SimplifySumDiv(AffineExpr dividend,
// The gcd of all multipliers and the divisor.
int64_t multiplier_divisor_gcd = divisor;
Interval no_multiplier_range{0, 0};
std::optional<int64_t> min_inner_divisor = std::nullopt;
std::optional<int64_t> inner_divisor_gcd = std::nullopt;
VisitSummands(new_dividend, [&](AffineExpr summand) {
if (auto multiplier = GetConstantRhs(summand, AffineExprKind::Mul)) {
multiplier_divisor_gcd = std::gcd(multiplier_divisor_gcd, *multiplier);
} else {
no_multiplier_range = no_multiplier_range +
range_evaluator_->ComputeExpressionRange(summand);
}

if (auto inner_divisor =
GetConstantRhs(summand, AffineExprKind::FloorDiv)) {
min_inner_divisor =
std::min(min_inner_divisor.value_or(*inner_divisor), *inner_divisor);
inner_divisor_gcd =
std::gcd(inner_divisor_gcd.value_or(*inner_divisor), *inner_divisor);
}
});

// Consider an expression like: `(x * 6 + y) / 9`. if the range of `y` is at
Expand All @@ -306,24 +296,6 @@ AffineExpr AffineExprSimplifier::SimplifySumDiv(AffineExpr dividend,
divisor /= multiplier_divisor_gcd;
}

// If we have an inner divisor whose value is equal to the GCD of all the
// divisors, we can remove a division:
// `(a0 / c + a1 / cd + ...) / e` -> `(a0 + a1 / d + (...) * c) / ce`
// This potentially increases the number of multiplications, but it's
// generally a win. It also matches what the MLIR simplifier does better, so
// we can get more simplifications.
if (min_inner_divisor && *min_inner_divisor > 0 &&
min_inner_divisor == inner_divisor_gcd) {
new_dividend = MapSummands(new_dividend, [&](AffineExpr summand) {
if (auto inner_divisor =
GetConstantRhs(summand, AffineExprKind::FloorDiv)) {
return GetLhs(summand).floorDiv(*inner_divisor / *inner_divisor_gcd);
}
return summand * *inner_divisor_gcd;
});
divisor *= *inner_divisor_gcd;
}

return new_dividend.floorDiv(divisor) + extracted;
}

Expand Down Expand Up @@ -509,13 +481,18 @@ AffineExpr AffineExprSimplifier::SimplifyOnce(AffineExpr expr) {
if (!div) continue; // Already erased.
if ((div_mul % mod_mul) || (div_mul / mod_mul) != mod_c) continue;

// In many cases, we could just compare the LHSes of the mod and the
// div, but if x is a floorDiv itself, we need to check a bit more
// carefully:
// ((x // c0) % c1) * d + (x // (c0 * c1)) * (c1 * d)`
// `x // (c0 * c1)` will be simplified, so we we may not even have
// `c0 * c1` in the expression, if `x` contains a multiplier.
if (Simplify(GetLhs(mod).floorDiv(*mod_c)) != Simplify(div)) continue;
auto mod_lhs = GetLhs(mod);
if (GetConstantRhs(mod_lhs, AffineExprKind::FloorDiv)) {
// If x is a floorDiv itself, we need to check a bit more carefully:
// ((x // c0) % c1) * d + (x // (c0 * c1)) * (c1 * d)`
// `x // (c0 * c1)` will be simplified, so we we may not even have
// `c0 * c1` in the expression, if `x` contains a multiplier.
if (Simplify(mod_lhs.floorDiv(*mod_c)) != Simplify(div)) continue;
} else {
if (mod_lhs != GetLhs(div)) continue;
auto div_c = GetConstantRhs(div, AffineExprKind::FloorDiv);
if (mod_c != div_c) continue;
}

others.push_back(GetLhs(mod) * mod_mul);
divs[div_i].first = nullptr;
Expand Down
16 changes: 0 additions & 16 deletions xla/service/gpu/model/indexing_map_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -715,22 +715,6 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) {
)"));
}

TEST_F(IndexingMapTest, AffineMapSimplification_DivGcdGreater1) {
auto serialized_map =
"()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 - ((s0 * 2 + s1 floordiv 64) "
"floordiv 3) * 768 + ((s0 * 128 + s1) floordiv 192) * 768)";
IndexingMap indexing_map = IndexingMap::FromTensorSizes(
ParseAffineMap(serialized_map, &mlir_context_), {}, {1234, 128, 4});
EXPECT_TRUE(indexing_map.Simplify());
EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"(
()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2)
domain:
s0 in [0, 1233]
s1 in [0, 127]
s2 in [0, 3]
)"));
}

TEST_F(IndexingMapTest, AffineMapSimplification_NegativeDiv) {
// (s0 floordiv 2) floordiv -7 is not s0 floordiv -14:
// 15 // 2 // -7 = -1
Expand Down

0 comments on commit ca0f4b3

Please sign in to comment.