Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise LayerDrop implementation #484

Merged
merged 2 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 4 additions & 65 deletions src/fairseq2/nn/module_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,85 +4,24 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Iterable, Iterator, Optional, final
from typing import Any, Iterable, Optional, final

import torch
from torch import Generator
from torch.nn import Module
from torch.nn import ModuleList as TorchModuleList

from fairseq2.typing import CPU


# compat
@final
class ModuleList(TorchModuleList):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this class? Why don't we just use torch.nn.ModuleList instead of this? I don't see this even used anywhere, I propose for this file to be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have several teams using this module for now. Everything tagged with “# compat” will eventually get removed once we migrate those uses (before v0.3 release)

"""Holds submodules in a list.

This class extends :class:`torch.nn.ModuleList` with an extra feature that
optionally drops a random number of submodules at every iteration during
training.

Usage:

>>> from torch.nn import Module
>>>
>>> from fairseq2.nn import ModuleList
>>>
>>> layer1 = Module()
>>> layer2 = Module()
>>> layer3 = Module()
>>>
>>> layers = ModuleList([layer1, layer2, layer3], drop_p=0.5)
>>>
>>> for layer in layers.drop_iter(): # This might iterate over layers 1 and 3.
... x = layer(x)
>>> for layer in layers.drop_iter(): # This might iterate over all layers.
... x = layer(x)
>>> for layer in layers.drop_iter(): # This might not iterate over any layers.
... x = layer(x)
"""

drop_p: float
generator: Optional[Generator]

def __init__(
self,
modules: Optional[Iterable[Module]] = None,
*,
drop_p: float = 0.0,
generator: Optional[Generator] = None,
Comment on lines 21 to 22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we remove these as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the generator? It should be still around in case someone wants to provide a different RNG for layerdrop

) -> None:
"""
:param modules:
An iterable of modules to add.
:param drop_p:
The probability of dropping a submodule during training.
:param generator:
The random number generator.
"""
super().__init__(modules)

self.drop_p = drop_p
self.generator = generator

def drop_iter(self) -> Iterator[Module]:
"""Return an iterator that drops a random set of submodules."""
if self.drop_p > 0.0 and self.training:
prob_dist = torch.rand(
len(self), generator=self.generator, device=CPU, dtype=torch.float32
)
else:
prob_dist = None

for idx, m in enumerate(super().__iter__()):
if prob_dist is None or prob_dist[idx] > self.drop_p:
yield m

def extra_repr(self) -> str:
""":meta private:"""
s = super().extra_repr()

if self.drop_p > 0.0:
s = f"{s}, drop_p={self.drop_p:G}"

return s
def drop_iter(self) -> Any:
return super().__iter__()
49 changes: 41 additions & 8 deletions src/fairseq2/nn/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,28 @@

from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Dict, Iterable, Optional, Protocol, Tuple, final
from typing import Dict, Iterable, Iterator, Optional, Protocol, Tuple, final

from torch import Tensor
from torch.nn import Module
import torch
from torch import Generator, Tensor
from torch.nn import Module, ModuleList
from torch.utils.hooks import RemovableHandle

from fairseq2.nn.incremental_state import IncrementalStateBag
from fairseq2.nn.module_list import ModuleList
from fairseq2.nn.normalization import LayerNorm
from fairseq2.nn.padding import PaddingMask
from fairseq2.nn.transformer.attention_mask import (
AttentionMaskFactory,
CausalAttentionMaskFactory,
)
from fairseq2.nn.transformer.decoder_layer import TransformerDecoderLayer
from fairseq2.nn.transformer.encoder import _record_drop_for_backward
from fairseq2.nn.transformer.layer_norm import (
LayerNormFactory,
create_standard_layer_norm,
)
from fairseq2.nn.transformer.norm_order import TransformerNormOrder
from fairseq2.typing import DataType, Device, override
from fairseq2.typing import CPU, DataType, Device, override


