Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ttnn.deallocate causes bad PCC / L1 overwrites during subsequent sharded_to_interleaved #14902

Open
yieldthought opened this issue Nov 8, 2024 · 5 comments
Assignees
Labels

Comments

@yieldthought
Copy link
Contributor

yieldthought commented Nov 8, 2024

Describe the bug
Our accuracy test of llama 3.1 8b on T3K has 0.0 accuracy when a tensor that is no longer used is deallocated. The proximate cause of the bad accuracy is the corruption of other tensors in L1 (in our case the residual tensor h) during the MLP module, which does not normally even get passed the h tensor. This corruption occurs during a sharded_to_interleaved call following the DRAM-sharded matmuls.

To Reproduce
Steps to reproduce the behavior:

  1. Get a T3k
  2. Copy the Llama 3.1 8B weights from sjc-snva-t3002:/proj_sw/user_dev/llama31-8b-data/Meta-Llama-3.1-8B-Instruct/
  3. Check out yieldthought/memory-corruption branch and build
  4. LLAMA_DIR=/proj_sw/user_dev/llama31-8b-data/Meta-Llama-3.1-8B-Instruct pytest models/demos/llama3/tests/test_llama_accuracy.py

Test asserts showing that h was corrupted during sharded_to_interleaved:

first_five_before=tensor([ 0.0106, -0.0006, -0.0012,  0.0080, -0.0243], dtype=torch.bfloat16)
first_five_after=tensor([0., 0., 0., 0., 0.], dtype=torch.bfloat16)
...
        if mode == "decode":
            if debug_tensor is not None:
                first_five_before = first_five(debug_tensor, self.mesh_device)
            w2_out = ttnn.sharded_to_interleaved(
                w2_out, ttnn.L1_MEMORY_CONFIG
            )  # NOTE: writing this out to DRAM interleaved avoids corrupting h!
            if debug_tensor is not None:
                first_five_after = first_five(debug_tensor, self.mesh_device)
                if not torch.allclose(first_five_before, first_five_after):
                    print(f'{first_five_before=}')
                    print(f'{first_five_after=}')
>                   assert False, "h was corrupted during sharded_to_interleaved"
E                   AssertionError: h was corrupted during sharded_to_interleaved

Expected behavior
Test passes with 88-90% top-1 accuracy. This can be obtained by commenting out the deallocate on line 114 of llama_decoder.py:

ttnn.deallocate(attn_out) # NOTE: Commenting out this deallocate avoids the bad output

Please complete the following environment information:
Internal ird t3k, I used sjc-snva-t3002

Additional context
Add any other context about the problem here.

@yieldthought
Copy link
Contributor Author

@mtairum for visibility - I think this is related to the ND issues we were seeing, but this one happens every time.

@yieldthought
Copy link
Contributor Author

FWIW using ttnn.to_memory_config has the same effect as ttnn.sharded_to_interleaved here.

@yieldthought
Copy link
Contributor Author

Escalating to P1, this is on the critical path to delivering llama 3.2 to the customer team; we need to use L1 to hit our performance targets.

@ntarafdar ntarafdar assigned jvegaTT and unassigned ntarafdar Nov 12, 2024
@ntarafdar
Copy link
Contributor

@yieldthought , @jvegaTT is looking at this now.

@jvegaTT
Copy link
Contributor

jvegaTT commented Nov 12, 2024

There is a failure pattern to the output tensor. 1024 bytes of mostly 0s with some very large float numbers interleaved, then 1024 bytes of the correct output data, then this pattern repeats throughout entire output tensor. Currently debugging why that may be the case.

Also moving ttnn.deallocate(attn_out) in llama_decoder.py to after the feed_forward.forward function call fixes the issue. I am suspecting a double free situation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants