From 40866e41e33c7a45422ab069e3b050f28695ceb4 Mon Sep 17 00:00:00 2001 From: Can Balioglu Date: Sat, 27 Apr 2024 02:24:04 +0000 Subject: [PATCH 1/2] Revise LayerDrop implementation --- src/fairseq2/nn/module_list.py | 69 ++------------------------ src/fairseq2/nn/transformer/decoder.py | 49 +++++++++++++++--- src/fairseq2/nn/transformer/encoder.py | 66 +++++++++++++++++++++--- tests/unit/nn/test_module_list.py | 50 ------------------- 4 files changed, 103 insertions(+), 131 deletions(-) delete mode 100644 tests/unit/nn/test_module_list.py diff --git a/src/fairseq2/nn/module_list.py b/src/fairseq2/nn/module_list.py index 3b980428c..7949baf03 100644 --- a/src/fairseq2/nn/module_list.py +++ b/src/fairseq2/nn/module_list.py @@ -4,47 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Iterator, Optional, final +from typing import Any, Iterable, Optional, final -import torch from torch import Generator from torch.nn import Module from torch.nn import ModuleList as TorchModuleList -from fairseq2.typing import CPU - +# compat @final class ModuleList(TorchModuleList): - """Holds submodules in a list. - - This class extends :class:`torch.nn.ModuleList` with an extra feature that - optionally drops a random number of submodules at every iteration during - training. - - Usage: - - >>> from torch.nn import Module - >>> - >>> from fairseq2.nn import ModuleList - >>> - >>> layer1 = Module() - >>> layer2 = Module() - >>> layer3 = Module() - >>> - >>> layers = ModuleList([layer1, layer2, layer3], drop_p=0.5) - >>> - >>> for layer in layers.drop_iter(): # This might iterate over layers 1 and 3. - ... x = layer(x) - >>> for layer in layers.drop_iter(): # This might iterate over all layers. - ... x = layer(x) - >>> for layer in layers.drop_iter(): # This might not iterate over any layers. - ... x = layer(x) - """ - - drop_p: float - generator: Optional[Generator] - def __init__( self, modules: Optional[Iterable[Module]] = None, @@ -52,37 +21,7 @@ def __init__( drop_p: float = 0.0, generator: Optional[Generator] = None, ) -> None: - """ - :param modules: - An iterable of modules to add. - :param drop_p: - The probability of dropping a submodule during training. - :param generator: - The random number generator. - """ super().__init__(modules) - self.drop_p = drop_p - self.generator = generator - - def drop_iter(self) -> Iterator[Module]: - """Return an iterator that drops a random set of submodules.""" - if self.drop_p > 0.0 and self.training: - prob_dist = torch.rand( - len(self), generator=self.generator, device=CPU, dtype=torch.float32 - ) - else: - prob_dist = None - - for idx, m in enumerate(super().__iter__()): - if prob_dist is None or prob_dist[idx] > self.drop_p: - yield m - - def extra_repr(self) -> str: - """:meta private:""" - s = super().extra_repr() - - if self.drop_p > 0.0: - s = f"{s}, drop_p={self.drop_p:G}" - - return s + def drop_iter(self) -> Any: + return super().__iter__() diff --git a/src/fairseq2/nn/transformer/decoder.py b/src/fairseq2/nn/transformer/decoder.py index ef894b731..b08036bf0 100644 --- a/src/fairseq2/nn/transformer/decoder.py +++ b/src/fairseq2/nn/transformer/decoder.py @@ -8,14 +8,14 @@ from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Dict, Iterable, Optional, Protocol, Tuple, final +from typing import Dict, Iterable, Iterator, Optional, Protocol, Tuple, final -from torch import Tensor -from torch.nn import Module +import torch +from torch import Generator, Tensor +from torch.nn import Module, ModuleList from torch.utils.hooks import RemovableHandle from fairseq2.nn.incremental_state import IncrementalStateBag -from fairseq2.nn.module_list import ModuleList from fairseq2.nn.normalization import LayerNorm from fairseq2.nn.padding import PaddingMask from fairseq2.nn.transformer.attention_mask import ( @@ -23,12 +23,13 @@ CausalAttentionMaskFactory, ) from fairseq2.nn.transformer.decoder_layer import TransformerDecoderLayer +from fairseq2.nn.transformer.encoder import _record_drop_for_backward from fairseq2.nn.transformer.layer_norm import ( LayerNormFactory, create_standard_layer_norm, ) from fairseq2.nn.transformer.norm_order import TransformerNormOrder -from fairseq2.typing import DataType, Device, override +from fairseq2.typing import CPU, DataType, Device, override class TransformerDecoder(Module, ABC): @@ -145,6 +146,8 @@ class StandardTransformerDecoder(TransformerDecoder): :cite:t:`https://doi.org/10.48550/arxiv.1706.03762`.""" self_attn_mask_factory: Optional[AttentionMaskFactory] + layer_drop_p: float + generator: Optional[Generator] layer_norm: Optional[LayerNorm] norm_order: TransformerNormOrder @@ -155,6 +158,7 @@ def __init__( self_attn_mask_factory: Optional[AttentionMaskFactory] = None, use_causal_attn_mask: bool = True, layer_drop_p: float = 0.0, + generator: Optional[Generator] = None, norm_order: TransformerNormOrder = TransformerNormOrder.POST, layer_norm_factory: Optional[LayerNormFactory] = None, device: Optional[Device] = None, @@ -172,12 +176,14 @@ def __init__( :param layer_drop_p: If greater than zero, applies LayerDrop to the decoder layers as described in :cite:t:`https://doi.org/10.48550/arxiv.1909.11556`. + :param generator: + The random number generator for LayerDrop. :param norm_order: The Layer Normalization order. :param layer_norm_factory: The factory to construct the Layer Normalization module. """ - layer_list = ModuleList(layers, drop_p=layer_drop_p) + layer_list = ModuleList(layers) if not layer_list: raise ValueError("`layers` must be non-empty.") @@ -197,6 +203,10 @@ def __init__( self.layers = layer_list + self.layer_drop_p = layer_drop_p + + self.generator = generator + if norm_order != TransformerNormOrder.POST: self.layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) else: @@ -228,8 +238,8 @@ def forward( seqs, keys=seqs, training=self.training, state_bag=state_bag ) - for layer_idx, layer in enumerate(self.layers.drop_iter()): - seqs, padding_mask = layer( + for layer_idx, (layer, drop) in enumerate(self._drop_iter()): + layer_output, layer_padding_mask = layer( seqs, padding_mask, self_attn_mask, @@ -238,6 +248,13 @@ def forward( state_bag=state_bag, ) + if drop: + seqs = _record_drop_for_backward(seqs, layer_output) + + continue + + seqs, padding_mask = layer_output, layer_padding_mask + for hook in self._layer_output_hooks.values(): if not hook(layer_idx, seqs, padding_mask, num_layers): break @@ -247,6 +264,19 @@ def forward( return seqs, padding_mask + def _drop_iter(self) -> Iterator[Tuple[Module, bool]]: + if self.training and self.layer_drop_p > 0.0: + prob_dist = torch.rand( + len(self.layers), generator=self.generator, device=CPU + ) + else: + prob_dist = None + + for idx, m in enumerate(self.layers): + drop = prob_dist is not None and float(prob_dist[idx]) <= self.layer_drop_p + + yield m, drop + def extra_repr(self) -> str: """:meta private:""" s = super().extra_repr() @@ -258,4 +288,7 @@ def extra_repr(self) -> str: s = f"{s}, self_attn_mask_factory={self_attn_mask_factory}" + if self.layer_drop_p > 0.0: + s = f"{s}, layer_drop_p={self.layer_drop_p}" + return f"{s}, norm_order={self.norm_order}" diff --git a/src/fairseq2/nn/transformer/encoder.py b/src/fairseq2/nn/transformer/encoder.py index 7d4168cb6..123bf758f 100644 --- a/src/fairseq2/nn/transformer/encoder.py +++ b/src/fairseq2/nn/transformer/encoder.py @@ -8,13 +8,14 @@ from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Dict, Iterable, Optional, Protocol, Tuple, final +from typing import Any, Dict, Iterable, Iterator, Optional, Protocol, Tuple, final -from torch import Tensor -from torch.nn import Module +import torch +from torch import Generator, Tensor +from torch.autograd import Function +from torch.nn import Module, ModuleList from torch.utils.hooks import RemovableHandle -from fairseq2.nn.module_list import ModuleList from fairseq2.nn.normalization import LayerNorm from fairseq2.nn.padding import PaddingMask from fairseq2.nn.transformer.attention_mask import AttentionMaskFactory @@ -24,7 +25,7 @@ create_standard_layer_norm, ) from fairseq2.nn.transformer.norm_order import TransformerNormOrder -from fairseq2.typing import DataType, Device, override +from fairseq2.typing import CPU, DataType, Device, override class TransformerEncoder(Module, ABC): @@ -125,6 +126,8 @@ class StandardTransformerEncoder(TransformerEncoder): :cite:t:`https://doi.org/10.48550/arxiv.1706.03762`.""" self_attn_mask_factory: Optional[AttentionMaskFactory] + layer_drop_p: float + generator: Optional[Generator] layer_norm: Optional[LayerNorm] norm_order: TransformerNormOrder @@ -134,6 +137,7 @@ def __init__( *, self_attn_mask_factory: Optional[AttentionMaskFactory] = None, layer_drop_p: float = 0.0, + generator: Optional[Generator] = None, norm_order: TransformerNormOrder = TransformerNormOrder.POST, layer_norm_factory: Optional[LayerNormFactory] = None, device: Optional[Device] = None, @@ -147,12 +151,14 @@ def __init__( :param layer_drop_p: If greater than zero, applies LayerDrop to the encoder layers as described in :cite:t:`https://doi.org/10.48550/arxiv.1909.11556`. + :param generator: + The random number generator for LayerDrop. :param norm_order: The Layer Normalization order. :param layer_norm_factory: The factory to construct the Layer Normalization module. """ - layer_list = ModuleList(layers, drop_p=layer_drop_p) + layer_list = ModuleList(layers) if not layer_list: raise ValueError("`layers` must be non-empty.") @@ -167,6 +173,10 @@ def __init__( self.layers = layer_list + self.layer_drop_p = layer_drop_p + + self.generator = generator + if norm_order != TransformerNormOrder.POST: self.layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) else: @@ -192,8 +202,15 @@ def forward( seqs, keys=seqs, training=self.training ) - for layer_idx, layer in enumerate(self.layers.drop_iter()): - seqs, padding_mask = layer(seqs, padding_mask, self_attn_mask) + for layer_idx, (layer, drop) in enumerate(self._drop_iter()): + layer_output, layer_padding_mask = layer(seqs, padding_mask, self_attn_mask) + + if drop: + seqs = _record_drop_for_backward(seqs, layer_output) + + continue + + seqs, padding_mask = layer_output, layer_padding_mask for hook in self._layer_output_hooks.values(): if not hook(layer_idx, seqs, padding_mask, num_layers): @@ -204,6 +221,19 @@ def forward( return seqs, padding_mask + def _drop_iter(self) -> Iterator[Tuple[Module, bool]]: + if self.training and self.layer_drop_p > 0.0: + prob_dist = torch.rand( + len(self.layers), generator=self.generator, device=CPU + ) + else: + prob_dist = None + + for idx, m in enumerate(self.layers): + drop = prob_dist is not None and float(prob_dist[idx]) <= self.layer_drop_p + + yield m, drop + def extra_repr(self) -> str: """:meta private:""" s = super().extra_repr() @@ -215,4 +245,24 @@ def extra_repr(self) -> str: s = f"{s}, self_attn_mask_factory={self_attn_mask_factory}" + if self.layer_drop_p > 0.0: + s = f"{s}, layer_drop_p={self.layer_drop_p}" + return f"{s}, norm_order={self.norm_order}" + + +# mypy: disable-error-code="no-any-return,override" + + +def _record_drop_for_backward(x: Tensor, dropped_output: Tensor) -> Tensor: + return _RecordDropForBackwardFunction.apply(x, dropped_output) + + +class _RecordDropForBackwardFunction(Function): + @staticmethod + def forward(ctx: Any, x: Tensor, dropped_output: Tensor) -> Tensor: + return x + + @staticmethod + def backward(ctx: Any, grad_output: Tensor) -> Tuple[Tensor, Tensor]: + return grad_output, torch.zeros_like(grad_output) diff --git a/tests/unit/nn/test_module_list.py b/tests/unit/nn/test_module_list.py deleted file mode 100644 index 81c7bf2d8..000000000 --- a/tests/unit/nn/test_module_list.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import pytest -from torch.nn import Linear - -from fairseq2.nn import ModuleList - - -class TestModuleList: - def test_iter_returns_no_modules_when_drop_p_is_one(self) -> None: - modules = [Linear(10, 10), Linear(10, 10), Linear(10, 10), Linear(10, 10)] - - m = ModuleList(modules, drop_p=1.0) - - with pytest.raises(StopIteration): - next(m.drop_iter()) - - def test_iter_returns_all_modules_when_drop_p_is_zero(self) -> None: - modules = [Linear(10, 10), Linear(10, 10), Linear(10, 10), Linear(10, 10)] - - m = ModuleList(modules) - - count = 0 - - for m1, m2 in zip(m.drop_iter(), modules): - assert m1 is m2 - - count += 1 - - assert count == len(modules) - - def test_iter_returns_all_modules_in_eval(self) -> None: - modules = [Linear(10, 10), Linear(10, 10), Linear(10, 10), Linear(10, 10)] - - m = ModuleList(modules, drop_p=1.0) - - m.eval() - - count = 0 - - for m1, m2 in zip(m.drop_iter(), modules): - assert m1 is m2 - - count += 1 - - assert count == len(modules) From f24a3e44387b49a9a66f4eded7d8250e193a968c Mon Sep 17 00:00:00 2001 From: Can Balioglu Date: Sat, 27 Apr 2024 02:42:02 +0000 Subject: [PATCH 2/2] Update --- src/fairseq2/nn/transformer/decoder.py | 2 +- src/fairseq2/nn/transformer/encoder.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fairseq2/nn/transformer/decoder.py b/src/fairseq2/nn/transformer/decoder.py index b08036bf0..cda9f04b5 100644 --- a/src/fairseq2/nn/transformer/decoder.py +++ b/src/fairseq2/nn/transformer/decoder.py @@ -289,6 +289,6 @@ def extra_repr(self) -> str: s = f"{s}, self_attn_mask_factory={self_attn_mask_factory}" if self.layer_drop_p > 0.0: - s = f"{s}, layer_drop_p={self.layer_drop_p}" + s = f"{s}, layer_drop_p={self.layer_drop_p:G}" return f"{s}, norm_order={self.norm_order}" diff --git a/src/fairseq2/nn/transformer/encoder.py b/src/fairseq2/nn/transformer/encoder.py index 123bf758f..2fbd50590 100644 --- a/src/fairseq2/nn/transformer/encoder.py +++ b/src/fairseq2/nn/transformer/encoder.py @@ -246,7 +246,7 @@ def extra_repr(self) -> str: s = f"{s}, self_attn_mask_factory={self_attn_mask_factory}" if self.layer_drop_p > 0.0: - s = f"{s}, layer_drop_p={self.layer_drop_p}" + s = f"{s}, layer_drop_p={self.layer_drop_p:G}" return f"{s}, norm_order={self.norm_order}"