Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Report] ttnn.softmax result different from pytorch #12847

Open
dmakoviichuk-tt opened this issue Sep 18, 2024 · 0 comments
Open

[Bug Report] ttnn.softmax result different from pytorch #12847

dmakoviichuk-tt opened this issue Sep 18, 2024 · 0 comments
Assignees
Labels
bug Something isn't working P1_critical

Comments

@dmakoviichuk-tt
Copy link
Contributor

Describe the bug
Current softmax precision is low which leads to the instability if use it for training.

To Reproduce

def test_softmax2():
    x16 = torch.tensor([1.296875, -0.625, 0.890625, -0.7734375, 0.30273438, -0.09033203, 0.052978516, -0.4140625, -0.5, 0.071777344], dtype=torch.bfloat16)
    x32 = torch.tensor([1.296875, -0.625, 0.890625, -0.7734375, 0.30273438, -0.09033203, 0.052978516, -0.4140625, -0.5, 0.071777344], dtype=torch.float32)


    y16 = torch.nn.functional.softmax(x16, dim=0)
    y32 = torch.nn.functional.softmax(x32, dim=0)

    print(y16, y16.sum())
    print(y32, y32.sum())  

    good_preds = torch.tensor([0.28515625, 0.041992188, 0.19238281, 0.036132812, 0.10595703, 0.07080078, 0.08251953, 0.051513672, 0.04736328, 0.083496094])
    tt_nn_preds = torch.tensor([0.27734375, 0.041748047, 0.19335938, 0.03466797, 0.107421875, 0.06933594, 0.08251953, 0.052734375, 0.048095703, 0.083984375])
    
    print(good_preds, good_preds.sum())

    print(tt_nn_preds, tt_nn_preds.sum())

    #print(good_preds / good_preds.sum())

    very_good_preds = torch.tensor([0.2890625, 0.041992188, 0.19238281, 0.035888672, 0.10839844, 0.07128906, 0.083984375, 0.052246094, 0.047607422, 0.08544922], dtype=torch.bfloat16)
    print(very_good_preds, very_good_preds.sum())

    with ttnn.manage_device(device_id=0) as device:
        a = ttnn.from_torch(x16.view(1,1,1,10), device=device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)
        output = ttnn.softmax(a, 3)
        tt_res = ttnn.to_torch(output)
    print(tt_res, tt_res.sum())

    print('l2 dist from pytorch to good:', torch.dist(y16, good_preds))
    print('l2 dist from pytorch to tt_nn:', torch.dist(y16, tt_nn_preds))
    print('l2 dist from pytorch to very good:', torch.dist(y16, very_good_preds))
    print('l2 dist from pytorch to tt_res:', torch.dist(y16, tt_res))

    print('l2 dist from pytorch to good norm:', torch.dist(y16, good_preds/ good_preds.sum()))
    print('l2 dist from pytorch to tt_nn norm:', torch.dist(y16, tt_nn_preds/ tt_nn_preds.sum()))
    print('l2 dist from pytorch to very good norm:', torch.dist(y16, very_good_preds/ very_good_preds.sum()))
    print('l2 dist from pytorch to tt_res norm:', torch.dist(y16, tt_res/ tt_res.sum()))


if __name__ == '__main__':
    test_softmax2()

you can see that ttnn.softmax demonstrates worse accuracy than pytorch or even manual implementation of it as composite op:

tt::tt_metal::Tensor softmax(const tt::tt_metal::Tensor& t, int dim) {
    auto t_max = ttnn_fixed::max(t, dim, /* keepdim */ true);
    auto t_sub_max = ttnn::subtract(t, t_max);
    auto t_sub_max_exp = ttnn::exp(t_sub_max);
    auto t_sum_over_dim =
        ttnn::moreh_sum(t_sub_max_exp, dim, /* keep_dim */ true, std::nullopt, std::nullopt, std::nullopt);
    auto inv_t_sum_over_dim = ttnn::reciprocal(/* queue_id */ 0, t_sum_over_dim);
    return ttnn::multiply(t_sub_max_exp, inv_t_sum_over_dim);
}

Here is output:

l2 dist from pytorch to good: tensor(0.0026)
l2 dist from pytorch to tt_nn: tensor(0.0105)
l2 dist from pytorch to very good: tensor(0.0038, dtype=torch.bfloat16)
l2 dist from pytorch to tt_res: TorchTensor(0.0061, dtype=torch.bfloat16)
l2 dist from pytorch to good norm: tensor(0.0022)
l2 dist from pytorch to tt_nn norm: tensor(0.0090)
l2 dist from pytorch to very good norm: tensor(0.0024, dtype=torch.bfloat16)
l2 dist from pytorch to tt_res norm: TorchTensor(0.0053, dtype=torch.bfloat16)

I've tested a few ways how to implement you can see that my composite op softmax which is called 'good' shows the best results.
tt_nn is a stable softmax which uses max + ttnn.softmax and tt_res is the ttnn.softmax. Both of them show pretty bad accuracy.
Expected behavior
It should not have worse accuracy than default c++ implementation.

Screenshots
If applicable, add screenshots to help explain your problem.

Please complete the following environment information:

  • OS: [e.g. Ubuntu 20.04]
  • Version of software (eg. commit)

Additional context
This bug is different from not subtracting max issue.
I think it is related to the Exp used in fast and approx mode.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working P1_critical
Projects
None yet
Development

No branches or pull requests

2 participants