Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] make LazyTransformParam more general #16088

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 49 additions & 10 deletions python/tvm/relax/transform/lazy_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name, unused-argument, missing-function-docstring, abstract-method
"""Relax LazyTransformParams pass."""
import tvm
import itertools
from tvm import IRModule, relax
from tvm.relax.expr_functor import PyExprMutator, PyExprVisitor, mutator, visitor

Expand Down Expand Up @@ -118,9 +119,15 @@ class LazyTransformParamsMutator(PyExprMutator):
The module to be transformed
"""

def __init__(self, mod: IRModule = None) -> None:
def __init__(
self, fget_item, fset_item, get_item_param, set_item_param, mod: IRModule = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the usage, the get_item_param and set_item_param are lists of elements, so they should be plural get_item_params and set_item_params. (If they are required overall. See other comment asking if they are.)

This would also be an advantage of breaking the functionality apart into two separate mutators: A LazyInput mutator could have an extra_params argument, and a LazyOutput mutator could have an extra_params argument, and both would be clear from the context. With both mutations implemented in the same class, if we wanted the name to indicate that these are additional arguments, we'd need to make it be extra_get_item_params and extra_set_item_params, which is a bit long for ease of use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot fully separate these mutators because they need to share input_tuple_param and func.params[0] will change after one mutator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so what I can do is create a class that does all the analysis, calls into the two mutators, and create function. In this case, we have to use extra_get_item_params and extra_set_item_params

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, looks like the analysis would need to be split apart as well. In general, I think having multiple smaller changes is better than a single mutator, but this seems like a refactor than would make sense to ask in a feature addition PR. The current revision looks good, and any additional changes can be in a follow-up PR if desired.

) -> None:
super().__init__(mod)
self.mod = mod
self.fget_item = fget_item
self.get_item_param = get_item_param
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of the get_item_param and set_item_param? I don't see any unit test with their intended use. Since the same parameters are forwarded to all get_item and set_item calls, it seems the same effect could be achieved by having a stateful callback, rather than passing an argument which is then passed right back to you.

Copy link
Contributor Author

@jinhongyii jinhongyii Nov 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. I forget to add a test on this. The case is that I need to call get_item(loader, index). Since shard loader is created at runtime and is a DRef object, it cannot be a stateful callback and I have to modify the interface of get_item. Does this make sense to you?

self.fset_item = fset_item
self.set_item_param = set_item_param
# the only input param, which should be a Tuple
self.input_tuple_param = None
self.input_params_set = None
Expand Down Expand Up @@ -149,9 +156,13 @@ def transform(self, func: relax.Function) -> relax.Function:
# Step 3. rewrite get item and set item
new_body = self.visit_expr(func.body)

# Step 4. Find all shape parameters that should be retained as
# parameters.
# Step 4. Add parameters of get_item and set_item (except index) to the function.
params = []
for param in itertools.chain(self.get_item_param, self.set_item_param):
jinhongyii marked this conversation as resolved.
Show resolved Hide resolved
params.append(param)

# Step 5. Find all shape parameters that should be retained as
# parameters.
symbolic_vars = relax.analysis.defined_symbolic_vars(func)
if symbolic_vars:
# direct iterate over the struct info annotation
Expand All @@ -173,8 +184,8 @@ def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> relax.Expr:
if tuple_get_item.tuple_value == self.input_tuple_param:
get_item_result = self.builder_.emit(
relax.Call(
relax.ExternFunc("get_item"),
[relax.PrimValue(tuple_get_item.index)],
relax.ExternFunc(self.fget_item),
self.get_item_param + [relax.PrimValue(tuple_get_item.index)],
None,
[relax.ObjectStructInfo()],
)
Expand All @@ -188,24 +199,24 @@ def visit_var_(self, var: relax.Var) -> None:
return super().visit_var_(var)

def visit_var_binding_(self, binding: relax.VarBinding) -> None:
if binding.var == self.out_tuple_var:
if self.fset_item is not None and binding.var == self.out_tuple_var:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the handling of set_item and of get_item is largely independent. Rather than having them as two features of a single mutator, which requires the self.fset_item is not None check to be repeated throughout, can we split it into two distinct mutators? One mutator replaces access of parameters with get_item. The other mutator replaces binding of output with set_item. If both features are desired, then both mutators are applied.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good. let me try

# The function after rewriting returns a empty tuple.
func_output = self.builder_.emit(relax.Tuple([]))
self.set_var_remap(binding.var.vid, func_output)
return

super().visit_var_binding_(binding)

if binding.var in self.memory_free_insertion:
if self.fset_item is not None and binding.var in self.memory_free_insertion:
for var in self.memory_free_insertion[binding.var]:
if var in self.out_tuple_map:
self.killed_vars.add(var)
for index in self.out_tuple_map[var]:
# rewrite set item
self.builder_.emit(
relax.Call(
relax.ExternFunc("set_item"),
[index, super().visit_var_(var)],
relax.ExternFunc(self.fset_item),
self.set_item_param + [index, super().visit_var_(var)],
None,
[relax.ObjectStructInfo()],
),
Expand All @@ -225,10 +236,38 @@ class LazyTransformParams:
(Load the input to memory on demand, and immediately free it after the last use.)

Note: ToNonDataflow() and RemovePurityTracking() should be invoked before this pass.

Parameters
----------
fget_item: str
The name of the get_item function.
fset_item: str
The name of the set_item function.
get_item_param: list of relax.Var
The parameters of the get_item function except index.
The given parameters will be placed before index.
For example, if get_item_param is [param1, param2], then the pass will generate
call_packed(fget_item, [param1, param2, index])
set_item_param: list of relax.Var
The parameters of the set_item function except index and value.
The given parameters will be placed before index and value.
For example, if set_item_param is [param1, param2], then the pass will generate
call_packed(fset_item, [param1, param2, index, value])
"""

