Skip to content

Commit

Permalink
[Unity] make LazyTransformParam more general (#16088)
Browse files Browse the repository at this point in the history
* lazy transform params

* format

* address comment

* fix

* fix ci

* fix ci
  • Loading branch information
jinhongyii authored Nov 9, 2023
1 parent 276b4ce commit 171ef61
Show file tree
Hide file tree
Showing 2 changed files with 268 additions and 23 deletions.
117 changes: 94 additions & 23 deletions python/tvm/relax/transform/lazy_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument, missing-function-docstring, abstract-method
# pylint: disable=invalid-name, unused-argument, missing-function-docstring, abstract-method, missing-class-docstring
"""Relax LazyTransformParams pass."""
from typing import Optional

import tvm
from tvm import IRModule, relax
from tvm.relax.expr_functor import PyExprMutator, PyExprVisitor, mutator, visitor
Expand Down Expand Up @@ -107,8 +109,7 @@ def visit_var_binding_(self, binding: relax.VarBinding) -> None:
self.var_liveness_end[binding.var] = self.last_appear_in_var_binding


@mutator
class LazyTransformParamsMutator(PyExprMutator):
class LazyTransformParamsFuncCreator:
"""
Transform transform_params functions into a lazy version.
Expand All @@ -118,16 +119,25 @@ class LazyTransformParamsMutator(PyExprMutator):
The module to be transformed
"""

def __init__(self, mod: IRModule = None) -> None:
super().__init__(mod)
def __init__(
self,
fget_item,
fset_item,
extra_get_item_params,
extra_set_item_params,
mod: IRModule = None,
) -> None:
self.mod = mod
self.fget_item = fget_item
self.extra_get_item_params = extra_get_item_params
self.fset_item = fset_item
self.extra_set_item_params = extra_set_item_params
# the only input param, which should be a Tuple
self.input_tuple_param = None
self.input_params_set = None
self.out_tuple_map = None
self.out_tuple_var = None
self.memory_free_insertion = None
self.killed_vars = set()

def transform(self, func: relax.Function) -> relax.Function:
self.input_tuple_param = func.params[0]
Expand All @@ -147,11 +157,18 @@ def transform(self, func: relax.Function) -> relax.Function:
self.memory_free_insertion = liveness.var_liveness_end

# Step 3. rewrite get item and set item
new_body = self.visit_expr(func.body)
new_body = func.body
if self.fget_item is not None:
new_body = LazyInputMutator(self, self.mod).visit_expr(new_body)

if self.fset_item is not None:
new_body = LazyOutputMutator(self, self.mod).visit_expr(new_body)

# Step 4. Find all shape parameters that should be retained as
# Step 4. Add parameters of get_item and set_item (except index) to the function.
params = [*self.extra_get_item_params, *self.extra_set_item_params]

# Step 5. Find all shape parameters that should be retained as
# parameters.
params = []
symbolic_vars = relax.analysis.defined_symbolic_vars(func)
if symbolic_vars:
# direct iterate over the struct info annotation
Expand All @@ -167,14 +184,22 @@ def transform(self, func: relax.Function) -> relax.Function:
is_pure=False,
).without_attr("relax.force_pure")


@mutator
class LazyInputMutator(PyExprMutator):
def __init__(self, func_creator, mod: Optional[IRModule] = None) -> None:
self.func_creator = func_creator
super().__init__(mod)

