From 22a94a4e62ca0aadbd48afe3b6213961e8c0ce8a Mon Sep 17 00:00:00 2001 From: VirdhatchaniKN Date: Fri, 20 Sep 2024 07:00:44 +0000 Subject: [PATCH] #12920: Update test files and golden function to check default value --- .../operations/backward/test_backward_celu.py | 21 ++++++++ .../operations/backward/test_backward_elu.py | 20 +++++++ .../backward/test_backward_hardshrink.py | 21 ++++++++ .../backward/test_backward_leaky_relu.py | 20 +++++++ .../backward/test_backward_logiteps.py | 18 +++++++ .../backward/test_backward_softshrink.py | 25 +++++++++ ttnn/ttnn/operations/unary_backward.py | 53 ++++++++++++------- 7 files changed, 160 insertions(+), 18 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_celu.py b/tests/ttnn/unit_tests/operations/backward/test_backward_celu.py index 1d40c60a1c8..b42841ad356 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_celu.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_celu.py @@ -29,3 +29,24 @@ def test_bw_celu(input_shapes, device): comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +def test_bw_celu_default(input_shapes, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, -1, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, -1, device, True) + + tt_output_tensor_on_device = ttnn.celu_bw(grad_tensor, input_tensor) + + golden_function = ttnn.get_golden_function(ttnn.celu_bw) + golden_tensor = golden_function(grad_data, in_data) + comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) + + assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_elu.py b/tests/ttnn/unit_tests/operations/backward/test_backward_elu.py index 59783b57e84..8114d3022c1 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_elu.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_elu.py @@ -36,3 +36,23 @@ def test_bw_elu(input_shapes, alpha, device): golden_tensor = golden_function(grad_data, in_data, alpha) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +def test_bw_elu_default(input_shapes, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 20, device, True) + + tt_output_tensor_on_device = ttnn.elu_bw(grad_tensor, input_tensor) + + golden_function = ttnn.get_golden_function(ttnn.elu_bw) + golden_tensor = golden_function(grad_data, in_data) + comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) + assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_hardshrink.py b/tests/ttnn/unit_tests/operations/backward/test_backward_hardshrink.py index 3c386939a0a..5700e708382 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_hardshrink.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_hardshrink.py @@ -28,3 +28,24 @@ def test_bw_hardshrink(input_shapes, lambd, device): comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +def test_bw_hardshrink_default(input_shapes, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device) + + tt_output_tensor_on_device = ttnn.hardshrink_bw(grad_tensor, input_tensor) + + golden_function = ttnn.get_golden_function(ttnn.hardshrink_bw) + golden_tensor = golden_function(grad_data, in_data) + + comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) + assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_leaky_relu.py b/tests/ttnn/unit_tests/operations/backward/test_backward_leaky_relu.py index 5941f89709d..5c33b4f1664 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_leaky_relu.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_leaky_relu.py @@ -27,3 +27,23 @@ def test_bw_leaky_relu(input_shapes, negative_slope, device): golden_tensor = golden_function(grad_data, in_data, negative_slope) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +def test_bw_leaky_relu_default(input_shapes, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, -1, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, -1, device, True) + + tt_output_tensor_on_device = ttnn.leaky_relu_bw(grad_tensor, input_tensor) + + golden_function = ttnn.get_golden_function(ttnn.leaky_relu_bw) + golden_tensor = golden_function(grad_data, in_data) + comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) + assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_logiteps.py b/tests/ttnn/unit_tests/operations/backward/test_backward_logiteps.py index 4ee3ccf1aec..5546090dcf7 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_logiteps.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_logiteps.py @@ -32,3 +32,21 @@ def test_bw_logiteps(input_shapes, eps, device): golden_tensor = golden_function(grad_data, in_data, eps) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +def test_bw_logiteps_default(input_shapes, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -2, 2, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device) + tt_output_tensor_on_device = ttnn.logiteps_bw(grad_tensor, input_tensor) + golden_function = ttnn.get_golden_function(ttnn.logiteps_bw) + golden_tensor = golden_function(grad_data, in_data) + comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) + assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_softshrink.py b/tests/ttnn/unit_tests/operations/backward/test_backward_softshrink.py index 815246a959c..bf3651a8fac 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_softshrink.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_softshrink.py @@ -35,3 +35,28 @@ def test_bw_softshrink(input_shapes, lambd, device): comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +def test_bw_softshrink_default(input_shapes, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 20, device) + in_data.retain_grad() + + pyt_y = torch.nn.functional.softshrink(in_data) + + tt_output_tensor_on_device = ttnn.softshrink_bw(grad_tensor, input_tensor) + + pyt_y.backward(gradient=grad_data) + + golden_tensor = [in_data.grad] + + comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor) + assert comp_pass diff --git a/ttnn/ttnn/operations/unary_backward.py b/ttnn/ttnn/operations/unary_backward.py index e4676604cf3..d6858bf9f8a 100644 --- a/ttnn/ttnn/operations/unary_backward.py +++ b/ttnn/ttnn/operations/unary_backward.py @@ -24,22 +24,39 @@ def _golden_function_unary_backward(torch_op, grad_tensor, input_tensor, *args, return golden_tensor -def _golden_function_unary_backward_with_float(torch_op, grad_tensor, input_tensor, alpha, *args, **kwargs): +def _golden_function_div_no_nan(torch_op, grad_tensor, input_tensor, alpha, *args, **kwargs): + pyt_y = torch.where(torch.tensor(alpha) == 0, torch.zeros_like(input_tensor), torch.div(input_tensor, alpha)) + input_tensor.retain_grad() + pyt_y.backward(gradient=grad_tensor) + golden_tensor = [input_tensor.grad] + golden_tensor[0] = torch.where(torch.isnan(golden_tensor[0]), torch.zeros_like(input_tensor), golden_tensor[0]) + return golden_tensor + + +def _golden_function_unary_backward_with_float(torch_op, grad_tensor, input_tensor, alpha=None, *args, **kwargs): if torch_op == "leaky_relu": - pyt_y = torch.nn.functional.leaky_relu(input_tensor, negative_slope=alpha, inplace=False) + if alpha != None: + pyt_y = torch.nn.functional.leaky_relu(input_tensor, negative_slope=alpha) + else: + pyt_y = torch.nn.functional.leaky_relu(input_tensor) elif torch_op == "elu": - pyt_y = torch.nn.functional.elu(input_tensor, alpha=alpha) + if alpha != None: + pyt_y = torch.nn.functional.elu(input_tensor, alpha=alpha) + else: + pyt_y = torch.nn.functional.elu(input_tensor) elif torch_op == "celu": - pyt_y = torch.nn.functional.celu(input_tensor, alpha) - elif torch_op == "div_no_nan": - pyt_y = torch.where(torch.tensor(alpha) == 0, torch.zeros_like(input_tensor), torch.div(input_tensor, alpha)) + if alpha != None: + pyt_y = torch.nn.functional.celu(input_tensor, alpha) + else: + pyt_y = torch.nn.functional.celu(input_tensor) else: - pyt_y = torch_op(input_tensor, alpha) + if alpha != None: + pyt_y = torch_op(input_tensor, alpha) + else: + pyt_y = torch_op(input_tensor) input_tensor.retain_grad() pyt_y.backward(gradient=grad_tensor) golden_tensor = [input_tensor.grad] - if torch_op == "div_no_nan": - golden_tensor[0] = torch.where(torch.isnan(golden_tensor[0]), torch.zeros_like(input_tensor), golden_tensor[0]) return golden_tensor @@ -146,35 +163,35 @@ def _golden_function_backward_with_reverse_string( ttnn.attach_golden_function( ttnn.hardshrink_bw, - golden_function=lambda grad, input, *args, **kwargs: _golden_function_unary_backward( - torch.hardshrink, grad, input, *args, **kwargs + golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float( + torch.hardshrink, grad, input, alpha, *args, **kwargs ), ) ttnn.attach_golden_function( ttnn.softshrink_bw, - golden_function=lambda grad, input, *args, **kwargs: _golden_function_unary_backward( - torch.softshrink, grad, input, *args, **kwargs + golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float( + torch.softshrink, grad, input, alpha, *args, **kwargs ), ) ttnn.attach_golden_function( ttnn.leaky_relu_bw, - golden_function=lambda grad, input, alpha, *args, **kwargs: _golden_function_unary_backward_with_float( + golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float( "leaky_relu", grad, input, alpha, *args, **kwargs ), ) ttnn.attach_golden_function( ttnn.elu_bw, - golden_function=lambda grad, input, alpha, *args, **kwargs: _golden_function_unary_backward_with_float( + golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float( "elu", grad, input, alpha, *args, **kwargs ), ) ttnn.attach_golden_function( ttnn.celu_bw, - golden_function=lambda grad, input, alpha, *args, **kwargs: _golden_function_unary_backward_with_float( + golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float( "celu", grad, input, alpha, *args, **kwargs ), ) @@ -188,7 +205,7 @@ def _golden_function_backward_with_reverse_string( ttnn.attach_golden_function( ttnn.logiteps_bw, - golden_function=lambda grad, input, alpha, *args, **kwargs: _golden_function_unary_backward_with_float( + golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float( torch.logit, grad, input, alpha, *args, **kwargs ), ) @@ -216,7 +233,7 @@ def _golden_function_backward_with_reverse_string( ttnn.attach_golden_function( ttnn.div_no_nan_bw, - golden_function=lambda grad, input, alpha, *args, **kwargs: _golden_function_unary_backward_with_float( + golden_function=lambda grad, input, alpha, *args, **kwargs: _golden_function_div_no_nan( "div_no_nan", grad, input, alpha, *args, **kwargs ), )