diff --git a/python/shark_turbine/aot/builtins/jittable.py b/python/shark_turbine/aot/builtins/jittable.py index d2c85b73f..df8fea7e5 100644 --- a/python/shark_turbine/aot/builtins/jittable.py +++ b/python/shark_turbine/aot/builtins/jittable.py @@ -45,6 +45,7 @@ from ..passes import ( functorch_functionalize, + remove_alias ) from ..support.utils import ( @@ -115,8 +116,8 @@ def resolver(py_value: Any, gni: GraphNodeImporter) -> Optional[Value]: return resolver -ALL_PASSES: Set[str] = set(["functorch_functionalize"]) -DEFAULT_PASSES: Tuple[str, ...] = ("functorch_functionalize",) +ALL_PASSES: Set[str] = set(["functorch_functionalize", "remove_alias"]) +DEFAULT_PASSES: Tuple[str, ...] = ("functorch_functionalize", "remove_alias") class jittable(CallableIntrinsic): @@ -205,6 +206,10 @@ def flat_wrapped_f(*args): transformed_f = flat_wrapped_f if "functorch_functionalize" in self._passes: transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args) + if "remove_alias" in self._passes: + transformed_f = remove_alias(transformed_f, *flat_pytorch_args) + if "functorch_functionalize" in self._passes: + transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args) # Ask dynamo to give us an aten graph. # TODO: Cache this for repeated calls. diff --git a/python/shark_turbine/aot/passes/__init__.py b/python/shark_turbine/aot/passes/__init__.py index 167b8b886..05c4a4e6c 100644 --- a/python/shark_turbine/aot/passes/__init__.py +++ b/python/shark_turbine/aot/passes/__init__.py @@ -5,3 +5,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .functorch import functorch_functionalize +from .remove_alias import remove_alias diff --git a/python/shark_turbine/aot/passes/remove_alias.py b/python/shark_turbine/aot/passes/remove_alias.py new file mode 100644 index 000000000..0f9ec1ff8 --- /dev/null +++ b/python/shark_turbine/aot/passes/remove_alias.py @@ -0,0 +1,66 @@ +from typing import Callable + +import torch +from torch.fx import ( + GraphModule, + Node, +) +from torch.fx.experimental import proxy_tensor +from torch.utils import _pytree as pytree +import operator as py_operator + + +def remove_unbind(gm: GraphModule) -> GraphModule: + # Find all unbind nodes + unbind_nodes = [] + for node in gm.graph.nodes: + if node.target == torch.ops.aten.unbind.int: + unbind_nodes.append(node) + + to_erase = [] + + # Replace all unbind -> getitem chains with a index_select node + for unbind in unbind_nodes: + only_getitem = True + for user in unbind.users: + if user.op != "call_function": + only_getitem = False + continue + if user.target != py_operator.getitem: + only_getitem = False + continue + if not only_getitem: + continue + + unbind_dim = 0 + if len(unbind.args) == 2: + unbind_dim = unbind.args[1] + + for user in unbind.users: + # Get the getitem indices + index = user.args[1] + with gm.graph.inserting_before(user): + select = gm.graph.call_function( + torch.select, + (unbind.args[0], unbind_dim, index), + ) + # Replace the getitem node with the index_select node + user.replace_all_uses_with(select) + + # Delete the getitem + to_erase.append(user) + + to_erase.append(unbind) + + # Erase all the getitem nodes + for node in to_erase: + gm.graph.erase_node(node) + gm.recompile() + + return gm + + +def remove_alias(gm: GraphModule, *args) -> GraphModule: + # Replace unbind -> getitem chains with index_select + gm = remove_unbind(gm) + return gm