Skip to content

Commit

Permalink
#12615: Add queue_id output tensor to slice op, concat_bw (#12718)
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW authored Sep 20, 2024
1 parent 5bb0b10 commit 927ff4b
Show file tree
Hide file tree
Showing 14 changed files with 304 additions and 78 deletions.
127 changes: 127 additions & 0 deletions tests/ttnn/unit_tests/operations/backward/test_backward_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,130 @@ def test_bw_concat_Default(input_shapes, input_shapes_2, device):

comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes, input_shapes_2",
(
((torch.Size([12, 1, 30, 32])), (torch.Size([2, 1, 30, 32]))),
((torch.Size([4, 1, 32, 32])), (torch.Size([5, 1, 32, 32]))),
),
)
@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]])
def test_bw_concat_Default_with_output(input_shapes, input_shapes_2, device, are_required_outputs):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True, True)

other_data, other_tensor = data_gen_with_range(input_shapes_2, -100, 100, device, True, True)

pyt_y = torch.concat((in_data, other_data))

grad_data, grad_tensor = data_gen_with_range(pyt_y.shape, -100, 100, device, True, True)

input_grad = None
other_grad = None

opt_tensor1 = torch.zeros(input_shapes, dtype=torch.bfloat16)
opt_tensor2 = torch.zeros(input_shapes_2, dtype=torch.bfloat16)

if are_required_outputs[0]:
input_grad = ttnn.from_torch(
opt_tensor1, ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG
)
if are_required_outputs[1]:
other_grad = ttnn.from_torch(
opt_tensor2, ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG
)
cq_id = 0

pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.concat_bw(
grad_tensor,
input_tensor,
other_tensor,
are_required_outputs=are_required_outputs,
input_a_grad=input_grad,
input_b_grad=other_grad,
queue_id=cq_id,
)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())

tt_output_tensor_on_device = [input_grad, other_grad]

golden_function = ttnn.get_golden_function(ttnn.concat_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)

status = True
for i in range(len(are_required_outputs)):
if are_required_outputs[i]:
status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]])
assert status


@pytest.mark.parametrize(
"input_shapes, input_shapes_2, dimension",
(
((torch.Size([12, 1, 30, 32])), (torch.Size([2, 1, 30, 32])), 0),
((torch.Size([1, 2, 45, 64])), (torch.Size([1, 1, 45, 64])), 1),
((torch.Size([1, 1, 125, 32])), (torch.Size([1, 1, 32, 32])), 2),
(
(torch.Size([1, 1, 64, 80])),
(torch.Size([1, 1, 64, 16])),
3,
), # size must be divisible by sizeof(uint32_t) because buffers hold uint32_t values
# Tile shape
((torch.Size([4, 1, 32, 32])), (torch.Size([5, 1, 32, 32])), 0),
((torch.Size([1, 2, 64, 64])), (torch.Size([1, 1, 64, 64])), 1),
((torch.Size([1, 1, 64, 32])), (torch.Size([1, 1, 32, 32])), 2),
((torch.Size([1, 1, 64, 64])), (torch.Size([1, 1, 64, 32])), 3),
),
)
@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]])
def test_bw_concat_with_output(input_shapes, input_shapes_2, dimension, device, are_required_outputs):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True, True)

other_data, other_tensor = data_gen_with_range(input_shapes_2, -100, 100, device, True, True)

pyt_y = torch.concat((in_data, other_data), dim=dimension)

grad_data, grad_tensor = data_gen_with_range(pyt_y.shape, -100, 100, device, True, True)

input_grad = None
other_grad = None

opt_tensor1 = torch.zeros(input_shapes, dtype=torch.bfloat16)
opt_tensor2 = torch.zeros(input_shapes_2, dtype=torch.bfloat16)

if are_required_outputs[0]:
input_grad = ttnn.from_torch(
opt_tensor1, ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG
)
if are_required_outputs[1]:
other_grad = ttnn.from_torch(
opt_tensor2, ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG
)

cq_id = 0

pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.concat_bw(
grad_tensor,
input_tensor,
other_tensor,
dimension,
are_required_outputs=are_required_outputs,
input_a_grad=input_grad,
input_b_grad=other_grad,
queue_id=cq_id,
)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())