class TransformerDecoder(Module, ABC):
Expand Down Expand Up @@ -145,6 +146,8 @@ class StandardTransformerDecoder(TransformerDecoder):
:cite:t:`https://doi.org/10.48550/arxiv.1706.03762`."""

self_attn_mask_factory: Optional[AttentionMaskFactory]
layer_drop_p: float
generator: Optional[Generator]
layer_norm: Optional[LayerNorm]
norm_order: TransformerNormOrder

Expand All @@ -155,6 +158,7 @@ def __init__(
self_attn_mask_factory: Optional[AttentionMaskFactory] = None,
use_causal_attn_mask: bool = True,
layer_drop_p: float = 0.0,
generator: Optional[Generator] = None,
norm_order: TransformerNormOrder = TransformerNormOrder.POST,
layer_norm_factory: Optional[LayerNormFactory] = None,
device: Optional[Device] = None,
Expand All @@ -172,12 +176,14 @@ def __init__(
:param layer_drop_p:
If greater than zero, applies LayerDrop to the decoder layers as
described in :cite:t:`https://doi.org/10.48550/arxiv.1909.11556`.
:param generator:
The random number generator for LayerDrop.
:param norm_order:
The Layer Normalization order.
:param layer_norm_factory:
The factory to construct the Layer Normalization module.
"""
layer_list = ModuleList(layers, drop_p=layer_drop_p)
layer_list = ModuleList(layers)
if not layer_list:
raise ValueError("`layers` must be non-empty.")

