You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
Current softmax precision is low which leads to the instability if use it for training.
To Reproduce
def test_softmax2():
x16 = torch.tensor([1.296875, -0.625, 0.890625, -0.7734375, 0.30273438, -0.09033203, 0.052978516, -0.4140625, -0.5, 0.071777344], dtype=torch.bfloat16)
x32 = torch.tensor([1.296875, -0.625, 0.890625, -0.7734375, 0.30273438, -0.09033203, 0.052978516, -0.4140625, -0.5, 0.071777344], dtype=torch.float32)
y16 = torch.nn.functional.softmax(x16, dim=0)
y32 = torch.nn.functional.softmax(x32, dim=0)
print(y16, y16.sum())
print(y32, y32.sum())
good_preds = torch.tensor([0.28515625, 0.041992188, 0.19238281, 0.036132812, 0.10595703, 0.07080078, 0.08251953, 0.051513672, 0.04736328, 0.083496094])
tt_nn_preds = torch.tensor([0.27734375, 0.041748047, 0.19335938, 0.03466797, 0.107421875, 0.06933594, 0.08251953, 0.052734375, 0.048095703, 0.083984375])
print(good_preds, good_preds.sum())
print(tt_nn_preds, tt_nn_preds.sum())
#print(good_preds / good_preds.sum())
very_good_preds = torch.tensor([0.2890625, 0.041992188, 0.19238281, 0.035888672, 0.10839844, 0.07128906, 0.083984375, 0.052246094, 0.047607422, 0.08544922], dtype=torch.bfloat16)
print(very_good_preds, very_good_preds.sum())
with ttnn.manage_device(device_id=0) as device:
a = ttnn.from_torch(x16.view(1,1,1,10), device=device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)
output = ttnn.softmax(a, 3)
tt_res = ttnn.to_torch(output)
print(tt_res, tt_res.sum())
print('l2 dist from pytorch to good:', torch.dist(y16, good_preds))
print('l2 dist from pytorch to tt_nn:', torch.dist(y16, tt_nn_preds))
print('l2 dist from pytorch to very good:', torch.dist(y16, very_good_preds))
print('l2 dist from pytorch to tt_res:', torch.dist(y16, tt_res))
print('l2 dist from pytorch to good norm:', torch.dist(y16, good_preds/ good_preds.sum()))
print('l2 dist from pytorch to tt_nn norm:', torch.dist(y16, tt_nn_preds/ tt_nn_preds.sum()))
print('l2 dist from pytorch to very good norm:', torch.dist(y16, very_good_preds/ very_good_preds.sum()))
print('l2 dist from pytorch to tt_res norm:', torch.dist(y16, tt_res/ tt_res.sum()))
if __name__ == '__main__':
test_softmax2()
you can see that ttnn.softmax demonstrates worse accuracy than pytorch or even manual implementation of it as composite op:
tt::tt_metal::Tensor softmax(const tt::tt_metal::Tensor& t, int dim) {
auto t_max = ttnn_fixed::max(t, dim, /* keepdim */ true);
auto t_sub_max = ttnn::subtract(t, t_max);
auto t_sub_max_exp = ttnn::exp(t_sub_max);
auto t_sum_over_dim =
ttnn::moreh_sum(t_sub_max_exp, dim, /* keep_dim */ true, std::nullopt, std::nullopt, std::nullopt);
auto inv_t_sum_over_dim = ttnn::reciprocal(/* queue_id */ 0, t_sum_over_dim);
return ttnn::multiply(t_sub_max_exp, inv_t_sum_over_dim);
}
Here is output:
l2 dist from pytorch to good: tensor(0.0026)
l2 dist from pytorch to tt_nn: tensor(0.0105)
l2 dist from pytorch to very good: tensor(0.0038, dtype=torch.bfloat16)
l2 dist from pytorch to tt_res: TorchTensor(0.0061, dtype=torch.bfloat16)
l2 dist from pytorch to good norm: tensor(0.0022)
l2 dist from pytorch to tt_nn norm: tensor(0.0090)
l2 dist from pytorch to very good norm: tensor(0.0024, dtype=torch.bfloat16)
l2 dist from pytorch to tt_res norm: TorchTensor(0.0053, dtype=torch.bfloat16)
I've tested a few ways how to implement you can see that my composite op softmax which is called 'good' shows the best results.
tt_nn is a stable softmax which uses max + ttnn.softmax and tt_res is the ttnn.softmax. Both of them show pretty bad accuracy. Expected behavior
It should not have worse accuracy than default c++ implementation.
Screenshots
If applicable, add screenshots to help explain your problem.
Please complete the following environment information:
OS: [e.g. Ubuntu 20.04]
Version of software (eg. commit)
Additional context
This bug is different from not subtracting max issue.
I think it is related to the Exp used in fast and approx mode.
The text was updated successfully, but these errors were encountered:
Describe the bug
Current softmax precision is low which leads to the instability if use it for training.
To Reproduce
you can see that ttnn.softmax demonstrates worse accuracy than pytorch or even manual implementation of it as composite op:
Here is output:
I've tested a few ways how to implement you can see that my composite op softmax which is called 'good' shows the best results.
tt_nn is a stable softmax which uses max + ttnn.softmax and tt_res is the ttnn.softmax. Both of them show pretty bad accuracy.
Expected behavior
It should not have worse accuracy than default c++ implementation.
Screenshots
If applicable, add screenshots to help explain your problem.
Please complete the following environment information:
Additional context
This bug is different from not subtracting max issue.
I think it is related to the Exp used in fast and approx mode.
The text was updated successfully, but these errors were encountered: