From 5d077c5a0900a6b98934b9e9d813da14ba0fc24b Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Fri, 7 Jun 2024 15:49:51 +0100 Subject: [PATCH] [Arith][SVE] Add rewrite rules for indices split by scalable expressions (#17046) This commit introduces rewrite rules for indices which can arise from splitting axes by scalable factors (e.g. `xo, xi = sch.split(x, factors = [None, 8 * T.vscale()])`): ``` (v_x_o * T.Cast("int64", T.vscale()) * T.int64(8) + v_x_i) // (T.Cast("int64", T.vscale()) * T.int64(8)) == v_x_o (v_x_o * T.Cast("int64", T.vscale()) * T.int64(8) + v_x_i) % (T.Cast("int64", T.vscale()) * T.int64(8)) == v_x_i ``` The rewrites help prove checks needed by `sch.tensorize()` (e.g. CompareBufferRegion). --- src/arith/rewrite_simplify.cc | 15 +++++++++++++++ src/arith/rewrite_simplify.h | 2 ++ tests/python/arith/test_arith_rewrite_simplify.py | 8 ++++++++ 3 files changed, 25 insertions(+) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 42447ef2f8f2..f4d4a9048ced 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1136,8 +1136,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { x + floordiv(y, z), CanProveGreaterEqual(z.Eval(), 0)); TVM_TRY_REWRITE_IF(matches_one_of(floordiv(y + x * z, z), floordiv(y + z * x, z)), floordiv(y, z) + x, CanProveGreaterEqual(z.Eval(), 0)); + TVM_TRY_REWRITE_IF(floordiv(x * z * c1 + y, z * c1), x + floordiv(y, z * c1), + CanProveGreaterEqual(z.Eval() * c1.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(x - floormod(x, c1), c1), floordiv(x, c1), c1.Eval()->value != 0); + + // Scalable divisor + TVM_TRY_REWRITE_IF(floordiv(x, y), ZeroWithTypeLike(x), + ContainsVscaleCall(y.Eval()) && CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0) && CanProve(x.Eval() < y.Eval())); } return ret; } @@ -1230,6 +1237,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { ZeroWithTypeLike(x), CanProveEqual(y.Eval() - z.Eval(), 0) || CanProveEqual(y.Eval() + z.Eval(), 0)); + TVM_TRY_REWRITE_IF(floormod(x * z * c1 + y, z * c1), floormod(y, z * c1), + CanProveGreaterEqual(z.Eval() * c1.Eval(), 0)); + + // Scalable divisor + TVM_TRY_REWRITE_IF(floormod(x, y), x, + ContainsVscaleCall(y.Eval()) && CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0) && CanProve(x.Eval() < y.Eval())); + if (floormod(x, c1).Match(ret)) { int64_t c1val = c1.Eval()->value; if (c1val > 0) { diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 26dee062c4d2..1a53bef45002 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -229,6 +229,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { // TODO(tqchen) refer back to super-analyzer. return TryCompare(x, val) == CompareResult::kEQ; } + // Whether x is true + bool CanProve(const PrimExpr& x) { return analyzer_->CanProve(x); } // Recursive rewrite x // we limit maximum depth of recursive rewrite allowed to diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index fcb6aa572910..1ebaab53af2d 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -559,6 +559,7 @@ class TestFloordivIndex(BaseCompare): TestCase(fld(x * y, y), x, y >= 0), TestCase(fld(y * x, y), x, y >= 0), TestCase(fld(x * z + y, z), x + fld(y, z), z >= 0), + TestCase(fld(x * z * 2 + y, z * 2), x + fld(y, z * 2), z * 2 >= 0), TestCase(fld(z * x + y, z), x + fld(y, z), z >= 0), TestCase(fld(y + x * z, z), fld(y, z) + x, z >= 0), TestCase(fld(y + z * x, z), fld(y, z) + x, z >= 0), @@ -616,6 +617,7 @@ class TestFloormodIndex(BaseCompare): TestCase(flm(x + y * (-10), 2), flm(x, 2)), TestCase(flm(x * 32 + y, 64), flm(x, 2) * 32 + y, [y >= 0, y < 32]), TestCase(flm(x * 32 - y, 64), flm(x * 32 - y, 64), [y >= 0, y < 32]), + TestCase(flm(x * z * 2 + y, z * 2), flm(y, z * 2), z * 2 >= 0), # NOTE: the followng case is covered by canonical simplify # long range simplifcation in general can be covered by canonical simplify # TestCase(flm(x * 10 + 1 + y * 2 + 2, 2), 1), @@ -832,6 +834,12 @@ class TestScalableIndex(BaseCompare): x + tir.vscale() * 4 - flm(4, tir.vscale() * 4), ), TestCase(tvm.te.max(tir.vscale() * x, tir.vscale() * y), tir.vscale() * x, x > y), + # FloorDiv + TestCase(fld(x * tir.vscale() * 4 + y, tir.vscale() * 4), x + fld(y, tir.vscale() * 4)), + TestCase(fld(x, tir.vscale() * 4), 0, [x >= 0, x < tir.vscale() * 4]), + # FloorMod + TestCase(flm(x * tir.vscale() * 4 + y, tir.vscale() * 4), flm(y, tir.vscale() * 4)), + TestCase(flm(x, tir.vscale() * 4), x, [x >= 0, x < tir.vscale() * 4]), ) def test_simplify(self, test_case):