Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

we should ensure activation checkpointing with Float8Linear behaves optimally #893

Open
vkuzo opened this issue Sep 16, 2024 · 1 comment

Comments

@vkuzo
Copy link
Contributor

vkuzo commented Sep 16, 2024

When AC is on for Float8Linear, what I would expect is:

  1. the forward gemm is recomputed in the backward (it is not being recomputed now)
  2. max(abs(activation)) and max(abs(weight)) are NOT recomputed, it's much better to always reuse them as they are tiny (seems like one of these is being recomputed now)

Let's figure out why this isn't what is happening now and what we should do about it. Note: reproductions below require #892

bfloat16 linear fwd/bwd with activation checkpointing on

repro command

python benchmarks/float8/profile_linear_float8.py ~/local/tmp/20240916_act_chk_on --dtype_filter bfloat16 --enable_activation_checkpointing True

trace snippet

Screenshot 2024-09-16 at 2 50 54 PM

we see 1 gemm in the forward and 3 in the backward, as expected

Float8Linear fwd/bwd with activation checkpointing on

repro command

python benchmarks/float8/profile_linear_float8.py ~/local/tmp/20240916_act_chk_on --dtype_filter float8 --enable_activation_checkpointing True

trace snippet

Screenshot 2024-09-16 at 3 05 37 PM

issue 1: there are only two gemms in the backward instead of three
issue 2: there are some extra kernels in the backward which are recomputing max(abs(activation)) and max(abs(weight))

@vkuzo
Copy link
Contributor Author

vkuzo commented Sep 17, 2024

the torch._scaled_mm behavior seems fine
the max(abs(tensor)) behavior seems inoptimal and we can do better with custom AC settings. I wrote up pytorch/torchtitan#580 with initial findings, will follow up after the conferences this week with more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant