Skip to content

Commit

Permalink
#12920: Update test files and golden function to check default value
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Sep 20, 2024
1 parent 691b2ef commit 505a7f8
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 13 deletions.
21 changes: 21 additions & 0 deletions tests/ttnn/unit_tests/operations/backward/test_backward_celu.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,24 @@ def test_bw_celu(input_shapes, device):
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)

assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_celu_default(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, -1, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, -1, device, True)

tt_output_tensor_on_device = ttnn.celu_bw(grad_tensor, input_tensor)

golden_function = ttnn.get_golden_function(ttnn.celu_bw)
golden_tensor = golden_function(grad_data, in_data)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)

assert comp_pass
20 changes: 20 additions & 0 deletions tests/ttnn/unit_tests/operations/backward/test_backward_elu.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,23 @@ def test_bw_elu(input_shapes, alpha, device):
golden_tensor = golden_function(grad_data, in_data, alpha)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_elu_default(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 20, device, True)

tt_output_tensor_on_device = ttnn.elu_bw(grad_tensor, input_tensor)

golden_function = ttnn.get_golden_function(ttnn.elu_bw)
golden_tensor = golden_function(grad_data, in_data)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,24 @@ def test_bw_hardshrink(input_shapes, lambd, device):

comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_hardshrink_default(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)

tt_output_tensor_on_device = ttnn.hardshrink_bw(grad_tensor, input_tensor)

golden_function = ttnn.get_golden_function(ttnn.hardshrink_bw)
golden_tensor = golden_function(grad_data, in_data)

comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,23 @@ def test_bw_leaky_relu(input_shapes, negative_slope, device):
golden_tensor = golden_function(grad_data, in_data, negative_slope)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_leaky_relu_default(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, -1, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, -1, device, True)

tt_output_tensor_on_device = ttnn.leaky_relu_bw(grad_tensor, input_tensor)

golden_function = ttnn.get_golden_function(ttnn.leaky_relu_bw)
golden_tensor = golden_function(grad_data, in_data)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,21 @@ def test_bw_logiteps(input_shapes, eps, device):
golden_tensor = golden_function(grad_data, in_data, eps)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_logiteps_default(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -2, 2, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device)
tt_output_tensor_on_device = ttnn.logiteps_bw(grad_tensor, input_tensor)
golden_function = ttnn.get_golden_function(ttnn.logiteps_bw)
golden_tensor = golden_function(grad_data, in_data)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,28 @@ def test_bw_softshrink(input_shapes, lambd, device):

comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_softshrink_default(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 20, device)
in_data.retain_grad()

pyt_y = torch.nn.functional.softshrink(in_data)

tt_output_tensor_on_device = ttnn.softshrink_bw(grad_tensor, input_tensor)

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]

comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
40 changes: 27 additions & 13 deletions ttnn/ttnn/operations/unary_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,36 @@ def _golden_function_unary_backward(torch_op, grad_tensor, input_tensor, *args,
return golden_tensor


def _golden_function_unary_backward_with_float(torch_op, grad_tensor, input_tensor, alpha, *args, **kwargs):
def _golden_function_div_no_nan(torch_op, grad_tensor, input_tensor, alpha, *args, **kwargs):
pyt_y = torch.where(torch.tensor(alpha) == 0, torch.zeros_like(input_tensor), torch.div(input_tensor, alpha))
input_tensor.retain_grad()
pyt_y.backward(gradient=grad_tensor)
golden_tensor = [input_tensor.grad]
golden_tensor = torch.where(torch.isnan(golden_tensor[0]), torch.zeros_like(input_tensor), golden_tensor[0])
return golden_tensor


def _golden_function_unary_backward_with_float(torch_op, grad_tensor, input_tensor, alpha=None, *args, **kwargs):
if torch_op == "leaky_relu":
pyt_y = torch.nn.functional.leaky_relu(input_tensor, negative_slope=alpha, inplace=False)
if alpha != None:
pyt_y = torch.nn.functional.leaky_relu(input_tensor, negative_slope=alpha)
else:
pyt_y = torch.nn.functional.leaky_relu(input_tensor)
elif torch_op == "elu":
pyt_y = torch.nn.functional.elu(input_tensor, alpha=alpha)
if alpha != None:
pyt_y = torch.nn.functional.elu(input_tensor, alpha=alpha)
else:
pyt_y = torch.nn.functional.elu(input_tensor)
elif torch_op == "celu":
pyt_y = torch.nn.functional.celu(input_tensor, alpha)
elif torch_op == "div_no_nan":
pyt_y = torch.where(torch.tensor(alpha) == 0, torch.zeros_like(input_tensor), torch.div(input_tensor, alpha))
if alpha != None:
pyt_y = torch.nn.functional.celu(input_tensor, alpha)
else:
pyt_y = torch.nn.functional.celu(input_tensor)
else:
pyt_y = torch_op(input_tensor, alpha)
input_tensor.retain_grad()
pyt_y.backward(gradient=grad_tensor)
golden_tensor = [input_tensor.grad]
if torch_op == "div_no_nan":
golden_tensor[0] = torch.where(torch.isnan(golden_tensor[0]), torch.zeros_like(input_tensor), golden_tensor[0])
return golden_tensor


Expand Down Expand Up @@ -160,21 +174,21 @@ def _golden_function_backward_with_reverse_string(

ttnn.attach_golden_function(
ttnn.leaky_relu_bw,
golden_function=lambda grad, input, alpha, *args, **kwargs: _golden_function_unary_backward_with_float(
golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float(
"leaky_relu", grad, input, alpha, *args, **kwargs
),
)

ttnn.attach_golden_function(
ttnn.elu_bw,
golden_function=lambda grad, input, alpha, *args, **kwargs: _golden_function_unary_backward_with_float(
golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float(
"elu", grad, input, alpha, *args, **kwargs
),
)

ttnn.attach_golden_function(
ttnn.celu_bw,
golden_function=lambda grad, input, alpha, *args, **kwargs: _golden_function_unary_backward_with_float(
golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float(
"celu", grad, input, alpha, *args, **kwargs
),
)
Expand All @@ -188,7 +202,7 @@ def _golden_function_backward_with_reverse_string(

ttnn.attach_golden_function(
ttnn.logiteps_bw,
golden_function=lambda grad, input, alpha, *args, **kwargs: _golden_function_unary_backward_with_float(
golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float(
torch.logit, grad, input, alpha, *args, **kwargs
),
)
Expand Down Expand Up @@ -216,7 +230,7 @@ def _golden_function_backward_with_reverse_string(

ttnn.attach_golden_function(
ttnn.div_no_nan_bw,
golden_function=lambda grad, input, alpha, *args, **kwargs: _golden_function_unary_backward_with_float(
golden_function=lambda grad, input, alpha, *args, **kwargs: _golden_function_div_no_nan(
"div_no_nan", grad, input, alpha, *args, **kwargs
),
)
Expand Down

0 comments on commit 505a7f8

Please sign in to comment.