-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unity] make LazyTransformParam more general #16088
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
# pylint: disable=invalid-name, unused-argument, missing-function-docstring, abstract-method | ||
"""Relax LazyTransformParams pass.""" | ||
import tvm | ||
import itertools | ||
from tvm import IRModule, relax | ||
from tvm.relax.expr_functor import PyExprMutator, PyExprVisitor, mutator, visitor | ||
|
||
|
@@ -118,9 +119,15 @@ class LazyTransformParamsMutator(PyExprMutator): | |
The module to be transformed | ||
""" | ||
|
||
def __init__(self, mod: IRModule = None) -> None: | ||
def __init__( | ||
self, fget_item, fset_item, get_item_param, set_item_param, mod: IRModule = None | ||
) -> None: | ||
super().__init__(mod) | ||
self.mod = mod | ||
self.fget_item = fget_item | ||
self.get_item_param = get_item_param | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the purpose of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My bad. I forget to add a test on this. The case is that I need to call |
||
self.fset_item = fset_item | ||
self.set_item_param = set_item_param | ||
# the only input param, which should be a Tuple | ||
self.input_tuple_param = None | ||
self.input_params_set = None | ||
|
@@ -149,9 +156,13 @@ def transform(self, func: relax.Function) -> relax.Function: | |
# Step 3. rewrite get item and set item | ||
new_body = self.visit_expr(func.body) | ||
|
||
# Step 4. Find all shape parameters that should be retained as | ||
# parameters. | ||
# Step 4. Add parameters of get_item and set_item (except index) to the function. | ||
params = [] | ||
for param in itertools.chain(self.get_item_param, self.set_item_param): | ||
jinhongyii marked this conversation as resolved.
Show resolved
Hide resolved
|
||
params.append(param) | ||
|
||
# Step 5. Find all shape parameters that should be retained as | ||
# parameters. | ||
symbolic_vars = relax.analysis.defined_symbolic_vars(func) | ||
if symbolic_vars: | ||
# direct iterate over the struct info annotation | ||
|
@@ -173,8 +184,8 @@ def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> relax.Expr: | |
if tuple_get_item.tuple_value == self.input_tuple_param: | ||
get_item_result = self.builder_.emit( | ||
relax.Call( | ||
relax.ExternFunc("get_item"), | ||
[relax.PrimValue(tuple_get_item.index)], | ||
relax.ExternFunc(self.fget_item), | ||
self.get_item_param + [relax.PrimValue(tuple_get_item.index)], | ||
None, | ||
[relax.ObjectStructInfo()], | ||
) | ||
|
@@ -188,24 +199,24 @@ def visit_var_(self, var: relax.Var) -> None: | |
return super().visit_var_(var) | ||
|
||
def visit_var_binding_(self, binding: relax.VarBinding) -> None: | ||
if binding.var == self.out_tuple_var: | ||
if self.fset_item is not None and binding.var == self.out_tuple_var: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like the handling of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good. let me try |
||
# The function after rewriting returns a empty tuple. | ||
func_output = self.builder_.emit(relax.Tuple([])) | ||
self.set_var_remap(binding.var.vid, func_output) | ||
return | ||
|
||
super().visit_var_binding_(binding) | ||
|
||
if binding.var in self.memory_free_insertion: | ||
if self.fset_item is not None and binding.var in self.memory_free_insertion: | ||
for var in self.memory_free_insertion[binding.var]: | ||
if var in self.out_tuple_map: | ||
self.killed_vars.add(var) | ||
for index in self.out_tuple_map[var]: | ||
# rewrite set item | ||
self.builder_.emit( | ||
relax.Call( | ||
relax.ExternFunc("set_item"), | ||
[index, super().visit_var_(var)], | ||
relax.ExternFunc(self.fset_item), | ||
self.set_item_param + [index, super().visit_var_(var)], | ||
None, | ||
[relax.ObjectStructInfo()], | ||
), | ||
|
@@ -225,10 +236,38 @@ class LazyTransformParams: | |
(Load the input to memory on demand, and immediately free it after the last use.) | ||
|
||
Note: ToNonDataflow() and RemovePurityTracking() should be invoked before this pass. | ||
|
||
Parameters | ||
---------- | ||
fget_item: str | ||
The name of the get_item function. | ||
fset_item: str | ||
The name of the set_item function. | ||
get_item_param: list of relax.Var | ||
The parameters of the get_item function except index. | ||
The given parameters will be placed before index. | ||
For example, if get_item_param is [param1, param2], then the pass will generate | ||
call_packed(fget_item, [param1, param2, index]) | ||
set_item_param: list of relax.Var | ||
The parameters of the set_item function except index and value. | ||
The given parameters will be placed before index and value. | ||
For example, if set_item_param is [param1, param2], then the pass will generate | ||
call_packed(fset_item, [param1, param2, index, value]) | ||
""" | ||
|
||
def __init__( | ||
self, fget_item="get_item", fset_item="set_item", get_item_param=[], set_item_param=[] | ||
jinhongyii marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> None: | ||
self.fget_item = fget_item | ||
self.get_item_param = get_item_param | ||
assert self.fget_item is not None, "transforming set_item only is not supported" | ||
self.fset_item = fset_item | ||
self.set_item_param = set_item_param | ||
|
||
def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule: | ||
lazy_mutator = LazyTransformParamsMutator(mod) | ||
lazy_mutator = LazyTransformParamsMutator( | ||
self.fget_item, self.fset_item, self.get_item_param, self.set_item_param, mod | ||
) | ||
for gv, _ in mod.functions_items(): | ||
if gv.name_hint.endswith("transform_params"): | ||
func = mod[gv] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the usage, the
get_item_param
andset_item_param
are lists of elements, so they should be pluralget_item_params
andset_item_params
. (If they are required overall. See other comment asking if they are.)This would also be an advantage of breaking the functionality apart into two separate mutators: A
LazyInput
mutator could have anextra_params
argument, and aLazyOutput
mutator could have anextra_params
argument, and both would be clear from the context. With both mutations implemented in the same class, if we wanted the name to indicate that these are additional arguments, we'd need to make it beextra_get_item_params
andextra_set_item_params
, which is a bit long for ease of use.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cannot fully separate these mutators because they need to share
input_tuple_param
and func.params[0] will change after one mutator.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so what I can do is create a class that does all the analysis, calls into the two mutators, and create function. In this case, we have to use
extra_get_item_params
andextra_set_item_params
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, looks like the analysis would need to be split apart as well. In general, I think having multiple smaller changes is better than a single mutator, but this seems like a refactor than would make sense to ask in a feature addition PR. The current revision looks good, and any additional changes can be in a follow-up PR if desired.