Skip to content

Commit

Permalink
#12857: add some tests and make the allocator more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
sjameelTT committed Oct 3, 2024
1 parent 0d07586 commit 546cd5b
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 6 deletions.
31 changes: 28 additions & 3 deletions tests/ttnn/unit_tests/test_to_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_untilize_with_unpadding_W_16(device, in_dtype, use_multicore, use_pack_

# sharded fails
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
def test_to_layout_rm(device, use_program_cache):
def test_to_layout_rm_sharded(device, use_program_cache):
torch.manual_seed(0)
torch_input = torch.rand([1, 1, 337920, 1], dtype=torch.bfloat16)
print("torch input", torch_input)
Expand Down Expand Up @@ -166,7 +166,7 @@ def test_to_layout_rm_int(device, use_program_cache):
print()
print(x)

x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) # This fails
x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT)
print("output shape:", x.shape)

torch_output = ttnn.to_torch(x)
Expand All @@ -186,7 +186,32 @@ def test_to_layout_rm_creation_int(device, use_program_cache):
print()
print(x)

x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) # This fails
x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT)
print("output shape:", x.shape)

torch_output = ttnn.to_torch(x)
assert_with_pcc(torch_input, torch_output)


@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
def test_to_layout_rm_s2i(device, use_program_cache):
torch.manual_seed(0)
torch_input = torch.rand([1, 1, 337920, 1], dtype=torch.bfloat16)
print("torch input", torch_input)
x = ttnn.from_torch(torch_input, dtype=ttnn.bfloat16)
x = ttnn.to_layout(x, ttnn.TILE_LAYOUT)
x = ttnn.to_device(x, device)

sharded_memory_config = ttnn.create_sharded_memory_config(
[1, 1, 337920, 32], ttnn.CoreGrid(x=8, y=8), ttnn.ShardStrategy.HEIGHT
)
x = ttnn.to_memory_config(x, sharded_memory_config)
print("input_shape:", x.shape)
print()
print(x)

x = ttnn.to_memory_config(x, ttnn.L1_MEMORY_CONFIG)
x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT)
print("output shape:", x.shape)

torch_output = ttnn.to_torch(x)
Expand Down
3 changes: 2 additions & 1 deletion ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ Tensor to_layout_impl(
std::vector<uint32_t> padded_shape;
for (auto index = 0; index < tensor.get_shape().rank(); ++index) {
if (index == tensor.get_shape().rank() - 1) {
output_tensor_end.push_back(tensor.get_shape()[index] + tensor.get_shape()[index] % 2);
uint32_t round_to_4B = ((tensor.get_shape()[index]*tensor.element_size() )% sizeof(uint32_t)) / tensor.element_size();
output_tensor_end.push_back(tensor.get_shape()[index] + round_to_4B - 1);
} else {
output_tensor_end.push_back(tensor.get_shape()[index] - 1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ std::vector<tt::tt_metal::LegacyShape> UntilizeWithUnpadding::compute_output_sha
out_shape.reserve(rank);
for (uint32_t i = 0; i < rank; i++) {
out_shape.push_back(this->output_tensor_end[i] + 1);
std::cout << "out_shape[" << i << "] = " << out_shape[i] << std::endl;
}
tt::tt_metal::LegacyShape output_tensor_shape(out_shape);
return {output_tensor_shape};
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/tensor/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ uint32_t get_page_size(DataType dtype, Layout layout, uint32_t total_size_bytes,
switch (layout) {
case Layout::ROW_MAJOR: {
uint32_t size_of_element = element_size_bytes(dtype);
page_size = std::max(W * size_of_element, (uint32_t)sizeof(uint32_t));
page_size = (W * size_of_element + sizeof(uint32_t) - 1) & ~(sizeof(uint32_t) - 1);
} break;
case Layout::TILE: {
// TODO: Update to be generic for data type (issue 462)
Expand Down

0 comments on commit 546cd5b

Please sign in to comment.