-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Revise LayerDrop implementation #484
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,85 +4,24 @@ | |
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Iterable, Iterator, Optional, final | ||
from typing import Any, Iterable, Optional, final | ||
|
||
import torch | ||
from torch import Generator | ||
from torch.nn import Module | ||
from torch.nn import ModuleList as TorchModuleList | ||
|
||
from fairseq2.typing import CPU | ||
|
||
|
||
# compat | ||
@final | ||
class ModuleList(TorchModuleList): | ||
"""Holds submodules in a list. | ||
|
||
This class extends :class:`torch.nn.ModuleList` with an extra feature that | ||
optionally drops a random number of submodules at every iteration during | ||
training. | ||
|
||
Usage: | ||
|
||
>>> from torch.nn import Module | ||
>>> | ||
>>> from fairseq2.nn import ModuleList | ||
>>> | ||
>>> layer1 = Module() | ||
>>> layer2 = Module() | ||
>>> layer3 = Module() | ||
>>> | ||
>>> layers = ModuleList([layer1, layer2, layer3], drop_p=0.5) | ||
>>> | ||
>>> for layer in layers.drop_iter(): # This might iterate over layers 1 and 3. | ||
... x = layer(x) | ||
>>> for layer in layers.drop_iter(): # This might iterate over all layers. | ||
... x = layer(x) | ||
>>> for layer in layers.drop_iter(): # This might not iterate over any layers. | ||
... x = layer(x) | ||
""" | ||
|
||
drop_p: float | ||
generator: Optional[Generator] | ||
|
||
def __init__( | ||
self, | ||
modules: Optional[Iterable[Module]] = None, | ||
*, | ||
drop_p: float = 0.0, | ||
generator: Optional[Generator] = None, | ||
Comment on lines
21
to
22
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we remove these as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean the generator? It should be still around in case someone wants to provide a different RNG for layerdrop |
||
) -> None: | ||
""" | ||
:param modules: | ||
An iterable of modules to add. | ||
:param drop_p: | ||
The probability of dropping a submodule during training. | ||
:param generator: | ||
The random number generator. | ||
""" | ||
super().__init__(modules) | ||
|
||
self.drop_p = drop_p | ||
self.generator = generator | ||
|
||
def drop_iter(self) -> Iterator[Module]: | ||
"""Return an iterator that drops a random set of submodules.""" | ||
if self.drop_p > 0.0 and self.training: | ||
prob_dist = torch.rand( | ||
len(self), generator=self.generator, device=CPU, dtype=torch.float32 | ||
) | ||
else: | ||
prob_dist = None | ||
|
||
for idx, m in enumerate(super().__iter__()): | ||
if prob_dist is None or prob_dist[idx] > self.drop_p: | ||
yield m | ||
|
||
def extra_repr(self) -> str: | ||
""":meta private:""" | ||
s = super().extra_repr() | ||
|
||
if self.drop_p > 0.0: | ||
s = f"{s}, drop_p={self.drop_p:G}" | ||
|
||
return s | ||
def drop_iter(self) -> Any: | ||
return super().__iter__() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,13 +8,14 @@ | |
|
||
from abc import ABC, abstractmethod | ||
from collections import OrderedDict | ||
from typing import Dict, Iterable, Optional, Protocol, Tuple, final | ||
from typing import Any, Dict, Iterable, Iterator, Optional, Protocol, Tuple, final | ||
|
||
from torch import Tensor | ||
from torch.nn import Module | ||
import torch | ||
from torch import Generator, Tensor | ||
from torch.autograd import Function | ||
from torch.nn import Module, ModuleList | ||
from torch.utils.hooks import RemovableHandle | ||
|
||
from fairseq2.nn.module_list import ModuleList | ||
from fairseq2.nn.normalization import LayerNorm | ||
from fairseq2.nn.padding import PaddingMask | ||
from fairseq2.nn.transformer.attention_mask import AttentionMaskFactory | ||
|
@@ -24,7 +25,7 @@ | |
create_standard_layer_norm, | ||
) | ||
from fairseq2.nn.transformer.norm_order import TransformerNormOrder | ||
from fairseq2.typing import DataType, Device, override | ||
from fairseq2.typing import CPU, DataType, Device, override | ||
|
||
|
||
class TransformerEncoder(Module, ABC): | ||
|
@@ -125,6 +126,8 @@ class StandardTransformerEncoder(TransformerEncoder): | |
:cite:t:`https://doi.org/10.48550/arxiv.1706.03762`.""" | ||
|
||
self_attn_mask_factory: Optional[AttentionMaskFactory] | ||
layer_drop_p: float | ||
generator: Optional[Generator] | ||
layer_norm: Optional[LayerNorm] | ||
norm_order: TransformerNormOrder | ||
|
||
|
@@ -134,6 +137,7 @@ def __init__( | |
*, | ||
self_attn_mask_factory: Optional[AttentionMaskFactory] = None, | ||
layer_drop_p: float = 0.0, | ||
generator: Optional[Generator] = None, | ||
norm_order: TransformerNormOrder = TransformerNormOrder.POST, | ||
layer_norm_factory: Optional[LayerNormFactory] = None, | ||
device: Optional[Device] = None, | ||
|
@@ -147,12 +151,14 @@ def __init__( | |
:param layer_drop_p: | ||
If greater than zero, applies LayerDrop to the encoder layers as | ||
described in :cite:t:`https://doi.org/10.48550/arxiv.1909.11556`. | ||
:param generator: | ||
The random number generator for LayerDrop. | ||
:param norm_order: | ||
The Layer Normalization order. | ||
:param layer_norm_factory: | ||
The factory to construct the Layer Normalization module. | ||
""" | ||
layer_list = ModuleList(layers, drop_p=layer_drop_p) | ||
layer_list = ModuleList(layers) | ||
if not layer_list: | ||
raise ValueError("`layers` must be non-empty.") | ||
|
||
|
@@ -167,6 +173,10 @@ def __init__( | |
|
||
self.layers = layer_list | ||
|
||
self.layer_drop_p = layer_drop_p | ||
|
||
self.generator = generator | ||
|
||
if norm_order != TransformerNormOrder.POST: | ||
self.layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) | ||
else: | ||
|
@@ -192,8 +202,15 @@ def forward( | |
seqs, keys=seqs, training=self.training | ||
) | ||
|
||
for layer_idx, layer in enumerate(self.layers.drop_iter()): | ||
seqs, padding_mask = layer(seqs, padding_mask, self_attn_mask) | ||
for layer_idx, (layer, drop) in enumerate(self._drop_iter()): | ||
layer_output, layer_padding_mask = layer(seqs, padding_mask, self_attn_mask) | ||
|
||
if drop: | ||
seqs = _record_drop_for_backward(seqs, layer_output) | ||
|
||
continue | ||
|
||
seqs, padding_mask = layer_output, layer_padding_mask | ||
|
||
for hook in self._layer_output_hooks.values(): | ||
if not hook(layer_idx, seqs, padding_mask, num_layers): | ||
|
@@ -204,6 +221,19 @@ def forward( | |
|
||
return seqs, padding_mask | ||
|
||
def _drop_iter(self) -> Iterator[Tuple[Module, bool]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this logic (and some other pieces in the forward pass) are used in decoder as well. Maybe we could consider adding a base component in the future (that both encoder and decoder would inherit from) |
||
if self.training and self.layer_drop_p > 0.0: | ||
prob_dist = torch.rand( | ||
len(self.layers), generator=self.generator, device=CPU | ||
) | ||
else: | ||
prob_dist = None | ||
|
||
for idx, m in enumerate(self.layers): | ||
drop = prob_dist is not None and float(prob_dist[idx]) <= self.layer_drop_p | ||
|
||
yield m, drop | ||
|
||
def extra_repr(self) -> str: | ||
""":meta private:""" | ||
s = super().extra_repr() | ||
|
@@ -215,4 +245,24 @@ def extra_repr(self) -> str: | |
|
||
s = f"{s}, self_attn_mask_factory={self_attn_mask_factory}" | ||
|
||
if self.layer_drop_p > 0.0: | ||
s = f"{s}, layer_drop_p={self.layer_drop_p:G}" | ||
|
||
return f"{s}, norm_order={self.norm_order}" | ||
|
||
|
||
# mypy: disable-error-code="no-any-return,override" | ||
|
||
|
||
def _record_drop_for_backward(x: Tensor, dropped_output: Tensor) -> Tensor: | ||
return _RecordDropForBackwardFunction.apply(x, dropped_output) | ||
|
||
|
||
class _RecordDropForBackwardFunction(Function): | ||
@staticmethod | ||
def forward(ctx: Any, x: Tensor, dropped_output: Tensor) -> Tensor: | ||
return x | ||
|
||
@staticmethod | ||
def backward(ctx: Any, grad_output: Tensor) -> Tuple[Tensor, Tensor]: | ||
return grad_output, torch.zeros_like(grad_output) | ||
Comment on lines
+267
to
+268
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Smartly done! The gradient with respect to x is going to be just the same as the gradient of the output, since this is just the identity function when you drop layers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, what is the advantage of using this method over PyTorch hooks? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of this class? Why don't we just use
torch.nn.ModuleList
instead of this? I don't see this even used anywhere, I propose for this file to be removed.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have several teams using this module for now. Everything tagged with “# compat” will eventually get removed once we migrate those uses (before v0.3 release)