def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> relax.Expr:
# rewrite get item
tuple_get_item = super().visit_tuple_getitem_(op)
if tuple_get_item.tuple_value == self.input_tuple_param:
if tuple_get_item.tuple_value == self.func_creator.input_tuple_param:
get_item_result = self.builder_.emit(
relax.Call(
relax.ExternFunc("get_item"),
[relax.PrimValue(tuple_get_item.index)],
relax.ExternFunc(self.func_creator.fget_item),
self.func_creator.extra_get_item_params
+ [relax.PrimValue(tuple_get_item.index)],
None,
[relax.ObjectStructInfo()],
)
Expand All @@ -183,36 +208,45 @@ def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> relax.Expr:
else:
return tuple_get_item


@mutator
class LazyOutputMutator(PyExprMutator):
def __init__(self, func_creator, mod: Optional[IRModule] = None) -> None:
self.func_creator = func_creator
self.killed_vars = set()
super().__init__(mod)

def visit_var_(self, var: relax.Var) -> None:
assert var not in self.killed_vars
return super().visit_var_(var)

def visit_var_binding_(self, binding: relax.VarBinding) -> None:
if binding.var == self.out_tuple_var:
if binding.var == self.func_creator.out_tuple_var:
# The function after rewriting returns a empty tuple.
func_output = self.builder_.emit(relax.Tuple([]))
self.set_var_remap(binding.var.vid, func_output)
return

super().visit_var_binding_(binding)

if binding.var in self.memory_free_insertion:
for var in self.memory_free_insertion[binding.var]:
if var in self.out_tuple_map:
if binding.var in self.func_creator.memory_free_insertion:
for var in self.func_creator.memory_free_insertion[binding.var]:
if var in self.func_creator.out_tuple_map:
self.killed_vars.add(var)
for index in self.out_tuple_map[var]:
for index in self.func_creator.out_tuple_map[var]:
# rewrite set item
self.builder_.emit(
relax.Call(
relax.ExternFunc("set_item"),
[index, super().visit_var_(var)],
relax.ExternFunc(self.func_creator.fset_item),
self.func_creator.extra_set_item_params
+ [index, super().visit_var_(var)],
None,
[relax.ObjectStructInfo()],
),
name_hint="_",
)

if var in self.input_params_set:
if var in self.func_creator.input_params_set:
self.builder_.emit(
relax.op.vm.kill_object(super().visit_var_(var)), name_hint="_"
)
Expand All @@ -225,16 +259,53 @@ class LazyTransformParams:
(Load the input to memory on demand, and immediately free it after the last use.)
Note: ToNonDataflow() and RemovePurityTracking() should be invoked before this pass.
Parameters
----------
fget_item: str
The name of the get_item function.
fset_item: str
The name of the set_item function.
extra_get_item_params: list of relax.Var
The parameters of the get_item function except index.
The given parameters will be placed before index.
For example, if extra_get_item_params is [param1, param2], then the pass will generate
call_packed(fget_item, [param1, param2, index])
extra_set_item_params: list of relax.Var
The parameters of the set_item function except index and value.
The given parameters will be placed before index and value.
For example, if extra_set_item_params is [param1, param2], then the pass will generate
call_packed(fset_item, [param1, param2, index, value])
"""

def __init__(
self,
fget_item="get_item",
fset_item="set_item",
extra_get_item_params=None,
extra_set_item_params=None,
) -> None:
self.fget_item = fget_item
self.extra_get_item_params = [] if extra_get_item_params is None else extra_get_item_params
assert self.fget_item is not None, "transforming set_item only is not supported"
self.fset_item = fset_item
self.extra_set_item_params = [] if extra_set_item_params is None else extra_set_item_params

def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule:
lazy_mutator = LazyTransformParamsMutator(mod)
lazy_mutator = LazyTransformParamsFuncCreator(
self.fget_item,
self.fset_item,
self.extra_get_item_params,
self.extra_set_item_params,
mod,
)
builder = relax.BlockBuilder(mod)
for gv, _ in mod.functions_items():
if gv.name_hint.endswith("transform_params"):
func = mod[gv]
if not isinstance(func, relax.Function):
continue
func = lazy_mutator.transform(func)
lazy_mutator.builder_.update_func(gv, func)
builder.update_func(gv, func)

return lazy_mutator.builder_.get()
return builder.get()
174 changes: 174 additions & 0 deletions tests/python/relax/test_transform_lazy_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,180 @@ def main_transform_params() -> R.Tuple:
tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)


def test_get_item_only():
@I.ir_module
class Before:
@T.prim_func
def transform_layout_IOHW_to_OIHW(
w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32")
):
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]

@R.function
def main_transform_params(
params: R.Tuple(
R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32")
)
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32")
):
# we expect ToNonDataflow and RemovePurityTracking to be invoked first
R.func_attr({"relax.force_pure": True})
cls = Before
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
lv3 = R.add(lv2, R.const(1, "float32"))
gv: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((16, 3, 3, 3), dtype="float32"),
) = (lv, lv3)
return gv

@I.ir_module
class Expected:
@T.prim_func
def transform_layout_IOHW_to_OIHW(
w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32")
):
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]

@R.function(pure=False)
def main_transform_params() -> (
R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32")
)
):
cls = Expected
gv: R.Object = R.call_packed("get_item_0", R.prim_value(1), sinfo_args=(R.Object,))
gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
gv, R.Tensor((16, 16, 3, 3), dtype="float32")
)
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
gv2: R.Object = R.call_packed("get_item_0", R.prim_value(0), sinfo_args=(R.Object,))
gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast(
gv2, R.Tensor((3, 16, 3, 3), dtype="float32")
)
lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = gv3
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
lv3: R.Tensor((16, 3, 3, 3), dtype="float32") = R.add(lv2, R.const(1, "float32"))
gv_1: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32")
) = (lv, lv3)
return gv_1

after = LazyTransformParams(fget_item="get_item_0", fset_item=None)(Before)
tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)


def test_extra_params():
@I.ir_module
class Before:
@T.prim_func
def transform_layout_IOHW_to_OIHW(
w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32")
):
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]

@R.function
def main_transform_params(
params: R.Tuple(
R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32")
)
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32")
):
# we expect ToNonDataflow and RemovePurityTracking to be invoked first
R.func_attr({"relax.force_pure": True})
cls = Before
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
lv3 = R.add(lv2, R.const(1, "float32"))
gv: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((16, 3, 3, 3), dtype="float32"),
) = (lv, lv3)
return gv

@I.ir_module
class Expected:
@T.prim_func
def transform_layout_IOHW_to_OIHW(
w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32")
):
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]

@R.function(pure=False)
def main_transform_params(loader: R.Object) -> R.Tuple:
cls = Expected
gv: R.Object = R.call_packed(
"get_item", loader, R.prim_value(1), sinfo_args=(R.Object,)
)
gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
gv, R.Tensor((16, 16, 3, 3), dtype="float32")
)
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
_: R.Object = R.call_packed("set_item", R.prim_value(0), lv, sinfo_args=(R.Object,))
_1: R.Tuple = R.vm.kill_object(lv)
gv2: R.Object = R.call_packed(
"get_item", loader, R.prim_value(0), sinfo_args=(R.Object,)
)
gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast(
gv2, R.Tensor((3, 16, 3, 3), dtype="float32")
)
lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = gv3
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
_2: R.Tuple = R.vm.kill_object(lv1)
lv3: R.Tensor((16, 3, 3, 3), dtype="float32") = R.add(lv2, R.const(1, "float32"))
_3: R.Object = R.call_packed("set_item", R.prim_value(1), lv3, sinfo_args=(R.Object,))
gv_1: R.Tuple = R.tuple()
return gv_1

after = LazyTransformParams(
extra_get_item_params=[relax.Var("loader", relax.ObjectStructInfo())]
)(Before)
tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)


def test_lazy_transform_params_with_symbolic_vars():
@I.ir_module
class Before:
Expand Down

0 comments on commit 171ef61

Please sign in to comment.