tt_output_tensor_on_device = [input_grad, other_grad]

golden_function = ttnn.get_golden_function(ttnn.concat_bw)
golden_tensor = golden_function(grad_data, in_data, other_data, dimension)

status = True
for i in range(len(are_required_outputs)):
if are_required_outputs[i]:
status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]])
assert status
34 changes: 34 additions & 0 deletions tests/ttnn/unit_tests/operations/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,3 +486,37 @@ def test_ttnn_slice_bert(input_shape, input_start, input_ends, layout, device):

ttnn_output = ttnn.to_torch(ttnn_output)
assert_with_pcc(torch_output, ttnn_output, 0.99)


def test_slice_output_tensor_rm(device):
torch_input = torch.ones(1, 3, 640, 640)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16)
torch_zeros = torch.zeros(1, 3, 320, 320)
ttnn_output = ttnn.from_torch(torch_zeros, device=device, dtype=ttnn.bfloat16, memory_config=ttnn.L1_MEMORY_CONFIG)
torch_output = torch_input[..., ::2, ::2] # torch_output shape: [1, 3, 320, 320]

pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.slice(ttnn_input, [0, 0, 0, 0], [0, 2, 319, 319], output_tensor=ttnn_output)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())

ttnn_output = ttnn.to_torch(ttnn_output)

assert_with_pcc(torch_output, ttnn_output, 0.99)


def test_slice_output_tensor_tile(device):
torch_input = torch.ones(1, 3, 640, 640)
ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)
torch_zeros = torch.zeros(1, 3, 320, 320)
ttnn_output = ttnn.from_torch(
torch_zeros, device=device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG
)
torch_output = torch_input[..., ::2, ::2] # torch_output shape: [1, 3, 320, 320]

pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.slice(ttnn_input, [0, 0, 0, 0], [0, 2, 319, 319], output_tensor=ttnn_output)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())

ttnn_output = ttnn.to_torch(ttnn_output)

assert_with_pcc(torch_output, ttnn_output, 0.99)
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ void SliceDeviceOperation::validate_with_output_tensors(
// Check if start shape is <= end shape
TT_FATAL(this->slice_start[i] <= this->slice_end[i], "Error");
}

