Skip to content

Commit

Permalink
[Relax] Ignore dynamic parameters in RewriteDataflowReshape (#17086)
Browse files Browse the repository at this point in the history
The Relax transform `RewriteDataflowReshape` identifies TIR functions
that are equivalent to `relax.op.reshape`, and replaces them with
calls to `relax.op.reshape`.  This is used as a precursor for
simplifications that rely on the high-level knowledge that an operator
is a reshape, but also require the low-level knowledge of the adjacent
TIR PrimFuncs.

Prior to this commit, the `RewriteDataflowReshape` pass would only
recognize static shapes, or dynamic shapes that could be inferred from
the shapes of tensor arguments.  This commit updates
`RewriteDataflowReshape` to recognize cases where an extra symbolic
variable has been provided.
  • Loading branch information
Lunderberg authored Jun 13, 2024
1 parent 0984e97 commit 0fb5365
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 14 deletions.
16 changes: 10 additions & 6 deletions src/relax/analysis/tir_op_pattern_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -517,19 +517,23 @@ bool HasReshapePattern(const PrimFunc& func) {
arith::Analyzer ana_;
};

if (func->params.size() < 2) {
return false;
Array<Buffer> buffer_args;
for (const auto& param : func->params) {
if (auto buffer = func->buffer_map.Get(param)) {
buffer_args.push_back(buffer.value());
}
}
Optional<Buffer> src_buffer = func->buffer_map.Get(func->params.front());
Optional<Buffer> dst_buffer = func->buffer_map.Get(func->params.back());
if (!(src_buffer.defined() && dst_buffer.defined())) {

if (buffer_args.size() < 2) {
return false;
}
Buffer src_buffer = buffer_args.front();
Buffer dst_buffer = buffer_args.back();

// To detect the reshape pattern, we require each For to have
// either another For or a BlockRealize as body.
ICHECK(func->body->IsInstance<BlockRealizeNode>());
return ReshapeDetector::Detect(src_buffer.value(), dst_buffer.value(), func->body);
return ReshapeDetector::Detect(src_buffer, dst_buffer, func->body);
}

TVM_REGISTER_GLOBAL("relax.analysis.has_reshape_pattern").set_body_typed(HasReshapePattern);
Expand Down
17 changes: 10 additions & 7 deletions src/relax/transform/rewrite_dataflow_reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@
namespace tvm {
namespace relax {

std::vector<size_t> GetUsedArgsIndices(const tir::PrimFunc& fn, size_t num_args) {
std::vector<size_t> GetUsedTensorArgIndices(const tir::PrimFunc& fn, size_t num_args) {
std::vector<size_t> indices;
for (size_t i = 0; i < num_args; ++i) {
auto buffer_var = fn->buffer_map[fn->params[i]]->data;
if (tir::UsesVar(fn->body, [=](const tir::VarNode* var) { return var == buffer_var.get(); })) {
indices.push_back(i);
if (auto buffer = fn->buffer_map.Get(fn->params[i])) {
auto buffer_var = buffer.value()->data;
if (tir::UsesVar(fn->body,
[=](const tir::VarNode* var) { return var == buffer_var.get(); })) {
indices.push_back(i);
}
}
}
return indices;
Expand Down Expand Up @@ -83,17 +86,17 @@ class DataflowReshapeRewriter : public ExprMutator {

auto prim_fn = Downcast<tir::PrimFunc>(mod_->Lookup(Downcast<GlobalVar>(call->args[0])));
auto arg_tuple = Downcast<Tuple>(call->args[1])->fields;
auto used_arg_indices = GetUsedArgsIndices(prim_fn, arg_tuple.size());
auto used_tensor_arg_indices = GetUsedTensorArgIndices(prim_fn, arg_tuple.size());

// The number of inputs to call_tir(reshape, (...)) might not be one, since FuseOps
// can generate a fused TupleGetItem + reshape function whose input is a tuple. FuseTIR
// then flattens the tuple input so that the fused TIR reshape function ends up having
// multiple input buffers. But only one of them should be accessed and reshaped.
if (used_arg_indices.size() != 1) {
if (used_tensor_arg_indices.size() != 1) {
return GetRef<Call>(call);
}

auto arg = arg_tuple[used_arg_indices[0]];
auto arg = arg_tuple[used_tensor_arg_indices[0]];

if (!IsCallingTIRReshape(call, arg)) {
return GetRef<Call>(call);
Expand Down
183 changes: 182 additions & 1 deletion tests/python/relax/test_transform_rewrite_dataflow_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import tvm
import tvm.testing
from tvm import relax
from tvm.script import relax as R, tir as T
from tvm.script import relax as R, tir as T, ir as I


def test_reshape_expand_dims():
Expand Down Expand Up @@ -581,5 +581,186 @@ def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((1,), dtype="float32"):
tvm.ir.assert_structural_equal(rewritten, Expected)


def test_rewrite_static_reshape():
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor([256], dtype="float32")):
with R.dataflow():
y = R.reshape(x, [64, 4])
z = R.add(y, y)
R.output(z)
return z

@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor((256,), dtype="float32")):
cls = Expected

with R.dataflow():
y = R.reshape(x, R.shape([64, 4]))
z = R.call_tir(cls.add, (y, y), out_sinfo=R.Tensor((64, 4), dtype="float32"))
R.output(z)
return z

@T.prim_func(private=True)
def add(
y1: T.Buffer((T.int64(64), T.int64(4)), "float32"),
y2: T.Buffer((T.int64(64), T.int64(4)), "float32"),
z: T.Buffer((T.int64(64), T.int64(4)), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})

for iters in T.grid(T.int64(64), T.int64(4)):
with T.block("T_add"):
i, j = T.axis.remap("SS", iters)
z[i, j] = y1[i, j] + y2[i, j]

After = tvm.ir.transform.Sequential(
[
# Lower both R.reshape and R.add from Relax to TIR
relax.transform.LegalizeOps(),
# Identify reshapes, raise calls to cls.reshape from TIR
# to Relax
relax.transform.RewriteDataflowReshape(),
# Clean up afterwards, removing the no-longer-required
# PrimFunc "reshape"
relax.transform.DeadCodeElimination(),
]
)(Before)

tvm.ir.assert_structural_equal(Expected, After)


# def test_rewrite_dynamic_reshape():
# @I.ir_module
# class Before:
# @R.function
# def main(x: R.Tensor(["N"], dtype="float32")):
# N = T.int64()
# with R.dataflow():
# y = R.reshape(x, [N // 4, 4])
# z = R.add(y, y)
# R.output(z)
# return z

# @I.ir_module
# class Expected:
# @R.function
# def main(x: R.Tensor(["N"], dtype="float32")):
# N = T.int64()
# cls = Expected

# with R.dataflow():
# y = R.reshape(x, R.shape([N // 4, 4]))
# z = R.call_tir(
# cls.add,
# (y, y),
# tir_vars=[N],
# out_sinfo=R.Tensor((N // 4, 4), dtype="float32"),
# )
# R.output(z)
# return z

# @T.prim_func(private=True)
# def add(
# y1_handle: T.handle,
# y2_handle: T.handle,
# z_handle: T.handle,
# N: T.int64,
# ):

# y1 = T.match_buffer(y1_handle, [N // 4, 4], "float32")
# y2 = T.match_buffer(y2_handle, [N // 4, 4], "float32")
# z = T.match_buffer(z_handle, [N // 4, 4], "float32")

# T.func_attr({"tir.noalias": T.bool(True)})

# for iters in T.grid(T.int64(64), T.int64(4)):
# with T.block("T_add"):
# i, j = T.axis.remap("SS", iters)
# z[i, j] = y1[i, j] + y2[i, j]

# After = tvm.ir.transform.Sequential(
# [
# # Lower both R.reshape and R.add from Relax to TIR
# relax.transform.LegalizeOps(),
# # Identify reshapes, raise calls to cls.reshape from TIR
# # to Relax
# relax.transform.RewriteDataflowReshape(),
# # Clean up afterwards, removing the no-longer-required
# # PrimFunc "reshape"
# relax.transform.DeadCodeElimination(),
# ]
# )(Before)
# After.show()
# tvm.ir.assert_structural_equal(Expected, After)


def test_rewrite_dynamic_reshape():
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor(["N*16"], dtype="float32"), _: R.Prim(value="N")):
N = T.int64()
with R.dataflow():
y = R.reshape(x, [N * 4, T.int64(4)])
z = R.add(y, y)
R.output(z)
return z

@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(["N*16"], dtype="float32"), _: R.Prim(value="N")):
N = T.int64()
cls = Expected

with R.dataflow():
y = R.reshape(x, R.shape([N * 4, T.int64(4)]))
z = R.call_tir(
cls.add,
(y, y),
tir_vars=[N],
out_sinfo=R.Tensor((N * 4, 4), dtype="float32"),
)
R.output(z)
return z

@T.prim_func(private=True)
def add(
y1_handle: T.handle,
y2_handle: T.handle,
z_handle: T.handle,
N: T.int64,
):

y1 = T.match_buffer(y1_handle, [N * 4, T.int64(4)], "float32")
y2 = T.match_buffer(y2_handle, [N * 4, T.int64(4)], "float32")
z = T.match_buffer(z_handle, [N * 4, T.int64(4)], "float32")

T.func_attr({"tir.noalias": T.bool(True)})

for iters in T.grid(N * 4, T.int64(4)):
with T.block("T_add"):
i, j = T.axis.remap("SS", iters)
z[i, j] = y1[i, j] + y2[i, j]

After = tvm.ir.transform.Sequential(
[
# Lower both R.reshape and R.add from Relax to TIR
relax.transform.LegalizeOps(),
# Identify reshapes, raise calls to cls.reshape from TIR
# to Relax
relax.transform.RewriteDataflowReshape(),
# Clean up afterwards, removing the no-longer-required
# PrimFunc "reshape"
relax.transform.DeadCodeElimination(),
]
)(Before)
tvm.ir.assert_structural_equal(Expected, After)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 0fb5365

Please sign in to comment.