Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix conversion for aten.expand #146

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open

Fix conversion for aten.expand #146

wants to merge 5 commits into from

Conversation

jdh8
Copy link
Collaborator

@jdh8 jdh8 commented Sep 4, 2024

Ticket

Subproblem of tenstorrent/tt-metal#12853

Problem description

Try fixing aten.expandttnn.repeat to unblock aten.repeat

What's changed

  • Reenable test for aten.expand

Error messages

FAILED tests/lowering/tensor_manipulation/test_expand.py::test_expand[input_shape0-new_shape0] - AssertionError: assert 0 == 1
FAILED tests/lowering/tensor_manipulation/test_expand.py::test_expand_after_op[input_shape0-new_shape0] - TypeError: __call__(): incompatible function arguments. The following argument types are supported:
FAILED tests/lowering/tensor_manipulation/test_expand.py::test_expand_before_op[input_shape0-new_shape0] - AssertionError: assert 0 == 1
FAILED tests/lowering/tensor_manipulation/test_expand.py::test_expand_between_ops[input_shape0-new_shape0] - TypeError: __call__(): incompatible function arguments. The following argument types are supported:
E       TypeError: __call__(): incompatible function arguments. The following argument types are supported:
E           1. (self: ttnn._ttnn.operations.data_movement.clone_t, input_tensor: ttnn._ttnn.deprecated.tensor.Tensor, *, memory_config: Optional[ttnn._ttnn.deprecated.tensor.MemoryConfig] = None, dtype: Optional[ttnn._ttnn.deprecated.tensor.DataType] = None, queue_id: int = 0) -> ttnn._ttnn.deprecated.tensor.Tensor
E       
E       Invoked with: <ttnn._ttnn.operations.data_movement.clone_t object at 0x7f58bb5e23b0>, ttnn.Tensor([[ 0.12109,  0.38672,  ...,  0.00000,  0.00000],
E                    [ 0.00000,  0.00000,  ...,  0.00000,  0.00000],
E                    ...,
E                    [ 0.00000,  0.00000,  ...,  0.00000,  0.00000],
E                    [ 0.00000,  0.00000,  ...,  0.00000,  0.00000]], shape=Shape([1[32], 4[32]]), dtype=DataType::BFLOAT16, layout=Layout::TILE), MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type=BufferType::DRAM,shard_spec=std::nullopt), <DataType.BFLOAT16: 0>
E       
E       Did you forget to `#include <pybind11/stl.h>`? Or <pybind11/complex.h>,
E       <pybind11/functional.h>, <pybind11/chrono.h>, etc. Some automatic
E       conversions are optional and require extra headers to be included
E       when compiling your pybind11 module.

../../tt-metal/ttnn/ttnn/decorators.py:328: TypeError

As target_wrappers.repeat is already registered for data movement, I wonder why there is still an argument mismatch.

TTNN_TARGET_WRAPPERS = [target_wrappers.clone, target_wrappers.repeat]

@jdh8 jdh8 added the bug Something isn't working label Sep 4, 2024
@ayerofieiev-tt
Copy link
Member

@jdh8 what is the status of this PR?

@jdh8
Copy link
Collaborator Author

jdh8 commented Sep 10, 2024

I was trying to break down tenstorrent/tt-metal#12853 into subproblems per op. Should they be PRs here or tickets at tenstorrent/tt-metal?

@boris-drazic
Copy link
Contributor

The four FAIL asserts in the first message are the same problem as issue tenstorrent/tt-metal#12853, so focusing on the problem specific to expand in second error message TypeError: __call__(): incompatible function arguments.

@jdh8 incompatible function arguments error is not actually for expand but rather for clone OP that you use in two tests: test_expand_after_op and test_expand_between_ops. clone expects ttnn._ttnn.deprecated.tensor.Tensor as input tensor and you are providing ttnn.Tensor. If clone line a = torch.clone(x) is replaced by something like a = x + 0 the error no longer shows.

@boris-drazic
Copy link
Contributor

@jdh8 looking at the assert issue, it is actually of different root cause than for concat, repeat, and reshape.

For this OP (expand) issues is not that the OP remains as aten but rather that pass torch_ttnn/passes/lowering/add_data_move_pass.py replaces repeat OP with target_wrappers.repeat instead of with ttnn.repeat.
As a result, when testing for node targets assert target.count(ttnn.repeat) == 1 this will not work since the target is target_wrappers.repeat. It could be replaced with:

from torch_ttnn.passes.lowering import target_wrappers
assert target.count(target_wrappers.repeat) == 1

Next assert will not work with simple update to look for wrapper repeat (assert nodes[target.index(target_wrappers.repeat)].args[1] == ttnn.Shape) as wrapper repeat function does not specify second argument as ttnn.Shape it just takes an array.

Last assert is already passing.

@jdh8 jdh8 changed the title Try fixing conversion for aten.expand Fix conversion for aten.expand Sep 13, 2024
@jdh8 jdh8 removed the blocked label Sep 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants