Skip to content

Commit

Permalink
#13856: ttnn.bias_gelu_bw unary has low PCC
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW committed Nov 14, 2024
1 parent 3b8fb6c commit bcd98a4
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 7 deletions.
162 changes: 162 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_bgelu_bw_tanh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest

import torch

import ttnn
from tests.ttnn.utils_for_testing import assert_with_pcc


def gelu_backward(grad: torch.Tensor, self: torch.Tensor, approximate: str = "none"):
M_SQRT2 = 1.41421356237309504880
M_2_SQRTPI = 1.12837916709551257390
if approximate == "tanh":
kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
kKappa = 0.044715
x_sq = self * self
x_cube = x_sq * self
inner = kBeta * (self + kKappa * x_cube)
tanh_inner = torch.tanh(inner)

left = 0.5 * self
right = 1 + tanh_inner

left_derivative = 0.5 * right

tanh_derivative = 1 - tanh_inner * tanh_inner
inner_derivative = kBeta * (1 + 3 * kKappa * x_sq)
right_derivative = left * tanh_derivative * inner_derivative

# return tanh_inner
return grad * (left_derivative + right_derivative)


@pytest.mark.parametrize(
"shapes",
[
[[4, 2, 96, 192], [4, 2, 96, 192]],
],
)
def test_case3(device, shapes):
torch.manual_seed(4378657)

high = 100
low = -100
in_data = torch.rand(shapes[0], requires_grad=True).bfloat16() * (high - low) + low
grad_data = torch.rand(shapes[1], requires_grad=False).bfloat16() * (high - low) + low

input_tensor = ttnn.from_torch(
in_data, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)
grad_tensor = ttnn.from_torch(
grad_data, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)

scalar = torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100).item()
# scalar = -10
in_data1 = in_data + scalar

# use torch implentation from derivatives.yaml to get output
torch_output_tensor = gelu_backward(grad_data, in_data1, approximate="tanh")

# use golden fn to get output
# golden_function = ttnn.get_golden_function(ttnn.bias_gelu_bw)
# torch_golden = golden_function(grad_data, in_data, scalar, value="tanh")[0]

# ttnn output
output_tensor = ttnn.bias_gelu_bw(grad_tensor, input_tensor, scalar, approximate="tanh")
print("scalar", scalar)
# print("torch_golden", torch_golden[0], torch_golden.shape)
# torch_output_tensor[torch_output_tensor == -0.0] = 0.0
print("torch_output_tensor", torch_output_tensor)
output_tensor_rm = output_tensor[0].cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
print("output_tensor", output_tensor_rm)

# diff = torch_output_tensor - output_tensor_rm
# print("diff ", (diff == 0).all())
# print(diff, diff.min(), diff.max())

output_tensor = ttnn.to_torch(output_tensor[0])

# assert ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.999
assert_with_pcc(torch_output_tensor, output_tensor, 0.97)


@pytest.mark.parametrize(
"shapes",
[
[[97, 129], [97, 129]],
],
)
def test_case4(device, shapes):
torch.manual_seed(7580522)

high = 100
low = -100
in_data = torch.rand(shapes[0], requires_grad=True).bfloat16() * (high - low) + low
grad_data = torch.rand(shapes[1], requires_grad=False).bfloat16() * (high - low) + low

input_tensor = ttnn.from_torch(
in_data, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)
grad_tensor = ttnn.from_torch(
grad_data, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)

scalar = torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100).item() # scalar -97.5
in_data1 = in_data + scalar

# use torch implentation from derivatives.yaml to get output
torch_output_tensor = gelu_backward(grad_data, in_data1, approximate="tanh")

# use golden fn to get output
golden_function = ttnn.get_golden_function(ttnn.bias_gelu_bw)
torch_golden = golden_function(grad_data, in_data, scalar, value="tanh")[0]

