Skip to content

Commit

Permalink
#12857: rm padding
Browse files Browse the repository at this point in the history
  • Loading branch information
sjameelTT committed Oct 3, 2024
1 parent 65aa05a commit 32607a3
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 10 deletions.
66 changes: 66 additions & 0 deletions tests/ttnn/unit_tests/test_to_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc_without_tensor_printout
import random


@pytest.mark.parametrize("height", [32, 30])
Expand Down Expand Up @@ -125,3 +126,68 @@ def test_untilize_with_unpadding_W_16(device, in_dtype, use_multicore, use_pack_
passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_input, output_torch)
logger.info(pcc_msg)
assert passing


# sharded fails
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
def test_to_layout_rm(device, use_program_cache):
torch.manual_seed(0)
torch_input = torch.rand([1, 1, 337920, 1], dtype=torch.bfloat16)
print("torch input", torch_input)
x = ttnn.from_torch(torch_input, dtype=ttnn.bfloat16)
x = ttnn.to_layout(x, ttnn.TILE_LAYOUT)
x = ttnn.to_device(x, device)

sharded_memory_config = ttnn.create_sharded_memory_config(
[1, 1, 337920, 32], ttnn.CoreGrid(x=8, y=8), ttnn.ShardStrategy.HEIGHT
)
x = ttnn.to_memory_config(x, sharded_memory_config)
print("input_shape:", x.shape)
print()
print(x)

x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) # This fails
print("output shape:", x.shape)

torch_output = ttnn.to_torch(x)
assert_with_pcc(torch_input, torch_output)


@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
def test_to_layout_rm_int(device, use_program_cache):
torch.manual_seed(0)
torch_input = torch.rand([1, 1, 337920, 1], dtype=torch.bfloat16)
print("torch input", torch_input)
x = ttnn.from_torch(torch_input, dtype=ttnn.bfloat16)
x = ttnn.to_layout(x, ttnn.TILE_LAYOUT)
x = ttnn.to_device(x, device)

print("input_shape:", x.shape)
print()
print(x)

x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) # This fails
print("output shape:", x.shape)

torch_output = ttnn.to_torch(x)
assert_with_pcc(torch_input, torch_output)


# native creation of RM tensor
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
def test_to_layout_rm_creation_int(device, use_program_cache):
torch.manual_seed(0)
torch_input = torch.rand([1, 1, 337920, 1], dtype=torch.bfloat16)
print("torch input", torch_input)
x = ttnn.from_torch(torch_input, dtype=ttnn.bfloat16)
x = ttnn.to_device(x, device)

print("input_shape:", x.shape)
print()
print(x)

x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) # This fails
print("output shape:", x.shape)

torch_output = ttnn.to_torch(x)
assert_with_pcc(torch_input, torch_output)
10 changes: 8 additions & 2 deletions ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,19 @@ Tensor to_layout_impl(
tt::tt_metal::MemoryConfig{memory_config.memory_layout, memory_config.buffer_type};
}
std::vector<uint32_t> output_tensor_end;
std::vector<uint32_t> padded_shape;
for (auto index = 0; index < tensor.get_shape().rank(); ++index) {
output_tensor_end.push_back(tensor.get_shape()[index] - 1);
if (index == tensor.get_shape().rank() - 1) {
output_tensor_end.push_back(tensor.get_shape()[index] + tensor.get_shape()[index] % 2);
} else {
output_tensor_end.push_back(tensor.get_shape()[index] - 1);
}
padded_shape.push_back(output_tensor_end[index] + 1);
}

tensor =
ttnn::untilize_with_unpadding(tensor, output_tensor_end, output_memory_config, use_multicore_untilize);
return reshape(tensor, ttnn::Shape(tt::tt_metal::LegacyShape{output_shape}));
return reshape(tensor, ttnn::Shape(tt::tt_metal::LegacyShape{output_shape, padded_shape}));

} else if (layout == ttnn::TILE_LAYOUT) {
std::vector<uint32_t> padded_output_shape;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ std::vector<tt::tt_metal::LegacyShape> UntilizeWithUnpadding::compute_output_sha
out_shape.reserve(rank);
for (uint32_t i = 0; i < rank; i++) {
out_shape.push_back(this->output_tensor_end[i] + 1);
std::cout << "out_shape[" << i << "] = " << out_shape[i] << std::endl;
}
tt::tt_metal::LegacyShape output_tensor_shape(out_shape);
return {output_tensor_shape};
Expand Down
8 changes: 1 addition & 7 deletions ttnn/cpp/ttnn/tensor/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ uint32_t get_page_size(DataType dtype, Layout layout, uint32_t total_size_bytes,
switch (layout) {
case Layout::ROW_MAJOR: {
uint32_t size_of_element = element_size_bytes(dtype);
page_size = W * size_of_element;
page_size = std::max(W * size_of_element, (uint32_t)sizeof(uint32_t));
} break;
case Layout::TILE: {
// TODO: Update to be generic for data type (issue 462)
Expand Down Expand Up @@ -262,12 +262,6 @@ void validate_on_device_dtype_and_layout(Device* device, const tt::tt_metal::Leg
break;
case DataType::UINT16:
case DataType::BFLOAT16:
if (layout == Layout::ROW_MAJOR) {
TT_ASSERT(
shape[-1] % 2 == 0,
"For ROW_MAJOR layout tensors with dtype BFLOAT16 or UINT16, tensor width must be divisible by "
"2 since data is packed as uint32_t when creating buffers on device!");
}
break;
case DataType::BFLOAT8_B:
case DataType::BFLOAT4_B:
Expand Down
1 change: 0 additions & 1 deletion ttnn/ttnn/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,6 @@ def to_torch(
tensor = tensor.to_torch()
slices = [slice(None, x) for x in shape_without_tile_padding]
tensor = tensor[slices]

if torch_rank is not None:
while len(tensor.shape) != torch_rank:
if tensor.shape[0] != 1:
Expand Down

0 comments on commit 32607a3

Please sign in to comment.