def __init__(
self, fget_item="get_item", fset_item="set_item", get_item_param=[], set_item_param=[]
jinhongyii marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
self.fget_item = fget_item
self.get_item_param = get_item_param
assert self.fget_item is not None, "transforming set_item only is not supported"
self.fset_item = fset_item
self.set_item_param = set_item_param

def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule:
lazy_mutator = LazyTransformParamsMutator(mod)
lazy_mutator = LazyTransformParamsMutator(
self.fget_item, self.fset_item, self.get_item_param, self.set_item_param, mod
)
for gv, _ in mod.functions_items():
if gv.name_hint.endswith("transform_params"):
func = mod[gv]
Expand Down
85 changes: 85 additions & 0 deletions tests/python/relax/test_transform_lazy_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,91 @@ def main_transform_params() -> R.Tuple:
tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)


def test_get_item_only():
@I.ir_module
class Before:
@T.prim_func
def transform_layout_IOHW_to_OIHW(
w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32")
):
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]

@R.function
def main_transform_params(
params: R.Tuple(
R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32")
)
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32")
):
# we expect ToNonDataflow and RemovePurityTracking to be invoked first
R.func_attr({"relax.force_pure": True})
cls = Before
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
lv3 = R.add(lv2, R.const(1, "float32"))
gv: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((16, 3, 3, 3), dtype="float32"),
) = (lv, lv3)
return gv

@I.ir_module
class Expected:
@T.prim_func
def transform_layout_IOHW_to_OIHW(
w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32")
):
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]

@R.function(pure=False)
def main_transform_params() -> (
R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32")
)
):
cls = Expected
gv: R.Object = R.call_packed("get_item_0", R.prim_value(1), sinfo_args=(R.Object,))
gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
gv, R.Tensor((16, 16, 3, 3), dtype="float32")
)
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
gv2: R.Object = R.call_packed("get_item_0", R.prim_value(0), sinfo_args=(R.Object,))
gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast(
gv2, R.Tensor((3, 16, 3, 3), dtype="float32")
)
lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = gv3
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
lv3: R.Tensor((16, 3, 3, 3), dtype="float32") = R.add(lv2, R.const(1, "float32"))
gv_1: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32")
) = (lv, lv3)
return gv_1

after = LazyTransformParams(fget_item="get_item_0", fset_item=None)(Before)
tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)


def test_lazy_transform_params_with_symbolic_vars():
@I.ir_module
class Before:
Expand Down
Loading