Expand All @@ -197,6 +203,10 @@ def __init__(

self.layers = layer_list

self.layer_drop_p = layer_drop_p

self.generator = generator

if norm_order != TransformerNormOrder.POST:
self.layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
else:
Expand Down Expand Up @@ -228,8 +238,8 @@ def forward(
seqs, keys=seqs, training=self.training, state_bag=state_bag
)

for layer_idx, layer in enumerate(self.layers.drop_iter()):
seqs, padding_mask = layer(
for layer_idx, (layer, drop) in enumerate(self._drop_iter()):
layer_output, layer_padding_mask = layer(
seqs,
padding_mask,
self_attn_mask,
Expand All @@ -238,6 +248,13 @@ def forward(
state_bag=state_bag,
)

if drop:
seqs = _record_drop_for_backward(seqs, layer_output)

continue

seqs, padding_mask = layer_output, layer_padding_mask

for hook in self._layer_output_hooks.values():
if not hook(layer_idx, seqs, padding_mask, num_layers):
break
Expand All @@ -247,6 +264,19 @@ def forward(

return seqs, padding_mask

def _drop_iter(self) -> Iterator[Tuple[Module, bool]]:
if self.training and self.layer_drop_p > 0.0:
prob_dist = torch.rand(
len(self.layers), generator=self.generator, device=CPU
)
else:
prob_dist = None

for idx, m in enumerate(self.layers):
drop = prob_dist is not None and float(prob_dist[idx]) <= self.layer_drop_p

yield m, drop

def extra_repr(self) -> str:
""":meta private:"""
s = super().extra_repr()
Expand All @@ -258,4 +288,7 @@ def extra_repr(self) -> str:

s = f"{s}, self_attn_mask_factory={self_attn_mask_factory}"

if self.layer_drop_p > 0.0:
s = f"{s}, layer_drop_p={self.layer_drop_p:G}"

return f"{s}, norm_order={self.norm_order}"
66 changes: 58 additions & 8 deletions src/fairseq2/nn/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@

from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Dict, Iterable, Optional, Protocol, Tuple, final
from typing import Any, Dict, Iterable, Iterator, Optional, Protocol, Tuple, final

from torch import Tensor
from torch.nn import Module
import torch
from torch import Generator, Tensor
from torch.autograd import Function
from torch.nn import Module, ModuleList
from torch.utils.hooks import RemovableHandle

from fairseq2.nn.module_list import ModuleList
from fairseq2.nn.normalization import LayerNorm
from fairseq2.nn.padding import PaddingMask
from fairseq2.nn.transformer.attention_mask import AttentionMaskFactory
Expand All @@ -24,7 +25,7 @@
create_standard_layer_norm,
)
from fairseq2.nn.transformer.norm_order import TransformerNormOrder
from fairseq2.typing import DataType, Device, override
from fairseq2.typing import CPU, DataType, Device, override


class TransformerEncoder(Module, ABC):
Expand Down Expand Up @@ -125,6 +126,8 @@ class StandardTransformerEncoder(TransformerEncoder):
:cite:t:`https://doi.org/10.48550/arxiv.1706.03762`."""

self_attn_mask_factory: Optional[AttentionMaskFactory]
layer_drop_p: float
generator: Optional[Generator]
layer_norm: Optional[LayerNorm]
norm_order: TransformerNormOrder

Expand All @@ -134,6 +137,7 @@ def __init__(
*,
self_attn_mask_factory: Optional[AttentionMaskFactory] = None,
layer_drop_p: float = 0.0,
generator: Optional[Generator] = None,
norm_order: TransformerNormOrder = TransformerNormOrder.POST,
layer_norm_factory: Optional[LayerNormFactory] = None,
device: Optional[Device] = None,
Expand All @@ -147,12 +151,14 @@ def __init__(
:param layer_drop_p:
If greater than zero, applies LayerDrop to the encoder layers as
described in :cite:t:`https://doi.org/10.48550/arxiv.1909.11556`.
:param generator:
The random number generator for LayerDrop.
:param norm_order:
The Layer Normalization order.
:param layer_norm_factory:
The factory to construct the Layer Normalization module.
"""
layer_list = ModuleList(layers, drop_p=layer_drop_p)
layer_list = ModuleList(layers)
if not layer_list:
raise ValueError("`layers` must be non-empty.")

Expand All @@ -167,6 +173,10 @@ def __init__(

self.layers = layer_list

self.layer_drop_p = layer_drop_p

self.generator = generator

if norm_order != TransformerNormOrder.POST:
self.layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
else:
Expand All @@ -192,8 +202,15 @@ def forward(
seqs, keys=seqs, training=self.training
)

for layer_idx, layer in enumerate(self.layers.drop_iter()):
seqs, padding_mask = layer(seqs, padding_mask, self_attn_mask)
for layer_idx, (layer, drop) in enumerate(self._drop_iter()):
layer_output, layer_padding_mask = layer(seqs, padding_mask, self_attn_mask)

if drop:
seqs = _record_drop_for_backward(seqs, layer_output)

continue

seqs, padding_mask = layer_output, layer_padding_mask

for hook in self._layer_output_hooks.values():
if not hook(layer_idx, seqs, padding_mask, num_layers):
Expand All @@ -204,6 +221,19 @@ def forward(

return seqs, padding_mask

def _drop_iter(self) -> Iterator[Tuple[Module, bool]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this logic (and some other pieces in the forward pass) are used in decoder as well. Maybe we could consider adding a base component in the future (that both encoder and decoder would inherit from)

if self.training and self.layer_drop_p > 0.0:
prob_dist = torch.rand(
len(self.layers), generator=self.generator, device=CPU
)
else:
prob_dist = None

for idx, m in enumerate(self.layers):
drop = prob_dist is not None and float(prob_dist[idx]) <= self.layer_drop_p

yield m, drop

def extra_repr(self) -> str:
""":meta private:"""
s = super().extra_repr()
Expand All @@ -215,4 +245,24 @@ def extra_repr(self) -> str:

s = f"{s}, self_attn_mask_factory={self_attn_mask_factory}"

if self.layer_drop_p > 0.0:
s = f"{s}, layer_drop_p={self.layer_drop_p:G}"

return f"{s}, norm_order={self.norm_order}"


# mypy: disable-error-code="no-any-return,override"


def _record_drop_for_backward(x: Tensor, dropped_output: Tensor) -> Tensor:
return _RecordDropForBackwardFunction.apply(x, dropped_output)


class _RecordDropForBackwardFunction(Function):
@staticmethod
def forward(ctx: Any, x: Tensor, dropped_output: Tensor) -> Tensor:
return x

@staticmethod
def backward(ctx: Any, grad_output: Tensor) -> Tuple[Tensor, Tensor]:
return grad_output, torch.zeros_like(grad_output)
Comment on lines +267 to +268
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smartly done! The gradient with respect to x is going to be just the same as the gradient of the output, since this is just the identity function when you drop layers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, what is the advantage of using this method over PyTorch hooks?

Loading
Loading