Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise LayerDrop implementation #484

Merged
merged 2 commits into from
Apr 27, 2024
Merged

Revise LayerDrop implementation #484

merged 2 commits into from
Apr 27, 2024

Conversation

cbalioglu
Copy link
Contributor

This PR revises the LayerDrop implementation and moves its from ModuleList to StandardTransformerEncoder and StandardTransformerDecoder. Although the original implementation was ideal, both DDP and FSDP had trouble correctly handling forward/backward passes and either silently failed to sync gradients (DDP), or failed with some cryptic error. This implementation causes redundant computation of dropped layers, but since the autograd graph stays constant, both DDP and FSDP can handle it correctly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 27, 2024
@cbalioglu cbalioglu merged commit f193bd8 into main Apr 27, 2024
10 checks passed
@cbalioglu cbalioglu deleted the layerdrop branch April 27, 2024 02:42
Comment on lines 21 to 22
drop_p: float = 0.0,
generator: Optional[Generator] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we remove these as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the generator? It should be still around in case someone wants to provide a different RNG for layerdrop


# compat
@final
class ModuleList(TorchModuleList):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this class? Why don't we just use torch.nn.ModuleList instead of this? I don't see this even used anywhere, I propose for this file to be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have several teams using this module for now. Everything tagged with “# compat” will eventually get removed once we migrate those uses (before v0.3 release)

Comment on lines +267 to +268
def backward(ctx: Any, grad_output: Tensor) -> Tuple[Tensor, Tensor]:
return grad_output, torch.zeros_like(grad_output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smartly done! The gradient with respect to x is going to be just the same as the gradient of the output, since this is just the identity function when you drop layers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, what is the advantage of using this method over PyTorch hooks?

@shagunsodhani
Copy link
Contributor

both DDP and FSDP had trouble correctly handling forward/backward passes and either silently failed to sync gradients (DDP), or failed with some cryptic error.

both DDP and FSDP had trouble correctly handling forward/backward passes and either silently failed to sync gradients (DDP), or failed with some cryptic error.

Did this happen because rng across the different gpus was different - causing different layers to be dropped across gpus ?

@@ -204,6 +221,19 @@ def forward(

return seqs, padding_mask

def _drop_iter(self) -> Iterator[Tuple[Module, bool]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this logic (and some other pieces in the forward pass) are used in decoder as well. Maybe we could consider adding a base component in the future (that both encoder and decoder would inherit from)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants