Skip to content

Commit

Permalink
[Unity] update input_info with real tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
liquanfeng committed Oct 22, 2023
1 parent 1ba11f6 commit 3037396
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
27 changes: 24 additions & 3 deletions python/tvm/relax/frontend/torch/dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,33 @@ def to_tvm_tensor(torch_tensor):

device = device_from_inputs(example_inputs)

assert len(example_inputs)

fake_inputs = []
if isinstance(example_inputs[0], torch._subclasses.fake_tensor.FakeTensor):
# Fake tensors
fake_inputs = example_inputs
else:
# Real tensors
for node in graph_module.graph.nodes:
if node.op != "placeholder":
continue
if "grapharg" not in node.meta:
continue
fake_tensor = node.meta["grapharg"].fake_tensor
if fake_tensor is None:
continue
fake_inputs.append(fake_tensor)

input_info = []
for tensor in example_inputs:
shape_vars = {}
for tensor in fake_inputs:
shape = []
for s in tensor.shape:
if isinstance(s, torch.SymInt):
shape.append(tvm.tir.Var(str(s), "int64"))
if str(s) not in shape_vars:
shape_vars[str(s)] = tvm.tir.Var(str(s), "int64")
shape.append(shape_vars[str(s)])
else:
shape.append(s)
input_info.append((shape, tensor.dtype))
Expand Down Expand Up @@ -110,7 +131,7 @@ def to_tvm_tensor(torch_tensor):
vm = tvm.relax.VirtualMachine(ex.mod, device=dev)

def exec_tvm(*i_args):
args = [a.contiguous() for a in i_args]
args = [a.contiguous() for a in i_args if isinstance(a, torch.Tensor)]
vm_args = list()
for arg in args:
if arg.dim() != 0:
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,7 +1483,10 @@ def from_fx(
# Translate the model.
for node in graph.nodes:
if node.op == "placeholder":
assert len(inputs) > 0, "Provided inputs is less than actual inputs"
if "grapharg" in node.meta and node.meta["grapharg"].fake_tensor is None:
# Ignore sym input
continue

self.env[node] = inputs.pop(0)
elif node.op == "output":
args = self.retrieve_args(node)
Expand Down

0 comments on commit 3037396

Please sign in to comment.