if(!output_tensors.empty() && output_tensors[0].has_value()){
const auto output_shape_required = this->compute_output_shapes(input_tensors)[0];
const auto& out_tensor = output_tensors[0].value();
TT_FATAL(out_tensor.get_legacy_shape() == output_shape_required, "The input tensors need a shape of {}, however the output tensor is only {}", output_shape_required, out_tensor.get_legacy_shape());
}
auto output_tensor_shape = this->compute_output_shapes(input_tensors)[0];
if (step.has_value()) { // if all ones modify before passing in to function
TT_FATAL(input_tensor_a.get_layout() == Layout::ROW_MAJOR, "Strided slice is only supported for row major layout");
Expand Down Expand Up @@ -139,6 +143,9 @@ std::vector<tt::tt_metal::LegacyShape> SliceDeviceOperation::compute_output_shap

std::vector<Tensor> SliceDeviceOperation::create_output_tensors(
const std::vector<Tensor> &input_tensors, const std::vector<std::optional<Tensor>> &output_tensors) const {
if (!output_tensors.empty() && output_tensors[0].has_value()) {
return {output_tensors[0].value()};
}
const auto &input_tensor_a = input_tensors.at(0);
const auto shapes = compute_output_shapes(input_tensors);

Expand Down
32 changes: 20 additions & 12 deletions ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "device/slice_op.hpp"
#include "ttnn/run_operation.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/common/constants.hpp"


namespace ttnn::operations::data_movement {
Expand All @@ -18,7 +19,8 @@ ttnn::Tensor SliceOperation::invoke(
tt::tt_metal::LegacyShape output_tensor_start,
tt::tt_metal::LegacyShape output_tensor_end,
const std::optional<tt::tt_metal::LegacyShape> step,
const std::optional<MemoryConfig>& memory_config_arg) {
const std::optional<MemoryConfig>& memory_config_arg,
const std::optional<Tensor>& optional_output_tensor) {
std::optional<tt::tt_metal::LegacyShape> modified_step = step;
if (modified_step.has_value()) {
if (std::all_of(modified_step->begin(), modified_step->end(), [](int32_t s) { return s == 1; })) {
Expand All @@ -41,7 +43,7 @@ ttnn::Tensor SliceOperation::invoke(
}
}
else {
auto memory_config = memory_config_arg.value_or(input_tensor.memory_config());
auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input_tensor.memory_config());
// TODO: Generalize this early exit of slice for other cases
auto& input_tensor_shape = input_tensor.get_legacy_shape();
if (input_tensor.is_sharded() && input_tensor.memory_config() == memory_config &&
Expand Down Expand Up @@ -73,7 +75,7 @@ ttnn::Tensor SliceOperation::invoke(
}

return operation::run(
SliceDeviceOperation{output_tensor_start, output_tensor_end, modified_step, memory_config}, {input_tensor}, {}, {}, queue_id)
SliceDeviceOperation{output_tensor_start, output_tensor_end, modified_step, memory_config}, {input_tensor}, {}, {optional_output_tensor}, queue_id)
.at(0);

}
Expand All @@ -84,8 +86,9 @@ ttnn::Tensor SliceOperation::invoke(
tt::tt_metal::LegacyShape output_tensor_start,
tt::tt_metal::LegacyShape output_tensor_end,
const std::optional<tt::tt_metal::LegacyShape> step,
const std::optional<MemoryConfig>& memory_config_arg) {
return invoke(0, input_tensor, output_tensor_start, output_tensor_end, step, memory_config_arg);
const std::optional<MemoryConfig>& memory_config_arg,
const std::optional<Tensor>& optional_output_tensor) {
return invoke(DefaultQueueId, input_tensor, output_tensor_start, output_tensor_end, step, memory_config_arg, optional_output_tensor);
}

ttnn::Tensor SliceOperation::invoke(
Expand All @@ -94,14 +97,16 @@ ttnn::Tensor SliceOperation::invoke(
tt::tt_metal::Array1D output_tensor_start,
tt::tt_metal::Array1D output_tensor_end,
const std::optional<tt::tt_metal::Array1D> step,
const std::optional<MemoryConfig>& memory_config_arg) {
const std::optional<MemoryConfig>& memory_config_arg,
const std::optional<Tensor>& optional_output_tensor) {
return invoke(
queue_id,
input_tensor,
tt::tt_metal::LegacyShape(output_tensor_start),
tt::tt_metal::LegacyShape(output_tensor_end),
step.has_value() ? std::optional<tt::tt_metal::LegacyShape>(tt::tt_metal::LegacyShape(step.value())) : std::nullopt,
memory_config_arg);
memory_config_arg,
optional_output_tensor);
}

ttnn::Tensor SliceOperation::invoke(
Expand All @@ -110,31 +115,34 @@ ttnn::Tensor SliceOperation::invoke(
tt::tt_metal::Array4D output_tensor_start,
tt::tt_metal::Array4D output_tensor_end,
const std::optional<tt::tt_metal::Array4D> step,
const std::optional<MemoryConfig>& memory_config_arg) {
const std::optional<MemoryConfig>& memory_config_arg,
const std::optional<Tensor>& optional_output_tensor) {
return invoke(
queue_id,
input_tensor,
tt::tt_metal::LegacyShape(output_tensor_start),
tt::tt_metal::LegacyShape(output_tensor_end),
step.has_value() ? std::optional<tt::tt_metal::LegacyShape>(tt::tt_metal::LegacyShape(step.value())) : std::nullopt,
memory_config_arg);
memory_config_arg,
optional_output_tensor);
}

ttnn::Tensor SliceOperation::invoke(
const ttnn::Tensor& input_tensor,
tt::tt_metal::Array4D output_tensor_start,
tt::tt_metal::Array4D output_tensor_end,
const std::optional<tt::tt_metal::Array4D> step,
const std::optional<MemoryConfig>& memory_config_arg) {
return invoke(DefaultQueueId, input_tensor, output_tensor_start, output_tensor_end, step, memory_config_arg);
const std::optional<MemoryConfig>& memory_config_arg,
const std::optional<Tensor>& optional_output_tensor) {
return invoke(DefaultQueueId, input_tensor, output_tensor_start, output_tensor_end, step, memory_config_arg, optional_output_tensor);
}

ttnn::Tensor SliceOperation::invoke(
const ttnn::Tensor& input_tensor,
tt::tt_metal::Array4D output_tensor_start,
tt::tt_metal::Array4D output_tensor_end,
const std::optional<tt::tt_metal::Array4D> step) {
return invoke(DefaultQueueId, input_tensor, output_tensor_start, output_tensor_end, step, std::nullopt);
return invoke(DefaultQueueId, input_tensor, output_tensor_start, output_tensor_end, step, std::nullopt, std::nullopt);
}

} // namespace operations
25 changes: 15 additions & 10 deletions ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,43 @@ struct SliceOperation {
const ttnn::Tensor& input_tensor,
tt::tt_metal::LegacyShape output_tensor_start,
tt::tt_metal::LegacyShape output_tensor_end,
const std::optional<tt::tt_metal::LegacyShape> step,
const std::optional<MemoryConfig>& memory_config_arg);
const std::optional<tt::tt_metal::LegacyShape> step = std::nullopt,
const std::optional<MemoryConfig>& memory_config_arg = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static ttnn::Tensor invoke(
const ttnn::Tensor& input_tensor,
tt::tt_metal::LegacyShape output_tensor_start,
tt::tt_metal::LegacyShape output_tensor_end,
const std::optional<tt::tt_metal::LegacyShape> step,
const std::optional<MemoryConfig>& memory_config_arg);
const std::optional<tt::tt_metal::LegacyShape> step = std::nullopt,
const std::optional<MemoryConfig>& memory_config_arg = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static ttnn::Tensor invoke(
uint8_t queue_id,
const ttnn::Tensor& input_tensor,
tt::tt_metal::Array1D output_tensor_start,
tt::tt_metal::Array1D output_tensor_end,
const std::optional<tt::tt_metal::Array1D> step,
const std::optional<MemoryConfig>& memory_config_arg);
const std::optional<tt::tt_metal::Array1D> step = std::nullopt,
const std::optional<MemoryConfig>& memory_config_arg = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static ttnn::Tensor invoke(
uint8_t queue_id,
const ttnn::Tensor& input_tensor,
tt::tt_metal::Array4D output_tensor_start,
tt::tt_metal::Array4D output_tensor_end,
const std::optional<tt::tt_metal::Array4D> step,
const std::optional<MemoryConfig>& memory_config_arg);
const std::optional<tt::tt_metal::Array4D> step = std::nullopt,
const std::optional<MemoryConfig>& memory_config_arg = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static ttnn::Tensor invoke(
const ttnn::Tensor& input_tensor,
tt::tt_metal::Array4D output_tensor_start,
tt::tt_metal::Array4D output_tensor_end,
const std::optional<tt::tt_metal::Array4D> step,
const std::optional<MemoryConfig>& memory_config_arg);
const std::optional<tt::tt_metal::Array4D> step = std::nullopt,
const std::optional<MemoryConfig>& memory_config_arg = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static ttnn::Tensor invoke(
const ttnn::Tensor& input_tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,17 @@ void bind_slice(py::module& module) {
const tt::tt_metal::Array4D & slice_end,
const std::optional<tt::tt_metal::Array4D> &step,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<Tensor>& optional_output_tensor,
uint8_t queue_id) {
return self(queue_id, input_tensor, slice_start, slice_end, step, memory_config);
return self(queue_id, input_tensor, slice_start, slice_end, step, memory_config, optional_output_tensor);
},
py::arg("input_tensor"),
py::arg("slice_start"),
py::arg("slice_end"),
py::arg("step") = std::nullopt,
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("output_tensor") = std::nullopt,
py::arg("queue_id") = 0,
}
);
Expand Down
Loading

0 comments on commit 927ff4b

Please sign in to comment.