# ttnn output
output_tensor = ttnn.bias_gelu_bw(grad_tensor, input_tensor, scalar, approximate="tanh")
print("scalar", scalar)
print("torch_golden", torch_golden[0], torch_golden.shape)
print("torch_output_tensor", torch_output_tensor)
output_tensor_rm = output_tensor[0].cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
print("output_tensor", output_tensor_rm)
diff = torch_golden - torch_output_tensor
# print("diff ", (diff == 0).all())
# print(diff)
output_tensor = ttnn.to_torch(output_tensor[0])

# assert ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.999
assert_with_pcc(torch_output_tensor, output_tensor, 0.998)


@pytest.mark.parametrize(
"shapes",
[
[[1, 1, 32, 32]],
],
)
def test_add_float(device, shapes):
torch.manual_seed(0)
torch.set_printoptions(linewidth=200, threshold=10000, precision=15, sci_mode=False, edgeitems=17)
torch_input_tensor_a = torch.ones(shapes[0], dtype=torch.bfloat16)
torch_input_tensor_b = 1.41421356237309504880
torch_output_tensor = torch.mul(torch_input_tensor_a, torch_input_tensor_b)
print("torch_output_tensor", torch_output_tensor)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
input_tensor_b = torch_input_tensor_b
output_tensor = ttnn.mul(input_tensor_a, input_tensor_b)

output_tensor = ttnn.to_torch(output_tensor)
print("output_tensor", output_tensor)

assert ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.99988
103 changes: 103 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_bias_gelu_bw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from loguru import logger
from functools import partial
import pytest
import torch
import ttnn
import traceback

from tests.ttnn.utils_for_testing import assert_with_pcc
from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt
from models.utility_functions import torch_random


def run_backward_div_tests(
input_shape,
approx,
dtype,
dlayout,
in_mem_cfg,
out_mem_cfg,
data_seed,
device,
):
torch.manual_seed(data_seed)
# grad tensor
x = gen_func_with_cast_tt(partial(torch_random, low=-100, high=100, dtype=torch.float32), dtype[0])(input_shape[0])
# input tensor
y = gen_func_with_cast_tt(partial(torch_random, low=-100, high=100, dtype=torch.float32), dtype[1])(input_shape[0])

y.requires_grad = True

scalar = torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100).item()
print("scalar", scalar)
try:
# get ref result
golden_function = ttnn.get_golden_function(ttnn.bias_gelu_bw)
ref_value = golden_function(x, y, scalar, value=approx)[0]

tt_x = ttnn.from_torch(x, dtype=dtype[0], layout=dlayout[0], device=device, memory_config=in_mem_cfg[0])
tt_y = ttnn.from_torch(y, dtype=dtype[1], layout=dlayout[0], device=device, memory_config=in_mem_cfg[1])

tt_result = ttnn.bias_gelu_bw(tt_x, tt_y, scalar, approximate=approx, memory_config=out_mem_cfg)[0]
tt_result = ttnn.to_torch(tt_result)

except Exception as e:
logger.warning(f"Test execution crashed: {e}")
print(traceback.format_exc())
raise e

assert len(tt_result.shape) == len(ref_value.shape)
assert tt_result.shape == ref_value.shape
assert_with_pcc(ref_value, tt_result, 0.999)


test_sweep_args = [
(
[(6, 10, 128, 224)], # AssertionError: 0.99706924575737 , scalar -99.0
"tanh",
[ttnn.bfloat8_b, ttnn.bfloat16],
[ttnn.TILE_LAYOUT],
[ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG],
ttnn.DRAM_MEMORY_CONFIG,
14469376,
),
(
[(4, 2, 96, 192)], # AssertionError: 0.9744508807102572, scalar -100.0
"tanh",
[ttnn.bfloat16, ttnn.bfloat16],
[ttnn.TILE_LAYOUT],
[ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG],
ttnn.L1_MEMORY_CONFIG,
4378657,
),
(
[(5, 10, 224, 32)], # AssertionError: 0.9982306869898846, scalar -98.5
"tanh",
[ttnn.bfloat8_b, ttnn.bfloat16],
[ttnn.TILE_LAYOUT],
[ttnn.DRAM_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG],
ttnn.DRAM_MEMORY_CONFIG,
678741,
),
(
[(97, 129)], # Pass, 0.9990033308812074, scalar -97.5
"tanh",
[ttnn.bfloat16, ttnn.bfloat16],
[ttnn.TILE_LAYOUT],
[ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG],
ttnn.DRAM_MEMORY_CONFIG,
7580522,
),
]


