Skip to content

Commit

Permalink
Add Fx Pass to remove unbind -> getitem chains
Browse files Browse the repository at this point in the history
RNN lowerings provided by pytorch use unbind to convert hidden layer
outputs into a list, and then index it later. This doesn't fit well
with FxImporter, because it doesn't really have a concept of a list.
This pass takes these chains, and converts them into torch.select,
so that we don't have lists again.
  • Loading branch information
Groverkss committed Jan 20, 2024
1 parent 86653a4 commit 3af36b8
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 2 deletions.
9 changes: 7 additions & 2 deletions python/shark_turbine/aot/builtins/jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

from ..passes import (
functorch_functionalize,
remove_alias
)

from ..support.utils import (
Expand Down Expand Up @@ -115,8 +116,8 @@ def resolver(py_value: Any, gni: GraphNodeImporter) -> Optional[Value]:
return resolver


ALL_PASSES: Set[str] = set(["functorch_functionalize"])
DEFAULT_PASSES: Tuple[str, ...] = ("functorch_functionalize",)
ALL_PASSES: Set[str] = set(["functorch_functionalize", "remove_alias"])
DEFAULT_PASSES: Tuple[str, ...] = ("functorch_functionalize", "remove_alias")


class jittable(CallableIntrinsic):
Expand Down Expand Up @@ -205,6 +206,10 @@ def flat_wrapped_f(*args):
transformed_f = flat_wrapped_f
if "functorch_functionalize" in self._passes:
transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args)
if "remove_alias" in self._passes:
transformed_f = remove_alias(transformed_f, *flat_pytorch_args)
if "functorch_functionalize" in self._passes:
transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args)

# Ask dynamo to give us an aten graph.
# TODO: Cache this for repeated calls.
Expand Down
1 change: 1 addition & 0 deletions python/shark_turbine/aot/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .functorch import functorch_functionalize
from .remove_alias import remove_alias
78 changes: 78 additions & 0 deletions python/shark_turbine/aot/passes/remove_alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Callable

import torch
from torch.fx import (
GraphModule,
Node,
)
from torch.fx.experimental import proxy_tensor
from torch.utils import _pytree as pytree
import operator as py_operator


class Test(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
unrolled = torch.unbind(x, 1)
return unrolled[0], unrolled[1], unrolled[2], unrolled[3]


trace: GraphModule = torch.fx.symbolic_trace(Test())


def remove_unbind(gm: GraphModule) -> GraphModule:
# Find all unbind nodes
unbind_nodes = []
for node in gm.graph.nodes:
if node.target == torch.ops.aten.unbind.int:
unbind_nodes.append(node)

to_erase = []

# Replace all unbind -> getitem chains with a index_select node
for unbind in unbind_nodes:
only_getitem = True
for user in unbind.users:
if user.op != "call_function":
only_getitem = False
continue
if user.target != py_operator.getitem:
only_getitem = False
continue
if not only_getitem:
continue

unbind_dim = 0
if len(unbind.args) == 2:
unbind_dim = unbind.args[1]

for user in unbind.users:
# Get the getitem indices
index = user.args[1]
with gm.graph.inserting_before(user):
select = gm.graph.call_function(
torch.select,
(unbind.args[0], unbind_dim, index),
)
# Replace the getitem node with the index_select node
user.replace_all_uses_with(select)

# Delete the getitem
to_erase.append(user)

to_erase.append(unbind)

# Erase all the getitem nodes
for node in to_erase:
gm.graph.erase_node(node)
gm.recompile()

return gm


def remove_alias(gm: GraphModule, *args) -> GraphModule:
# Replace unbind -> getitem chains with index_select
gm = remove_unbind(gm)
return gm
1 change: 1 addition & 0 deletions python/shark_turbine/dynamo/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
torch.ops.aten._log_softmax_backward_data,
torch.ops.aten.lift_fresh_copy.default,
torch.ops.aten._unsafe_index.Tensor,
torch.ops.aten.sinc,
# decompositions added manually in this file
torch.ops.aten._scaled_dot_product_flash_attention.default,
]
Expand Down

0 comments on commit 3af36b8

Please sign in to comment.