Skip to content

Commit

Permalink
#0: Reduce padding on UNet Shallow input tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
esmalTT committed Sep 20, 2024
1 parent 2484df6 commit 6c0cb9d
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 24 deletions.
15 changes: 13 additions & 2 deletions models/experimental/functional_unet/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# SPDX-License-Identifier: Apache-2.0

import ttnn
from loguru import logger

from tests.ttnn.utils_for_testing import assert_with_pcc


Expand All @@ -20,15 +22,24 @@ def is_t3k_with_eth_dispatch_cores(mesh_device) -> bool:
return all_devices_using_full_grid and (len(mesh_device.get_devices()) == 8)


def verify_with_pcc(torch_tensor, ttnn_tensor, pcc):
_, computed_pcc = assert_with_pcc(torch_tensor, ttnn_tensor, pcc)
logger.info(f"PCC check was successful ({computed_pcc:.4f} > {pcc:.4f})")
if (computed_pcc - pcc) / pcc > 0.0025:
logger.warning(
f"Computed PCC ({computed_pcc:.4f}) was higher than the expected PCC ({pcc:.4f}) - consider updating the expected PCC value"
)


def check_pcc_conv(torch_tensor, ttnn_tensor, pcc=0.999, mesh_composer=None):
B, C, H, W = torch_tensor.shape
ttnn_tensor = ttnn.to_torch(ttnn_tensor, mesh_composer=mesh_composer).reshape(B, H, W, C).permute(0, 3, 1, 2)
assert_with_pcc(torch_tensor, ttnn_tensor, pcc)
verify_with_pcc(torch_tensor, ttnn_tensor, pcc)


def check_pcc_pool(torch_tensor, ttnn_tensor, pcc=0.999, mesh_composer=None):
B, C, H, W = torch_tensor.shape
ttnn_tensor = (
ttnn.to_torch(ttnn_tensor, mesh_composer=mesh_composer).reshape(B, H, W, -1).permute(0, 3, 1, 2)[:, :C, :, :]
)
assert_with_pcc(torch_tensor, ttnn_tensor, pcc)
verify_with_pcc(torch_tensor, ttnn_tensor, pcc)
7 changes: 2 additions & 5 deletions models/experimental/functional_unet/tests/test_unet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import pytest
import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc

from models.experimental.functional_unet.tt.model_preprocessing import (
create_unet_input_tensors,
create_unet_model_parameters,
)
from models.experimental.functional_unet.tt import unet_shallow_torch
from models.experimental.functional_unet.tt import unet_shallow_ttnn
from models.experimental.functional_unet.tests.common import check_pcc_conv


@pytest.mark.parametrize("batch", [2])
Expand All @@ -28,6 +27,4 @@ def test_unet_model(batch, groups, device, use_program_cache, reset_seeds):
torch_output_tensor = model(torch_input)
output_tensor = ttnn_model(ttnn_input)

B, C, H, W = torch_output_tensor.shape
ttnn_tensor = ttnn.to_torch(output_tensor).reshape(B, H, W, -1)[:, :, :, :C].permute(0, 3, 1, 2)
assert_with_pcc(torch_output_tensor, ttnn_tensor, 0.97)
check_pcc_conv(torch_output_tensor, output_tensor, 0.97)
16 changes: 3 additions & 13 deletions models/experimental/functional_unet/tests/test_unet_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ def test_unet_trace(
logger.info(f"Average model performance={iterations * batch / (end-start) : .2f} fps")

logger.info(f"Running sanity check against reference model output")
B, C, H, W = torch_output_tensor.shape
ttnn_tensor = ttnn.to_torch(outputs[-1]).reshape(B, H, W, -1)[:, :, :, :C].permute(0, 3, 1, 2)
assert_with_pcc(torch_output_tensor, ttnn_tensor, 0.97)
check_pcc_conv(torch_output_tensor, outputs[-1], 0.97)


@skip_for_grayskull("UNet not currently supported on GS")
Expand Down Expand Up @@ -184,9 +182,7 @@ def test_unet_trace_2cq(
ttnn.DumpDeviceProfiler(device)

logger.info(f"Running sanity check against reference model output")
B, C, H, W = torch_output_tensor.shape
ttnn_tensor = ttnn.to_torch(outputs[-1]).reshape(B, H, W, -1)[:, :, :, :C].permute(0, 3, 1, 2)
assert_with_pcc(torch_output_tensor, ttnn_tensor, 0.97)
check_pcc_conv(torch_output_tensor, outputs[-1], 0.97)

ttnn.release_trace(device, tid)

Expand Down Expand Up @@ -317,12 +313,6 @@ def test_unet_trace_2cq_multi_device(
logger.info(f"Average model performance={iterations * total_batch / (end-start) : .2f} fps")

logger.info(f"Running sanity check against reference model output")
B, C, H, W = torch_output_tensor.shape
ttnn_tensor = (
ttnn.to_torch(outputs[-1], mesh_composer=output_mesh_composer)
.reshape(B, H, W, -1)[:, :, :, :C]
.permute(0, 3, 1, 2)
)
assert_with_pcc(torch_output_tensor, ttnn_tensor, 0.97)
check_pcc_conv(torch_output_tensor, outputs[-1], 0.97, mesh_composer=output_mesh_composer)

ttnn.release_trace(mesh_device, tid)
6 changes: 3 additions & 3 deletions models/experimental/functional_unet/tt/model_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ def create_unet_input_tensors(
ttnn_input_tensor.shape[3],
)
if pad_input:
# Pad to 16 if grayskull run and 32 for wormhole
pad = 32 if device.arch() == ttnn.device.Arch.WORMHOLE_B0 else 16
hpad = 0 # 96*32*64
pad, hpad = 16, 0
if ttnn_input_tensor.shape[-1] < pad or ttnn_input_tensor.shape[-2] < hpad:
ttnn_input_tensor = torch.nn.functional.pad(
ttnn_input_tensor,
Expand Down Expand Up @@ -68,6 +66,8 @@ def create_unet_model_parameters(model: unet_shallow_torch.UNet, input_tensor: t
parameters.c1["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 8 * 32}
parameters.c1["use_split_reader"] = True
parameters.c1["use_activation_double_buffer"] = True
parameters.c1["input_channels_alignment"] = 16

parameters.c1_2["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 8 * 32}
parameters.c1_2["use_split_reader"] = True
parameters.c1_2["use_activation_double_buffer"] = True
Expand Down
6 changes: 5 additions & 1 deletion models/experimental/functional_unet/tt/unet_shallow_ttnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
from ttnn.model_preprocessing import fold_batch_norm2d_into_conv2d, ParameterDict


def nearest_16(x):
return math.ceil(x / 16) * 16


def determine_num_cores_for_upsample(nhw: int, width: int, max_cores=64) -> int:
gcd_nhw_width = math.gcd(nhw, width)
cores = nhw // gcd_nhw_width
Expand Down Expand Up @@ -514,7 +518,7 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None:
self.downblock1.conv1.batch_size
* self.downblock1.conv1.input_height
* self.downblock1.conv1.input_width,
nearest_32(self.downblock1.conv1.in_channels),
nearest_16(self.downblock1.conv1.in_channels),
],
ttnn.CoreGrid(x=8, y=8),
ttnn.ShardStrategy.HEIGHT,
Expand Down

0 comments on commit 6c0cb9d

Please sign in to comment.