Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Update relax_dynamo #15962

Merged
merged 2 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions python/tvm/relax/frontend/torch/dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,37 @@ def to_tvm_tensor(torch_tensor):
real_tensor = torch.randn(torch_tensor.shape, dtype=torch_tensor.dtype)
return tvm.nd.array(real_tensor.numpy())

graph_module.graph.eliminate_dead_code()

device = device_from_inputs(example_inputs)

assert len(example_inputs)

fake_inputs = []
if isinstance(example_inputs[0], torch._subclasses.fake_tensor.FakeTensor):
# Fake tensors
fake_inputs = example_inputs
else:
# Real tensors
for node in graph_module.graph.nodes:
if node.op != "placeholder":
continue
if "grapharg" not in node.meta:
continue
fake_tensor = node.meta["grapharg"].fake_tensor
if fake_tensor is None:
continue
fake_inputs.append(fake_tensor)

input_info = []
for tensor in example_inputs:
shape_vars = {}
for tensor in fake_inputs:
shape = []
for s in tensor.shape:
if isinstance(s, torch.SymInt):
shape.append(tvm.tir.Var(str(s), "int64"))
if str(s) not in shape_vars:
shape_vars[str(s)] = tvm.tir.Var(str(s), "int64")
shape.append(shape_vars[str(s)])
else:
shape.append(s)
input_info.append((shape, tensor.dtype))
Expand Down Expand Up @@ -110,7 +133,7 @@ def to_tvm_tensor(torch_tensor):
vm = tvm.relax.VirtualMachine(ex.mod, device=dev)

def exec_tvm(*i_args):
args = [a.contiguous() for a in i_args]
args = [a.contiguous() for a in i_args if isinstance(a, torch.Tensor)]
vm_args = list()
for arg in args:
if arg.dim() != 0:
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,6 +1484,10 @@ def from_fx(
for node in graph.nodes:
if node.op == "placeholder":
assert len(inputs) > 0, "Provided inputs is less than actual inputs"
if "grapharg" in node.meta and node.meta["grapharg"].fake_tensor is None:
# Ignore sym input
continue

self.env[node] = inputs.pop(0)
elif node.op == "output":
args = self.retrieve_args(node)
Expand Down
15 changes: 15 additions & 0 deletions tests/python/relax/test_frontend_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,21 @@ def forward(self, x):
opt_model(inp).detach().numpy(), model(inp).detach().numpy(), rtol=1e-5, atol=1e-5
)

def Func1(x, y):
z = torch.cat([x, y])
if z.size(0) > 5:
return z.mul(2)
else:
return z.add(2)

opt_func = torch.compile(Func1, backend=relax_dynamo(), dynamic=True)

for s in (2, 4):
x = torch.randn(s, 100)
y = torch.randn(s, 100)
with torch.no_grad():
tvm.testing.assert_allclose(opt_func(x, y), opt_func(x, y))


def test_subgraph_capture():
import torch
Expand Down
Loading