Skip to content

Commit

Permalink
Merge M0 ops (#54)
Browse files Browse the repository at this point in the history
* change tracer type as a decorator, its API change, too

* torch_ttnn.backend apply tracer config

* mv json dump out of parse_fx_graph

* Wrap leaf func with function form instead of decorator

* Rename TorchTtnnOption to TenstorrentBackendOption

* Revert "Use subgraph_rewriter"

This reverts commit fc09080.

* Extrace mock_ttnn to a standalone package

* Register ttnn backend

* Add setup.py and pyproject.toml

* Update README.md

* mv torch_ttnn/tracer.py tracer/tracer.py

* Add model_sow2_list1 in tools/stat_models

* Fix counter bug

* fix try except

* detr_resnet50 retain_graph=True

* Update README.md for package building

* Update test case for ttnn inferface change

- ttnn.open(0) -> ttnn.open_device(device_id=0)
- ttnn.close(d) -> ttnn.close_device(d)

* Convert 14 pointwise binary operations

Add conversion and unit test cases

- add
- eq
- gt
- logical_and
- logical_or
- logical_xor
- lt
- maximum
- minimum
- mul
- ne
- pow
- sub
- xlogy

* Convert 35 pointwise unary operations

- aten.abs
- aten.acos
- aten.acosh
- aten.asin
- aten.asinh
- aten.atan
- aten.atan2  # binary
- aten.atanh
- aten.clone
- aten.cos
- aten.cosh
- aten.erf
- aten.exp
- aten.expm1
- aten.gelu
- aten.hardtanh
- aten.isinf
- aten.isnan
- aten.leaky_relu
- aten.log
- aten.log10
- aten.log1p
- aten.log2
- aten.logical_not
- aten.neg
- aten.reciprocal
- aten.relu
- aten.rsqrt
- aten.sigmoid
- aten.sign
- aten.sin
- aten.sinh
- aten.sqrt
- aten.tan
- aten.tanh

* Convert 3 pointwise trinary operations

- addcdiv
- addcmul
- where

* Convert 2 matmul operations

Also use fx.subgraph_rewriter

- matmul
- linear

* Simplify op conversion

* Fix wrap for ttnn & update test

- ttnn.add(and other ops) don't have __name__, so torch.compile will fail. We hard patch the op with the __name__
- Now ttnn need a to_layout before computation

* Fix more ops unit test

* Simpify pass insertion for gen_graphviz

* Update test cases for to_layour

* Fix add_data_move support kwargs

* Support linear without bias

* Don't gen graphviz for pointwise unary test

* Fix three ops, and verify some ops

3 op are fixed
- clone
- hardtanh
- leaky_relu

Following ops are verifyed and wont be fixed
- atan2: ttnn bug while x or y is 0
- pow: exponent don't suporrt tensor type
- xlogy: y==1 should be 0

* Simplify binary test, Add scalar support

* Supprt addcmul & addcdiv

* Fix support of addcdiv and where

- addcdiv with option is a special case in torch, need special pattern matching
- ttnn.where(c, x, y) c can not be bool, need cast

* Convert and test repeat op

* Add silu conversion

* Update new test cases according to torch.compile interface change

* Simplify unary test cases, reuse impl code

* Update test_stat import statment

* Add group_norm conversion

* Try convert layer_norm but the result is different

* Support repeat

* Update trinary tests

* Support concat

* Support split

* Update format

* [wip] group_norm

* group_norm use customized replacement

transformer method not work because it replace native_group_norm->getitem as ttnn.group_norm
pattern replacement not work because it use symbolic trace, which using proxy, not value,
  however, we need to do the value dependent conversion

* Add more test_group_norm

* Add more test_layer_norm

* Remove unused patterns/norm.py

* refactor, rename, add comment

* Support x.t()

* move layer_norm impl to customized_replace

* updatfor ttnn version 3023ec0f7

* Refactor test case

Don't know why but if I reuse some code in test case, wrap it as a function, it will fail if vscode run more than 8 cases at the same time. It is really strange.

I currently have no idea how it happen,
I can just refactor the case to dup code.

* Merge .gitignore

* Resolve conflicts in README.md

* Use incoming test_fall_back.py

* Use our tests/tools/test_stats.py

* Resolve generate_report.py

* Preserve intermediate node meta for composite ops and add checks in respective tests (#50)

Co-authored-by: Artem Yerofieiev <[email protected]>

* Resolve torch_ttnn/__init__.py

* Resolve confict

* Remove duplicate entries from .gitignore

* Update metadata in setup.py

* Correct the name of the test module in test_datamove.py

* Remove duplicate test for softmax

* Fix merge errors, now binary test passed.

* Remove test_if.py as `if` is already tested by lowering/misc/test_if.py

* Remove test_only_add_matmul.py, superseded by lowering/matmul/test_only_add_matmul.py

* Test group_norm

* Test torch.matmul -> ttnn.matmul

* Test compiling torch.nn.functional.linear

* Refactor test for CSE

* Remove test_norm.py as we've done its migration

* This test no longer tests falling back to torch op but division, which should be handled in lowering/eltwise/binary/test_div.py instead

* Convert tests for unary eltwise ops

* Fix loading pytest arguments

* Convert tests for binary eltwise ops

* Fix precision test for bivariate ops

* Fix precision test for univariate ops

* Remove test_pointwise_trinary.py for flexibility

Trivariate functions differ too much to share the same testing protocol

* Test torch.reshape -> ttnn.reshape

* Test compiling aten.repeat

* Test compiling torch.cat

* Remove test_datamove.py because all its tests have been moved to lowering/tensor_manipulation/

* Remove test already covered by test_resnet.py

* Use more descriptive names in torch_ttnn/patterns/add.py

Co-authored-by: Artem Yerofieiev <[email protected]>

* Use more descriptive names in torch_ttnn/patterns/addcdiv.py

Co-authored-by: Artem Yerofieiev <[email protected]>

* Simpler path to resolve ttnn.addcdiv

Co-authored-by: Artem Yerofieiev <[email protected]>

* Make test names unique for easier batch testing with pytest

* Fix import target_wrappers

* Move code in add_coreops_pass.py to add_data_move_pass.py first to help refactoring

* Refactor lowering univariate functions

* Simplify control flow in lowering

* Refactor lowering bivariate functions

* Sort ops to lower

* Lower to ttnn.atan2

* Lower to ttnn.leaky_relu

* Lower default hardtanh to ttnn.clip for now

* Lower addcdiv and addcmul

* Remove migrated lowering code

* Fix names of ttnn ops

* Remove duplicate entries in the op list

* Remove unused pattern file

* Remove the file that is already merged into add_data_move_pass.py

* Test broadcasting for bivariate ops

* Test broadcasting for all bivariate ops

* Remove intermediate test for `div` so we can test broadcastinig

* Regroup ops based on API docs

https://docs.tenstorrent.com/tt-metal/latest/ttnn/ttnn/api.html

* Remove ops not working:

- Comparison ops
- Logical ops
- aten.pow.Tensor_Tensor

* Apply @kevinwuTT's patch on tt.repeat at model teardown

* Format the code with `black` for a consistent style

* Reformat with the specified config

* Mark tests xfail based on #64

* Remove test for unsupported pow(tensor, tensor)

* Mark broadcasting issues (#64) with atan2 and xlogy

* Reformat the code

* Mark broadcasting issues (#64) with (min|max)imum

* Mark broadcasting issues (#64) with subtraction

* Mark numerical issues with atan2

* Try setting realistic tolerance for low precision math ops

* Tolerate more with pointwise unary math ops

* Reflect that we convert torch.hardtanh to ttnn.clip for now

* Remove test for unsupported group norm

* Mark conversion failure with `linear`

* Fix test_clone.py for the patch

* Mark argument mismtach in `arange` (#65)

* Link #66 to test_linear.py

* Mark lowering issues with tensor manipulation (#67)

* Reciprocal needs an offset because it has a pole at 0

* More tolerance for matmul for its accumulated error

* Symetrically mark xfail for (min|max)imum

* Merge human-made docs from README.md to docs/README.md.in

* Do not use braces for shell variables to avoid clashing with .format

* Generate README.md with new metrics

* Mark xfail for xlogy involving broadcasting

xlogy asserts the same size for inputs for now

---------

Co-authored-by: swimdi <[email protected]>
Co-authored-by: yoco <[email protected]>
Co-authored-by: Zahid Wakeel <[email protected]>
Co-authored-by: Artem Yerofieiev <[email protected]>
Co-authored-by: yoco <[email protected]>
  • Loading branch information
6 people authored Aug 22, 2024
1 parent 68a590d commit f87d7b3
Show file tree
Hide file tree
Showing 73 changed files with 2,750 additions and 279 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
__pycache__
dist
torch_ttnn.egg-info
venv
.vscode
stat
*.dot
*.svg
*.csv
*.pt
metrics
data
data
59 changes: 41 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ The table below summarizes the results of running various ML models through our

| Model | Run Success | Torch Ops Before (Unique Ops) | Torch Ops Remain (Unique Ops) | To/From Device Ops | Original Run Time (ms) | Compiled Run Time (ms) | Accuracy (%) |
|:------------------------------------|:--------------|:--------------------------------|:--------------------------------|---------------------:|-------------------------:|:-------------------------|:---------------|
| [Mnist (Eval)](tests/models/mnist) || 14 (8) | 5 (4) | 16 | 38.64 | 501.5 | 99.85 |
| [Mnist (Train)](tests/models/mnist) || 14 (8) | 7 (5) | 14 | 136.38 | 2709.01 | 66.84 |
| [ResNet18](tests/models/resnet) || 70 (9) | 42 (4) | 47 | 2131.05 | 9985.44 | 99.99 |
| [Bloom](tests/models/bloom) || 1407 (29) | 626 (11) | 1379 | 28892.3 | 68470.67 | 45.77 |
| [YOLOS](tests/models/yolos) || 964 (28) | 409 (11) | 919 | 1410.28 | 45328.58 | 71.71 |
| [Llama](tests/models/llama) || 5 (4) | 3 (2) | 3 | 206771 | 187910.29 | 45.46 |
| [BERT](tests/models/bert) || 1393 (21) | 539 (5) | 1513 | 67347.3 | 60024.8 | 98.64 |
| [Falcon](tests/models/falcon) || 3 (3) | 2 (2) | 5 | 51366.6 | N/A | N/A |
| [GPT-2](tests/models/gpt2) || 748 (31) | 316 (12) | 569 | 5711.32 | N/A | N/A |
| [Mnist (Eval)](tests/models/mnist) || 14 (8) | 5 (4) | 16 | 35.53 | 556.63 | 99.72 |
| [Mnist (Train)](tests/models/mnist) || 14 (8) | 7 (5) | 14 | 114.16 | 3076.17 | 76.59 |
| [ResNet18](tests/models/resnet) || 70 (9) | 42 (4) | 44 | 2023.95 | 10673.42 | 99.99 |
| [Bloom](tests/models/bloom) || 1407 (29) | 626 (11) | 1378 | 28504 | 68025.6 | 45.77 |
| [YOLOS](tests/models/yolos) || 964 (28) | 320 (11) | 825 | 1340.21 | 46101.1 | 71.71 |
| [Llama](tests/models/llama) || 3 (2) | 2 (2) | 2 | 164063 | 166348.21 | 100.0 |
| [BERT](tests/models/bert) || 1393 (21) | 491 (5) | 1465 | 63591.6 | 55096.44 | 98.64 |
| [Falcon](tests/models/falcon) || 3 (3) | 2 (2) | 5 | 46268.6 | N/A | N/A |
| [GPT-2](tests/models/gpt2) || 748 (31) | 307 (12) | 644 | 1793.52 | N/A | N/A |

### Explanation of Metrics

Expand Down Expand Up @@ -135,12 +135,10 @@ The table below summarizes the results of running various ML models through our
| aten.unsqueeze.default || 1 |
| aten.view.default || 283 |
#### Llama
| aten ops | status | count |
|:----------------------|:---------|--------:|
| aten._to_copy.default || 1 |
| aten.mm.default || 1 |
| aten.t.default || 1 |
| aten.view.default || 2 |
| aten ops | status | count |
|:-----------------------|:---------|--------:|
| aten.slice.Tensor || 1 |
| aten.unsqueeze.default || 2 |
#### BERT
| aten ops | status | count |
|:-------------------------------|:---------|--------:|
Expand Down Expand Up @@ -216,7 +214,7 @@ import torch
import torch_ttnn

# A torch Module
class FooModule(torch.Module):
class FooModule(torch.nn.Module):
...
# Create a module
module = FooModule()
Expand All @@ -235,7 +233,7 @@ The tracer dump the information of fx graph such as node's op_name and shape.

For example, you can run this script to parse the information
```
PYTHONPATH=$(pwd) python3 tools/run_torchvision.py --backend torch_stat --backward --profile
PYTHONPATH=$(pwd) python3 tools/stat_models.py --trace_orig --backward --profile
ls stat/raw
```

Expand Down Expand Up @@ -263,10 +261,35 @@ The `*_total_*_size_dist/` statistics the `op_type`'s input/output_size distribu

[The `profile/` is the tools provided by pytorch](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html), you can open it by the url: chrome://tracing


# For developers

## Install torch-ttnn with editable mode

During development, you may want to use the torch-ttnn package for testing.
In order to do that, you can install the torch-ttnn package in "editable"
mode with

```shell
pip install -e .
```

Now, you can utilize `torch_ttnn` in your Python code. Any modifications you make to the `torch_ttnn` package will take effect immediately, eliminating the need for constant reinstallation via pip.

## Build wheel file

For developers want to deploy the wheel, you can build the wheel file with

```shell
python -m build
```

Then you can upload the `.whl` file to the PyPI (Python Package Index).

## Run transformer models
To run transformer model with ttnn backend, run:
```
PYTHONPATH=${TT_METAL_HOME}:$(pwd) python3 tools/run_transformers.py --model "phiyodr/bert-large-finetuned-squad2" --backend torch_ttnn
PYTHONPATH="$TT_METAL_HOME:$(pwd)" python3 tools/run_transformers.py --model "phiyodr/bert-large-finetuned-squad2" --backend torch_ttnn
```

You can also substitute the backend with `torch_stat` to run a reference comparison.
Expand Down
31 changes: 28 additions & 3 deletions docs/README.md.in
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import torch
import torch_ttnn

# A torch Module
class FooModule(torch.Module):
class FooModule(torch.nn.Module):
...
# Create a module
module = FooModule()
Expand All @@ -46,7 +46,7 @@ The tracer dump the information of fx graph such as node's op_name and shape.

For example, you can run this script to parse the information
```
PYTHONPATH=$(pwd) python3 tools/run_torchvision.py --backend torch_stat --backward --profile
PYTHONPATH=$(pwd) python3 tools/stat_models.py --trace_orig --backward --profile
ls stat/raw
```

Expand Down Expand Up @@ -74,10 +74,35 @@ The `*_total_*_size_dist/` statistics the `op_type`'s input/output_size distribu

[The `profile/` is the tools provided by pytorch](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html), you can open it by the url: chrome://tracing


# For developers

## Install torch-ttnn with editable mode

During development, you may want to use the torch-ttnn package for testing.
In order to do that, you can install the torch-ttnn package in "editable"
mode with

```shell
pip install -e .
```

Now, you can utilize `torch_ttnn` in your Python code. Any modifications you make to the `torch_ttnn` package will take effect immediately, eliminating the need for constant reinstallation via pip.

## Build wheel file

For developers want to deploy the wheel, you can build the wheel file with

```shell
python -m build
```

Then you can upload the `.whl` file to the PyPI (Python Package Index).

## Run transformer models
To run transformer model with ttnn backend, run:
```
PYTHONPATH=${{TT_METAL_HOME}}:$(pwd) python3 tools/run_transformers.py --model "phiyodr/bert-large-finetuned-squad2" --backend torch_ttnn
PYTHONPATH="$TT_METAL_HOME:$(pwd)" python3 tools/run_transformers.py --model "phiyodr/bert-large-finetuned-squad2" --backend torch_ttnn
```

You can also substitute the backend with `torch_stat` to run a reference comparison.
Expand Down
13 changes: 13 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[build-system]
requires = ["setuptools==68.0.0", "setuptools-scm==7.1.0", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "torch-ttnn"
authors = [{ name = "Tenstorrent", email = "[email protected]" }]
dependencies = ["torch"]
version = "0.1.0"
description = "PyTorch dynamo backend for Tenstorrent TT-NN framework"
readme = "README.md"
keywords = ["torch", "ttnn"]
requires-python = ">=3.8"
11 changes: 11 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from setuptools import setup

setup(
name="torch-ttnn",
version="0.1.0",
description="PyTorch dynamo backend for Tenstorrent TT-NN framework",
license="MIT",
url="https://github.com/tenstorrent/pytorch2.0_ttnn",
packages=["torch_ttnn"],
install_requires=[],
)
2 changes: 2 additions & 0 deletions tests/lowering/creation/test_arange.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_arange(device, input_shapes):
assert torch.allclose(result_before, result_after)


@pytest.mark.xfail(reason="argument mismatch (#65)")
@pytest.mark.parametrize(
"input_shapes",
[[2, 100]],
Expand All @@ -74,6 +75,7 @@ def test_arange_start(device, input_shapes):
assert torch.allclose(result_before, result_after)


@pytest.mark.xfail(reason="argument mismatch (#65)")
@pytest.mark.parametrize(
"input_shapes",
[[4, 100, 3]],
Expand Down
6 changes: 3 additions & 3 deletions tests/lowering/creation/test_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_clone_from_arg(device, input_shapes):

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.clone) == 1
assert [node.target for node in nodes].count(torch_ttnn.target_wrappers.clone) == 1
# Check inference result
assert torch.allclose(result_before, result_after)

Expand All @@ -63,8 +63,8 @@ def test_clone_from_node(device, input_shapes):
# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
target = [node.target for node in nodes]
assert target.count(ttnn.clone) == 1
clone_arg_0 = nodes[target.index(ttnn.clone)].args[0].target
assert target.count(torch_ttnn.target_wrappers.clone) == 1
clone_arg_0 = nodes[target.index(torch_ttnn.target_wrappers.clone)].args[0].target
assert isinstance(clone_arg_0, ttnn.decorators.FastOperation) or isinstance(
clone_arg_0, ttnn.decorators.Operation
)
Expand Down
22 changes: 15 additions & 7 deletions tests/lowering/eltwise/binary/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,36 @@ class AddModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x + x
def forward(self, x, y):
return x + y


@pytest.mark.parametrize(
"input_shapes",
[
[(4, 4)],
],
(
((32, 32), (32, 32)),
((64,), (32, 64)),
((64, 32), (64, 1)),
pytest.param(
((64, 1), (1, 64)),
marks=pytest.mark.xfail(reason="broadcasting issues (#64)"),
),
),
)
def test_add(device, input_shapes):
m = AddModule()
inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes]
inputs = [torch.randint(1, 5, shape).type(torch.bfloat16) for shape in input_shapes]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
assert 1 == len(option._out_fx_graphs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.add) == 1

# Check inference result
assert torch.allclose(result_before, result_after)
44 changes: 44 additions & 0 deletions tests/lowering/eltwise/binary/test_atan2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
import torch_ttnn
import pytest
import ttnn
from tests.utils import assert_with_pcc


class Atan2Module(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.atan2(x, y)


@pytest.mark.xfail(reason="numerical issues: AssertionError: -0.999[...]")
@pytest.mark.parametrize(
"input_shapes",
(
((4, 4), (4, 4)),
((8, 1), (8, 8)),
pytest.param(
((1, 8), (8, 1)),
marks=pytest.mark.xfail(reason="broadcasting issues (#64)"),
),
),
)
def test_atan2(device, input_shapes):
m = Atan2Module()
inputs = [torch.randint(1, 5, shape).type(torch.bfloat16) for shape in input_shapes]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.atan2) == 1

# Check inference result
assert_with_pcc(result_before, result_after, 0.99)
15 changes: 10 additions & 5 deletions tests/lowering/eltwise/binary/test_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@ def forward(self, numerator, denominator):

@pytest.mark.parametrize(
"input_shapes",
[[(4, 4), (4, 4)], [(64, 128), (64, 128)]],
(
((32, 32), (32, 32)),
((64,), (32, 64)),
((64, 32), (64, 1)),
pytest.param(
((64, 1), (1, 64)),
marks=pytest.mark.xfail(reason="broadcasting issues (#64)"),
),
),
)
def test_div(device, input_shapes):
m = DivModule()
Expand All @@ -35,10 +43,7 @@ def test_div(device, input_shapes):
assert target.count(ttnn.mul) == 1
assert target.index(ttnn.reciprocal) < target.index(ttnn.mul)
assert nodes[target.index(ttnn.mul)].args[1].target == ttnn.reciprocal
# Intermediate node meta check if preserved
for node in nodes:
if node.target == ttnn.full or node.target == ttnn.reciprocal:
assert node.meta["val"].size() == input_shapes[0]

# Check inference result
assert_with_pcc(result_before, result_after)

Expand Down
3 changes: 2 additions & 1 deletion tests/lowering/eltwise/binary/test_eq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ def forward(self, input1, input2):
return torch.eq(input1, input2)


@pytest.mark.xfail(reason="broadcasting issues (#64)")
@pytest.mark.parametrize(
"input_shapes",
[[(4, 4), (4, 4)]],
(((32, 32), (32, 32)), ((64,), (32, 64)), ((64, 32), (64, 1)), ((64, 1), (1, 64))),
)
def test_eq_tensor(device, input_shapes):
m = EqModule()
Expand Down
37 changes: 37 additions & 0 deletions tests/lowering/eltwise/binary/test_gt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
import torch_ttnn
import pytest
import ttnn
from tests.utils import assert_with_pcc


class GtModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.gt(x, y)


@pytest.mark.xfail(reason="broadcasting issues (#64)")
@pytest.mark.parametrize(
"input_shapes",
(((32, 32), (32, 32)), ((64,), (32, 64)), ((64, 32), (64, 1)), ((64, 1), (1, 64))),
)
def test_gt(device, input_shapes):
m = GtModule()
inputs = [torch.randint(1, 5, shape).type(torch.bfloat16) for shape in input_shapes]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.gt) == 1

# Check inference result
assert torch.allclose(result_before, result_after)
Loading

0 comments on commit f87d7b3

Please sign in to comment.