From 0fb5365cd42b1ebaa97be6ed168f5e741f7c66a3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 12 Jun 2024 21:29:15 -0500 Subject: [PATCH] [Relax] Ignore dynamic parameters in RewriteDataflowReshape (#17086) The Relax transform `RewriteDataflowReshape` identifies TIR functions that are equivalent to `relax.op.reshape`, and replaces them with calls to `relax.op.reshape`. This is used as a precursor for simplifications that rely on the high-level knowledge that an operator is a reshape, but also require the low-level knowledge of the adjacent TIR PrimFuncs. Prior to this commit, the `RewriteDataflowReshape` pass would only recognize static shapes, or dynamic shapes that could be inferred from the shapes of tensor arguments. This commit updates `RewriteDataflowReshape` to recognize cases where an extra symbolic variable has been provided. --- src/relax/analysis/tir_op_pattern_kind.cc | 16 +- .../transform/rewrite_dataflow_reshape.cc | 17 +- ...test_transform_rewrite_dataflow_reshape.py | 183 +++++++++++++++++- 3 files changed, 202 insertions(+), 14 deletions(-) diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc index c56f019e6bd4..44a888d7e6c9 100644 --- a/src/relax/analysis/tir_op_pattern_kind.cc +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -517,19 +517,23 @@ bool HasReshapePattern(const PrimFunc& func) { arith::Analyzer ana_; }; - if (func->params.size() < 2) { - return false; + Array buffer_args; + for (const auto& param : func->params) { + if (auto buffer = func->buffer_map.Get(param)) { + buffer_args.push_back(buffer.value()); + } } - Optional src_buffer = func->buffer_map.Get(func->params.front()); - Optional dst_buffer = func->buffer_map.Get(func->params.back()); - if (!(src_buffer.defined() && dst_buffer.defined())) { + + if (buffer_args.size() < 2) { return false; } + Buffer src_buffer = buffer_args.front(); + Buffer dst_buffer = buffer_args.back(); // To detect the reshape pattern, we require each For to have // either another For or a BlockRealize as body. ICHECK(func->body->IsInstance()); - return ReshapeDetector::Detect(src_buffer.value(), dst_buffer.value(), func->body); + return ReshapeDetector::Detect(src_buffer, dst_buffer, func->body); } TVM_REGISTER_GLOBAL("relax.analysis.has_reshape_pattern").set_body_typed(HasReshapePattern); diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc b/src/relax/transform/rewrite_dataflow_reshape.cc index 8345f3e0b745..5403b7090c53 100644 --- a/src/relax/transform/rewrite_dataflow_reshape.cc +++ b/src/relax/transform/rewrite_dataflow_reshape.cc @@ -34,12 +34,15 @@ namespace tvm { namespace relax { -std::vector GetUsedArgsIndices(const tir::PrimFunc& fn, size_t num_args) { +std::vector GetUsedTensorArgIndices(const tir::PrimFunc& fn, size_t num_args) { std::vector indices; for (size_t i = 0; i < num_args; ++i) { - auto buffer_var = fn->buffer_map[fn->params[i]]->data; - if (tir::UsesVar(fn->body, [=](const tir::VarNode* var) { return var == buffer_var.get(); })) { - indices.push_back(i); + if (auto buffer = fn->buffer_map.Get(fn->params[i])) { + auto buffer_var = buffer.value()->data; + if (tir::UsesVar(fn->body, + [=](const tir::VarNode* var) { return var == buffer_var.get(); })) { + indices.push_back(i); + } } } return indices; @@ -83,17 +86,17 @@ class DataflowReshapeRewriter : public ExprMutator { auto prim_fn = Downcast(mod_->Lookup(Downcast(call->args[0]))); auto arg_tuple = Downcast(call->args[1])->fields; - auto used_arg_indices = GetUsedArgsIndices(prim_fn, arg_tuple.size()); + auto used_tensor_arg_indices = GetUsedTensorArgIndices(prim_fn, arg_tuple.size()); // The number of inputs to call_tir(reshape, (...)) might not be one, since FuseOps // can generate a fused TupleGetItem + reshape function whose input is a tuple. FuseTIR // then flattens the tuple input so that the fused TIR reshape function ends up having // multiple input buffers. But only one of them should be accessed and reshaped. - if (used_arg_indices.size() != 1) { + if (used_tensor_arg_indices.size() != 1) { return GetRef(call); } - auto arg = arg_tuple[used_arg_indices[0]]; + auto arg = arg_tuple[used_tensor_arg_indices[0]]; if (!IsCallingTIRReshape(call, arg)) { return GetRef(call); diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py index 26578393fe5e..f7befd3b886a 100644 --- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py +++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py @@ -18,7 +18,7 @@ import tvm import tvm.testing from tvm import relax -from tvm.script import relax as R, tir as T +from tvm.script import relax as R, tir as T, ir as I def test_reshape_expand_dims(): @@ -581,5 +581,186 @@ def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((1,), dtype="float32"): tvm.ir.assert_structural_equal(rewritten, Expected) +def test_rewrite_static_reshape(): + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor([256], dtype="float32")): + with R.dataflow(): + y = R.reshape(x, [64, 4]) + z = R.add(y, y) + R.output(z) + return z + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((256,), dtype="float32")): + cls = Expected + + with R.dataflow(): + y = R.reshape(x, R.shape([64, 4])) + z = R.call_tir(cls.add, (y, y), out_sinfo=R.Tensor((64, 4), dtype="float32")) + R.output(z) + return z + + @T.prim_func(private=True) + def add( + y1: T.Buffer((T.int64(64), T.int64(4)), "float32"), + y2: T.Buffer((T.int64(64), T.int64(4)), "float32"), + z: T.Buffer((T.int64(64), T.int64(4)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + + for iters in T.grid(T.int64(64), T.int64(4)): + with T.block("T_add"): + i, j = T.axis.remap("SS", iters) + z[i, j] = y1[i, j] + y2[i, j] + + After = tvm.ir.transform.Sequential( + [ + # Lower both R.reshape and R.add from Relax to TIR + relax.transform.LegalizeOps(), + # Identify reshapes, raise calls to cls.reshape from TIR + # to Relax + relax.transform.RewriteDataflowReshape(), + # Clean up afterwards, removing the no-longer-required + # PrimFunc "reshape" + relax.transform.DeadCodeElimination(), + ] + )(Before) + + tvm.ir.assert_structural_equal(Expected, After) + + +# def test_rewrite_dynamic_reshape(): +# @I.ir_module +# class Before: +# @R.function +# def main(x: R.Tensor(["N"], dtype="float32")): +# N = T.int64() +# with R.dataflow(): +# y = R.reshape(x, [N // 4, 4]) +# z = R.add(y, y) +# R.output(z) +# return z + +# @I.ir_module +# class Expected: +# @R.function +# def main(x: R.Tensor(["N"], dtype="float32")): +# N = T.int64() +# cls = Expected + +# with R.dataflow(): +# y = R.reshape(x, R.shape([N // 4, 4])) +# z = R.call_tir( +# cls.add, +# (y, y), +# tir_vars=[N], +# out_sinfo=R.Tensor((N // 4, 4), dtype="float32"), +# ) +# R.output(z) +# return z + +# @T.prim_func(private=True) +# def add( +# y1_handle: T.handle, +# y2_handle: T.handle, +# z_handle: T.handle, +# N: T.int64, +# ): + +# y1 = T.match_buffer(y1_handle, [N // 4, 4], "float32") +# y2 = T.match_buffer(y2_handle, [N // 4, 4], "float32") +# z = T.match_buffer(z_handle, [N // 4, 4], "float32") + +# T.func_attr({"tir.noalias": T.bool(True)}) + +# for iters in T.grid(T.int64(64), T.int64(4)): +# with T.block("T_add"): +# i, j = T.axis.remap("SS", iters) +# z[i, j] = y1[i, j] + y2[i, j] + +# After = tvm.ir.transform.Sequential( +# [ +# # Lower both R.reshape and R.add from Relax to TIR +# relax.transform.LegalizeOps(), +# # Identify reshapes, raise calls to cls.reshape from TIR +# # to Relax +# relax.transform.RewriteDataflowReshape(), +# # Clean up afterwards, removing the no-longer-required +# # PrimFunc "reshape" +# relax.transform.DeadCodeElimination(), +# ] +# )(Before) +# After.show() +# tvm.ir.assert_structural_equal(Expected, After) + + +def test_rewrite_dynamic_reshape(): + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor(["N*16"], dtype="float32"), _: R.Prim(value="N")): + N = T.int64() + with R.dataflow(): + y = R.reshape(x, [N * 4, T.int64(4)]) + z = R.add(y, y) + R.output(z) + return z + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor(["N*16"], dtype="float32"), _: R.Prim(value="N")): + N = T.int64() + cls = Expected + + with R.dataflow(): + y = R.reshape(x, R.shape([N * 4, T.int64(4)])) + z = R.call_tir( + cls.add, + (y, y), + tir_vars=[N], + out_sinfo=R.Tensor((N * 4, 4), dtype="float32"), + ) + R.output(z) + return z + + @T.prim_func(private=True) + def add( + y1_handle: T.handle, + y2_handle: T.handle, + z_handle: T.handle, + N: T.int64, + ): + + y1 = T.match_buffer(y1_handle, [N * 4, T.int64(4)], "float32") + y2 = T.match_buffer(y2_handle, [N * 4, T.int64(4)], "float32") + z = T.match_buffer(z_handle, [N * 4, T.int64(4)], "float32") + + T.func_attr({"tir.noalias": T.bool(True)}) + + for iters in T.grid(N * 4, T.int64(4)): + with T.block("T_add"): + i, j = T.axis.remap("SS", iters) + z[i, j] = y1[i, j] + y2[i, j] + + After = tvm.ir.transform.Sequential( + [ + # Lower both R.reshape and R.add from Relax to TIR + relax.transform.LegalizeOps(), + # Identify reshapes, raise calls to cls.reshape from TIR + # to Relax + relax.transform.RewriteDataflowReshape(), + # Clean up afterwards, removing the no-longer-required + # PrimFunc "reshape" + relax.transform.DeadCodeElimination(), + ] + )(Before) + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main()