Skip to content

Commit

Permalink
#12867: Cleanup 9 Unary Backward ops (#12920)
Browse files Browse the repository at this point in the history
* #12867: Cleanup ExecuteUnaryBackwardFloatWithDefault

* #12867: Cleanup repeat_bw

* #12867: Cleanup rdiv_bw

* #12920: Update test files and golden function to check default value
  • Loading branch information
VirdhatchaniKN authored Sep 22, 2024
1 parent 16e9f89 commit 31032c7
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 155 deletions.
21 changes: 21 additions & 0 deletions tests/ttnn/unit_tests/operations/backward/test_backward_celu.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,24 @@ def test_bw_celu(input_shapes, device):
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)

assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_celu_default(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, -1, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, -1, device, True)

tt_output_tensor_on_device = ttnn.celu_bw(grad_tensor, input_tensor)

golden_function = ttnn.get_golden_function(ttnn.celu_bw)
golden_tensor = golden_function(grad_data, in_data)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)

assert comp_pass
20 changes: 20 additions & 0 deletions tests/ttnn/unit_tests/operations/backward/test_backward_elu.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,23 @@ def test_bw_elu(input_shapes, alpha, device):
golden_tensor = golden_function(grad_data, in_data, alpha)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_elu_default(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 20, device, True)

tt_output_tensor_on_device = ttnn.elu_bw(grad_tensor, input_tensor)

golden_function = ttnn.get_golden_function(ttnn.elu_bw)
golden_tensor = golden_function(grad_data, in_data)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,24 @@ def test_bw_hardshrink(input_shapes, lambd, device):

comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_hardshrink_default(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)

tt_output_tensor_on_device = ttnn.hardshrink_bw(grad_tensor, input_tensor)

golden_function = ttnn.get_golden_function(ttnn.hardshrink_bw)
golden_tensor = golden_function(grad_data, in_data)

comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,23 @@ def test_bw_leaky_relu(input_shapes, negative_slope, device):
golden_tensor = golden_function(grad_data, in_data, negative_slope)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_leaky_relu_default(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, -1, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, -1, device, True)

tt_output_tensor_on_device = ttnn.leaky_relu_bw(grad_tensor, input_tensor)

golden_function = ttnn.get_golden_function(ttnn.leaky_relu_bw)
golden_tensor = golden_function(grad_data, in_data)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,21 @@ def test_bw_logiteps(input_shapes, eps, device):
golden_tensor = golden_function(grad_data, in_data, eps)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_logiteps_default(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -2, 2, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device)
tt_output_tensor_on_device = ttnn.logiteps_bw(grad_tensor, input_tensor)
golden_function = ttnn.get_golden_function(ttnn.logiteps_bw)
golden_tensor = golden_function(grad_data, in_data)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,28 @@ def test_bw_softshrink(input_shapes, lambd, device):

comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_softshrink_default(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 20, device)
in_data.retain_grad()

pyt_y = torch.nn.functional.softshrink(in_data)

tt_output_tensor_on_device = ttnn.softshrink_bw(grad_tensor, input_tensor)

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]

comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ std::vector<Tensor> ExecuteUnaryBackwardSoftplus::invoke(
return grad_tensor;
}

