Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is autocast needed with FSDP2? #700

Closed
garrett361 opened this issue Nov 25, 2024 · 1 comment
Closed

Is autocast needed with FSDP2? #700

garrett361 opened this issue Nov 25, 2024 · 1 comment
Labels
question Further information is requested

Comments

@garrett361
Copy link
Collaborator

garrett361 commented Nov 25, 2024

Hi, is it necessary to wrap the forward pass in autocast when using FSDP2? I noticed that the torchtitan training loop does not.

If I wrap in torch.autocast(device_type="cuda", dtype=torch.bfloat16) my matmuls will be bfloat16, but my softmaxes (say) will be in float32. This behavior requires the autocast wrapper:

t = torch.randn(100, device="cuda", dtype=torch.bfloat16)

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    out = t.softmax(dim=-1)

out.dtype # torch.float32

# Without autocast:
t.softmax(dim=-1).dtype # torch.bfloat16

This is the usual way to do DDP or non-distributed mixed-precision training.

It seems to me that this behavior is lost in the torchtitan training loop which doesn't use the autocast context manager. Is this not true? Does FSDP2 somehow still perform the upcast for the usual upcasted amp ops like softmax? Not seeing how it might do so, and can't test easily at the moment.

I believe I correctly understand that MixedPrecisionPolicy controls the dtypes that weights are held in, reductions are performed in, and whether to cast a given module's outputs to a certain dtype, but that is all orthogonal to the dispatcher flags that autocast controls, IIUC.

Relates to #600 and #591. Also, I believe OLMo uses autocast with FSDP, but that is FSDP1 last time I checked.

CC @awgu

@tianyu-l tianyu-l added the question Further information is requested label Nov 25, 2024
@awgu
Copy link
Contributor

awgu commented Dec 2, 2024

The behavior with FSDP2's native mixed precision is the same as FSDP1's. You need to use autocast manager if you want that softmax-in-fp32 behavior.

FSDP's native mixed precision is simply:

  1. param_dtype: what dtype to all-gather parameters in?
  2. cast_forward_inputs: should FSDP cast the module forward inputs to param_dtype?
  3. reduce_dtype: what dtype to reduce-scatter gradients in?
  4. output_dtype: optionally should FSDP cast the module forward outputs to output_dtype?

The only casts that FSDP can provide are from 1, 2, 3, and 4. There is no magic at the forward/backward operator level (e.g. within an FSDP module forward).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants