From 3af36b8e34fcf179cf79b3cdbd7abc0e56adf21e Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Sun, 21 Jan 2024 05:16:39 +0530 Subject: [PATCH 1/2] Add Fx Pass to remove unbind -> getitem chains RNN lowerings provided by pytorch use unbind to convert hidden layer outputs into a list, and then index it later. This doesn't fit well with FxImporter, because it doesn't really have a concept of a list. This pass takes these chains, and converts them into torch.select, so that we don't have lists again. --- python/shark_turbine/aot/builtins/jittable.py | 9 ++- python/shark_turbine/aot/passes/__init__.py | 1 + .../shark_turbine/aot/passes/remove_alias.py | 78 +++++++++++++++++++ python/shark_turbine/dynamo/passes.py | 1 + 4 files changed, 87 insertions(+), 2 deletions(-) create mode 100644 python/shark_turbine/aot/passes/remove_alias.py diff --git a/python/shark_turbine/aot/builtins/jittable.py b/python/shark_turbine/aot/builtins/jittable.py index d2c85b73f..df8fea7e5 100644 --- a/python/shark_turbine/aot/builtins/jittable.py +++ b/python/shark_turbine/aot/builtins/jittable.py @@ -45,6 +45,7 @@ from ..passes import ( functorch_functionalize, + remove_alias ) from ..support.utils import ( @@ -115,8 +116,8 @@ def resolver(py_value: Any, gni: GraphNodeImporter) -> Optional[Value]: return resolver -ALL_PASSES: Set[str] = set(["functorch_functionalize"]) -DEFAULT_PASSES: Tuple[str, ...] = ("functorch_functionalize",) +ALL_PASSES: Set[str] = set(["functorch_functionalize", "remove_alias"]) +DEFAULT_PASSES: Tuple[str, ...] = ("functorch_functionalize", "remove_alias") class jittable(CallableIntrinsic): @@ -205,6 +206,10 @@ def flat_wrapped_f(*args): transformed_f = flat_wrapped_f if "functorch_functionalize" in self._passes: transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args) + if "remove_alias" in self._passes: + transformed_f = remove_alias(transformed_f, *flat_pytorch_args) + if "functorch_functionalize" in self._passes: + transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args) # Ask dynamo to give us an aten graph. # TODO: Cache this for repeated calls. diff --git a/python/shark_turbine/aot/passes/__init__.py b/python/shark_turbine/aot/passes/__init__.py index 167b8b886..05c4a4e6c 100644 --- a/python/shark_turbine/aot/passes/__init__.py +++ b/python/shark_turbine/aot/passes/__init__.py @@ -5,3 +5,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .functorch import functorch_functionalize +from .remove_alias import remove_alias diff --git a/python/shark_turbine/aot/passes/remove_alias.py b/python/shark_turbine/aot/passes/remove_alias.py new file mode 100644 index 000000000..f96cd9849 --- /dev/null +++ b/python/shark_turbine/aot/passes/remove_alias.py @@ -0,0 +1,78 @@ +from typing import Callable + +import torch +from torch.fx import ( + GraphModule, + Node, +) +from torch.fx.experimental import proxy_tensor +from torch.utils import _pytree as pytree +import operator as py_operator + + +class Test(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + unrolled = torch.unbind(x, 1) + return unrolled[0], unrolled[1], unrolled[2], unrolled[3] + + +trace: GraphModule = torch.fx.symbolic_trace(Test()) + + +def remove_unbind(gm: GraphModule) -> GraphModule: + # Find all unbind nodes + unbind_nodes = [] + for node in gm.graph.nodes: + if node.target == torch.ops.aten.unbind.int: + unbind_nodes.append(node) + + to_erase = [] + + # Replace all unbind -> getitem chains with a index_select node + for unbind in unbind_nodes: + only_getitem = True + for user in unbind.users: + if user.op != "call_function": + only_getitem = False + continue + if user.target != py_operator.getitem: + only_getitem = False + continue + if not only_getitem: + continue + + unbind_dim = 0 + if len(unbind.args) == 2: + unbind_dim = unbind.args[1] + + for user in unbind.users: + # Get the getitem indices + index = user.args[1] + with gm.graph.inserting_before(user): + select = gm.graph.call_function( + torch.select, + (unbind.args[0], unbind_dim, index), + ) + # Replace the getitem node with the index_select node + user.replace_all_uses_with(select) + + # Delete the getitem + to_erase.append(user) + + to_erase.append(unbind) + + # Erase all the getitem nodes + for node in to_erase: + gm.graph.erase_node(node) + gm.recompile() + + return gm + + +def remove_alias(gm: GraphModule, *args) -> GraphModule: + # Replace unbind -> getitem chains with index_select + gm = remove_unbind(gm) + return gm diff --git a/python/shark_turbine/dynamo/passes.py b/python/shark_turbine/dynamo/passes.py index 88c08f6ad..26a0bc3eb 100644 --- a/python/shark_turbine/dynamo/passes.py +++ b/python/shark_turbine/dynamo/passes.py @@ -48,6 +48,7 @@ torch.ops.aten._log_softmax_backward_data, torch.ops.aten.lift_fresh_copy.default, torch.ops.aten._unsafe_index.Tensor, + torch.ops.aten.sinc, # decompositions added manually in this file torch.ops.aten._scaled_dot_product_flash_attention.default, ] From f595a0102dd9e82a44cf65d9cf993e625b0438ac Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Sun, 21 Jan 2024 05:22:18 +0530 Subject: [PATCH 2/2] Remove unrelated changes --- python/shark_turbine/aot/passes/remove_alias.py | 12 ------------ python/shark_turbine/dynamo/passes.py | 1 - 2 files changed, 13 deletions(-) diff --git a/python/shark_turbine/aot/passes/remove_alias.py b/python/shark_turbine/aot/passes/remove_alias.py index f96cd9849..0f9ec1ff8 100644 --- a/python/shark_turbine/aot/passes/remove_alias.py +++ b/python/shark_turbine/aot/passes/remove_alias.py @@ -10,18 +10,6 @@ import operator as py_operator -class Test(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - unrolled = torch.unbind(x, 1) - return unrolled[0], unrolled[1], unrolled[2], unrolled[3] - - -trace: GraphModule = torch.fx.symbolic_trace(Test()) - - def remove_unbind(gm: GraphModule) -> GraphModule: # Find all unbind nodes unbind_nodes = [] diff --git a/python/shark_turbine/dynamo/passes.py b/python/shark_turbine/dynamo/passes.py index 26a0bc3eb..88c08f6ad 100644 --- a/python/shark_turbine/dynamo/passes.py +++ b/python/shark_turbine/dynamo/passes.py @@ -48,7 +48,6 @@ torch.ops.aten._log_softmax_backward_data, torch.ops.aten.lift_fresh_copy.default, torch.ops.aten._unsafe_index.Tensor, - torch.ops.aten.sinc, # decompositions added manually in this file torch.ops.aten._scaled_dot_product_flash_attention.default, ]