std::vector<Tensor> _rdiv_bw(
std::vector<Tensor> ExecuteUnaryBackwardRdiv::invoke(
const Tensor& grad, const Tensor& input, float scalar, string round_mode, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
TT_FATAL((round_mode == "None" || round_mode == "trunc" || round_mode == "floor"), "Incorrect rounding mode (expected 'None', 'trunc', or 'floor')");
Expand Down Expand Up @@ -591,7 +591,7 @@ std::vector<Tensor> _square_bw(const Tensor& grad, const Tensor& input, const st
return grad_tensor;
}

std::vector<Tensor> _hardshrink_bw(
std::vector<Tensor> ExecuteUnaryBackwardHardshrink::invoke(
const Tensor& grad, const Tensor& input_tensor, float lambd, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor hardshrink_result = ttnn::hardshrink(input_tensor, lambd, output_mem_config);
Expand All @@ -603,7 +603,7 @@ std::vector<Tensor> _hardshrink_bw(

// softshrink
// result: torch.where(self < -lambd, grad, torch.where(self > lambd, grad, torch.tensor(0.0)))
std::vector<Tensor> _softshrink_bw(
std::vector<Tensor> ExecuteUnaryBackwardSoftshrink::invoke(
const Tensor& grad, const Tensor& input_tensor, float lambd, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor result = ttnn::where(
Expand All @@ -622,7 +622,7 @@ std::vector<Tensor> _softshrink_bw(

// Leaky_Relu
// result: torch.where(self > 0, grad_output, grad_output * negative_slope)
std::vector<Tensor> _leaky_relu_bw(
std::vector<Tensor> ExecuteUnaryBackwardLeakyRelu::invoke(
const Tensor& grad, const Tensor& input, float negative_slope, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_result = where(
Expand All @@ -634,7 +634,7 @@ std::vector<Tensor> _leaky_relu_bw(

// ELU
// result : grad * (torch.where(input >= 0, 1, alpha * torch.exp(input)))
std::vector<Tensor> _elu_bw(
std::vector<Tensor> ExecuteUnaryBackwardElu::invoke(
const Tensor& grad, const Tensor& input, float alpha, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_result = where(
Expand All @@ -649,7 +649,7 @@ std::vector<Tensor> _elu_bw(

// Celu
// result: torch.where((input > 0), grad, grad * torch.exp(input / alpha))
std::vector<Tensor> _celu_bw(
std::vector<Tensor> ExecuteUnaryBackwardCelu::invoke(
const Tensor& grad, const Tensor& input, float alpha, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor div_result = ttnn::multiply(
Expand Down Expand Up @@ -1074,7 +1074,7 @@ std::vector<Tensor> _cosh_bw(const Tensor& grad, const Tensor& input, const std:
// # grad_output / (self * (1.0 - self)),
// # self.new_full((), float("nan")),
// # )
std::vector<Tensor> _logiteps_bw(
std::vector<Tensor> ExecuteUnaryBackwardLogiteps::invoke(
const Tensor& grad, const Tensor& input, float eps, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
float low, high;
Expand Down Expand Up @@ -1414,7 +1414,7 @@ std::vector<std::optional<ttnn::Tensor>> ExecuteUnaryBackwardGelu::invoke(
return ExecuteUnaryBackwardGelu::invoke(DefaultQueueId, grad, input, approximate, output_mem_config, input_grad);
}

std::vector<Tensor> _repeat_bw(
std::vector<Tensor> ExecuteUnaryBackwardRepeat::invoke(
const Tensor& grad, const Tensor& input, const tt::tt_metal::LegacyShape& shape, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
auto output_memory_config = output_mem_config.value_or(input.memory_config()); //TODO: Remove after ternary forward ops migration is completed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ namespace ttnn::operations::unary_backward {

enum class UnaryBackwardOpType {
DIV_BW,
RDIV_BW,
MULTIGAMMALN_BW,
ADD_BW,
EQ_BW,
Expand All @@ -36,11 +35,6 @@ enum class UnaryBackwardOpType {
SIGMOID_BW,
RELU_BW,
LOGIT_BW,
HARDSHRINK_BW,
SOFTSHRINK_BW,
LEAKY_RELU_BW,
ELU_BW,
CELU_BW,
RPOW_BW,
FLOOR_BW,
ROUND_BW,
Expand All @@ -61,7 +55,6 @@ enum class UnaryBackwardOpType {
CEIL_BW,
SOFTSIGN_BW,
COSH_BW,
LOGITEPS_BW,
LOG2_BW,
SIGN_BW,
DIV_NO_NAN_BW,
Expand All @@ -73,7 +66,6 @@ enum class UnaryBackwardOpType {
ERF_BW,
DEG2RAD_BW,
POLYGAMMA_BW,
REPEAT_BW,
};

std::vector<Tensor> _acos_bw( const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
Expand Down Expand Up @@ -132,16 +124,6 @@ std::vector<Tensor> _log_bw( const Tensor& grad, const Tensor& input, const std:
std::vector<Tensor> _add_bw( const Tensor& grad, const Tensor& input, float alpha, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
std::vector<Tensor> _eq_bw( const Tensor& grad, const Tensor& input, float other, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);

std::vector<Tensor> _hardshrink_bw( const Tensor& grad, const Tensor& input, float lambd = 0.5, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
std::vector<Tensor> _softshrink_bw( const Tensor& grad, const Tensor& input, float lambd = 0.5, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
std::vector<Tensor> _leaky_relu_bw( const Tensor& grad, const Tensor& input, float negative_slope = 0.01, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
std::vector<Tensor> _elu_bw( const Tensor& grad, const Tensor& input, float alpha = 1.0, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
std::vector<Tensor> _celu_bw( const Tensor& grad, const Tensor& input, float aplha = 1.0, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
std::vector<Tensor> _logiteps_bw( const Tensor& grad, const Tensor& input, float eps = 0.0, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);

std::vector<Tensor> _rdiv_bw( const Tensor& grad, const Tensor& input, float scalar, string round_mode = "None", const std::optional<MemoryConfig>& output_mem_config = std::nullopt);

std::vector<Tensor> _repeat_bw(const Tensor& grad, const Tensor& input, const tt::tt_metal::LegacyShape& shape, const std::optional<MemoryConfig>& output_mem_config);
Tensor change_layout_to_tile(const Tensor& temp, const MemoryConfig& output_mem_config);

// OpHandler struct template
Expand Down Expand Up @@ -414,48 +396,6 @@ struct OpHandler<UnaryBackwardOpType::POLYGAMMA_BW> {
}
};

template <>
struct OpHandler<UnaryBackwardOpType::HARDSHRINK_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float lambd, const std::optional<MemoryConfig>& output_mem_config ) {
return _hardshrink_bw(grad, input, lambd, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::SOFTSHRINK_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float lambd, const std::optional<MemoryConfig>& output_mem_config ) {
return _softshrink_bw(grad, input, lambd, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::LEAKY_RELU_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float negative_slope, const std::optional<MemoryConfig>& output_mem_config ) {
return _leaky_relu_bw(grad, input, negative_slope, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::ELU_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float alpha, const std::optional<MemoryConfig>& output_mem_config ) {
return _elu_bw(grad, input, alpha, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::CELU_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float alpha, const std::optional<MemoryConfig>& output_mem_config ) {
return _celu_bw(grad, input, alpha, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::LOGITEPS_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float eps, const std::optional<MemoryConfig>& output_mem_config ) {
return _logiteps_bw(grad, input, eps, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::GT_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float other, const std::optional<MemoryConfig>& output_mem_config ) {
Expand Down Expand Up @@ -540,20 +480,6 @@ struct OpHandler<UnaryBackwardOpType::SUB_BW> {
}
};

template <>
struct OpHandler<UnaryBackwardOpType::RDIV_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float scalar, string round_mode, const std::optional<MemoryConfig>& output_mem_config ) {
return _rdiv_bw(grad, input, scalar, round_mode, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::REPEAT_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, const tt::tt_metal::LegacyShape& shape, const std::optional<MemoryConfig>& output_mem_config ) {
return _repeat_bw(grad, input, shape, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::EQ_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float other, const std::optional<MemoryConfig>& output_mem_config ) {
Expand Down
Loading

0 comments on commit 31032c7

Please sign in to comment.