@pytest.mark.parametrize(
"input_shape, approx, dtype, dlayout, in_mem_config, out_mem_config, data_seed",
(test_sweep_args),
)
def test_backward_div(input_shape, approx, dtype, dlayout, in_mem_config, out_mem_config, data_seed, device):
run_backward_div_tests(input_shape, approx, dtype, dlayout, in_mem_config, out_mem_config, data_seed, device)
1 change: 1 addition & 0 deletions tests/ttnn/utils_for_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def assert_with_pcc(expected_pytorch_result, actual_pytorch_result, pcc=0.9999):
actual_pytorch_result.shape
), f"list(expected_pytorch_result.shape)={list(expected_pytorch_result.shape)} vs list(actual_pytorch_result.shape)={list(actual_pytorch_result.shape)}"
pcc_passed, pcc_message = comp_pcc(expected_pytorch_result, actual_pytorch_result, pcc)
print("pcc_message", pcc_message)
assert pcc_passed, construct_pcc_assert_message(pcc_message, expected_pytorch_result, actual_pytorch_result)
return pcc_passed, pcc_message

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1328,21 +1328,20 @@ std::vector<std::optional<ttnn::Tensor>> ExecuteUnaryBackwardGelu::invoke(
TT_FATAL((approximate == "none" || approximate == "tanh"), "Incorrect approximate mode (expected 'None', 'tanh')");

if (approximate == "tanh") {
float kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
float kKappa = 0.044715;
Tensor x_sq = ttnn::multiply(input, input, std::nullopt, output_memory_config);
tt::log_debug(tt::LogOp, "************* in TANH gelu_bw");
float kBeta = M_SQRT2 * M_2_SQRTPI * 0.5f;
float kKappa = 0.044715f;
Tensor x_sq = ttnn::square(input, output_memory_config);
Tensor x_cube = ttnn::multiply(x_sq, input, std::nullopt, output_memory_config);
Tensor inner = ttnn::multiply(ttnn::add(input, ttnn::multiply(x_cube, kKappa, std::nullopt, output_memory_config)), kBeta, std::nullopt, output_mem_config);
Tensor tanh_inner = ttnn::tanh(inner, output_memory_config);

Tensor left = ttnn::multiply(input, 0.5, std::nullopt, output_memory_config);
Tensor left = ttnn::multiply(input, 0.5f, std::nullopt, output_memory_config);
Tensor right = ttnn::add(tanh_inner, 1, std::nullopt, output_memory_config);

Tensor left_derivative = ttnn::multiply(right, 0.5, std::nullopt, output_memory_config);

Tensor tanh_derivative =
ttnn::neg(ttnn::subtract(ttnn::multiply(tanh_inner, tanh_inner, std::nullopt, output_memory_config), 1, std::nullopt, output_mem_config),
output_memory_config);
Tensor tanh_derivative = ttnn::subtract(ttnn::ones_like(tanh_inner), ttnn::square(tanh_inner, output_memory_config), std::nullopt, output_mem_config);
Tensor inner_derivative = ttnn::multiply(
(ttnn::add(
ttnn::multiply(ttnn::multiply(x_sq, kKappa, std::nullopt, output_memory_config), 3, std::nullopt, output_memory_config), 1, std::nullopt, output_mem_config)), kBeta);
Expand All @@ -1354,6 +1353,7 @@ std::vector<std::optional<ttnn::Tensor>> ExecuteUnaryBackwardGelu::invoke(

ttnn::multiply(queue_id, grad, (ttnn::add(left_derivative, right_derivative)), std::nullopt, output_memory_config, input_grad);
result.push_back(input_grad);
// result.push_back(input_grad);
} else {
float kAlpha = M_SQRT1_2;
float kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5;
Expand Down

0 comments on commit bcd98a4

Please sign in to comment.