Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Fx Pass to remove unbind -> getitem chains #361

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/shark_turbine/aot/builtins/jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

from ..passes import (
functorch_functionalize,
remove_alias
)

from ..support.utils import (
Expand Down Expand Up @@ -115,8 +116,8 @@ def resolver(py_value: Any, gni: GraphNodeImporter) -> Optional[Value]:
return resolver


ALL_PASSES: Set[str] = set(["functorch_functionalize"])
DEFAULT_PASSES: Tuple[str, ...] = ("functorch_functionalize",)
ALL_PASSES: Set[str] = set(["functorch_functionalize", "remove_alias"])
DEFAULT_PASSES: Tuple[str, ...] = ("functorch_functionalize", "remove_alias")


class jittable(CallableIntrinsic):
Expand Down Expand Up @@ -205,6 +206,10 @@ def flat_wrapped_f(*args):
transformed_f = flat_wrapped_f
if "functorch_functionalize" in self._passes:
transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args)
if "remove_alias" in self._passes:
transformed_f = remove_alias(transformed_f, *flat_pytorch_args)
if "functorch_functionalize" in self._passes:
transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args)

# Ask dynamo to give us an aten graph.
# TODO: Cache this for repeated calls.
Expand Down
1 change: 1 addition & 0 deletions python/shark_turbine/aot/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .functorch import functorch_functionalize
from .remove_alias import remove_alias
66 changes: 66 additions & 0 deletions python/shark_turbine/aot/passes/remove_alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import Callable

import torch
from torch.fx import (
GraphModule,
Node,
)
from torch.fx.experimental import proxy_tensor
from torch.utils import _pytree as pytree
import operator as py_operator


def remove_unbind(gm: GraphModule) -> GraphModule:
# Find all unbind nodes
unbind_nodes = []
for node in gm.graph.nodes:
if node.target == torch.ops.aten.unbind.int:
unbind_nodes.append(node)

to_erase = []

# Replace all unbind -> getitem chains with a index_select node
for unbind in unbind_nodes:
only_getitem = True
for user in unbind.users:
if user.op != "call_function":
only_getitem = False
continue
if user.target != py_operator.getitem:
only_getitem = False
continue
if not only_getitem:
continue

unbind_dim = 0
if len(unbind.args) == 2:
unbind_dim = unbind.args[1]

for user in unbind.users:
# Get the getitem indices
index = user.args[1]
with gm.graph.inserting_before(user):
select = gm.graph.call_function(
torch.select,
(unbind.args[0], unbind_dim, index),
)
# Replace the getitem node with the index_select node
user.replace_all_uses_with(select)

# Delete the getitem
to_erase.append(user)

to_erase.append(unbind)

# Erase all the getitem nodes
for node in to_erase:
gm.graph.erase_node(node)
gm.recompile()

return gm


def remove_alias(gm: GraphModule, *args) -> GraphModule:
# Replace unbind -> getitem chains with index_select
gm = remove_unbind(gm)
return gm
Loading