From 199cf93c0558b76cc8c00379e922c07128e5caf7 Mon Sep 17 00:00:00 2001 From: Kaushik Ram Sadagopan Date: Mon, 9 Oct 2023 18:47:01 -0400 Subject: [PATCH 1/2] Introduce ShawRelativePositionSDPA. (#90) --- bibliography.bib | 9 + src/fairseq2/models/w2vbert/builder.py | 2 + src/fairseq2/models/wav2vec2/builder.py | 33 +++- src/fairseq2/nn/transformer/__init__.py | 3 + src/fairseq2/nn/transformer/shaw_attention.py | 168 ++++++++++++++++++ 5 files changed, 214 insertions(+), 1 deletion(-) create mode 100644 src/fairseq2/nn/transformer/shaw_attention.py diff --git a/bibliography.bib b/bibliography.bib index 08f06eb00..ad3951599 100644 --- a/bibliography.bib +++ b/bibliography.bib @@ -36,6 +36,15 @@ @misc{https://doi.org/10.48550/arxiv.1706.03762 copyright = {arXiv.org perpetual, non-exclusive license} } +@misc{https://doi.org/10.48550/arxiv.1803.02155, + title={Self-Attention with Relative Position Representations}, + author={Peter Shaw and Jakob Uszkoreit and Ashish Vaswani}, + year={2018}, + eprint={1803.02155}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} + @misc{https://doi.org/10.48550/arxiv.1901.02860, title={Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context}, author={Zihang Dai and Zhilin Yang and Yiming Yang and Jaime Carbonell and Quoc V. Le and Ruslan Salakhutdinov}, diff --git a/src/fairseq2/models/w2vbert/builder.py b/src/fairseq2/models/w2vbert/builder.py index 03e784582..4093a9e0e 100644 --- a/src/fairseq2/models/w2vbert/builder.py +++ b/src/fairseq2/models/w2vbert/builder.py @@ -45,6 +45,7 @@ def _encoder_600m() -> Wav2Vec2EncoderConfig: depthwise_conv_kernel_size=31, causal_depthwise_conv=False, conv_norm_type="batch_norm", + shaw_rel_pos_sdpa_config=None, ) @@ -78,6 +79,7 @@ def _encoder_300m() -> Wav2Vec2EncoderConfig: depthwise_conv_kernel_size=31, causal_depthwise_conv=False, conv_norm_type="batch_norm", + shaw_rel_pos_sdpa_config=None, ) diff --git a/src/fairseq2/models/wav2vec2/builder.py b/src/fairseq2/models/wav2vec2/builder.py index 297b9c482..850378faa 100644 --- a/src/fairseq2/models/wav2vec2/builder.py +++ b/src/fairseq2/models/wav2vec2/builder.py @@ -34,6 +34,7 @@ MultiheadAttention, RelativePositionalEncoding, RelativePositionSDPA, + ShawRelativePositionSDPA, StandardFeedForwardNetwork, StandardMultiheadAttention, StandardTransformerEncoder, @@ -46,6 +47,20 @@ from fairseq2.typing import DataType, Device +@dataclass +class ShawRelativePositionSDPAConfig: + """Holds the configuration of the :class:ShawRelativePositionSDPA module.""" + + max_left_rel_pos: int + """The left clipping value for relative positions.""" + + max_right_rel_pos: Optional[int] + """The right clipping value for relative positions.""" + + use_rel_pos_values: bool = False + """If True, also uses relative position values to compute relative attention.""" + + @dataclass class Wav2Vec2EncoderConfig: """Holds the configuration of a wav2vec 2.0 encoder.""" @@ -97,7 +112,7 @@ class Wav2Vec2EncoderConfig: sample_fbank_every_k: int # Position Encoder - pos_encoder_type: Literal["conv", "relative", "rotary"] + pos_encoder_type: Literal["conv", "relative", "relative_shaw", "rotary"] """The type of position encoder.""" # Convolutional Position Encoder @@ -146,6 +161,9 @@ class Wav2Vec2EncoderConfig: conv_norm_type: Literal["batch_norm", "layer_norm"] """The type of normalization to use in the Conformer convolution module.""" + shaw_rel_pos_sdpa_config: Optional[ShawRelativePositionSDPAConfig] + """The parameters for ShawRelativePositionSDPA.""" + def _encoder_base() -> Wav2Vec2EncoderConfig: layer_descs = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 @@ -179,6 +197,7 @@ def _encoder_base() -> Wav2Vec2EncoderConfig: depthwise_conv_kernel_size=0, causal_depthwise_conv=False, conv_norm_type="batch_norm", + shaw_rel_pos_sdpa_config=None, ) @@ -369,6 +388,18 @@ def build_attention(self) -> MultiheadAttention: device=self.device, dtype=self.dtype, ) + elif self.config.pos_encoder_type == "relative_shaw": + sdpa_config = self.config.shaw_rel_pos_sdpa_config + sdpa = ShawRelativePositionSDPA( + self.config.model_dim, + self.config.num_encoder_attn_heads, + sdpa_config.max_left_rel_pos, + max_right_rel_pos=sdpa_config.max_right_rel_pos, + use_rel_pos_values=sdpa_config.use_rel_pos_values, + attn_dropout_p=self.config.attn_dropout_p, + device=self.device, + dtype=self.dtype, + ) else: sdpa = create_default_sdpa(self.config.attn_dropout_p) diff --git a/src/fairseq2/nn/transformer/__init__.py b/src/fairseq2/nn/transformer/__init__.py index 0fae90fa7..4c21a5e62 100644 --- a/src/fairseq2/nn/transformer/__init__.py +++ b/src/fairseq2/nn/transformer/__init__.py @@ -86,3 +86,6 @@ from fairseq2.nn.transformer.relative_attention import ( RelativePositionSDPA as RelativePositionSDPA, ) +from fairseq2.nn.transformer.shaw_attention import ( + ShawRelativePositionSDPA as ShawRelativePositionSDPA, +) diff --git a/src/fairseq2/nn/transformer/shaw_attention.py b/src/fairseq2/nn/transformer/shaw_attention.py new file mode 100644 index 000000000..7b4692542 --- /dev/null +++ b/src/fairseq2/nn/transformer/shaw_attention.py @@ -0,0 +1,168 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple, final + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn.functional import dropout, softmax + +from fairseq2.nn.embedding import StandardEmbedding +from fairseq2.nn.transformer.attention import SDPA +from fairseq2.typing import DataType, Device, finaloverride + + +@final +class ShawRelativePositionSDPA(SDPA): + """Computes relative position scaled dot-product attention + as described in :cite:t:`https://doi.org/10.48550/arxiv.1803.02155`.""" + + model_dim: int + num_heads: int + max_left_rel_pos: int + max_right_rel_pos: Optional[int] + rel_k_embed: StandardEmbedding + rel_v_embed: Optional[StandardEmbedding] + + def __init__( + self, + model_dim: int, + num_heads: int, + max_left_rel_pos: int, + *, + max_right_rel_pos: Optional[int] = None, + use_rel_pos_values: bool = False, + attn_dropout_p: float = 0.0, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ) -> None: + """ + :param model_dim: + The dimensionality of the model. + :param: num_heads: + The number of attention heads. + :param: max_left_rel_pos: + The left clipping value for relative positions. + :param: max_right_rel_pos: + The right clipping value for relative positions. + :param: use_rel_pos_values: + If True, also uses relative position values to compute relative attention. + :param attn_dropout_p: + The dropout probability on attention weights. + """ + super().__init__(attn_dropout_p=attn_dropout_p) + + if model_dim % num_heads != 0: + raise ValueError( + f"`model_dim` must be a multiple of `num_heads` ({num_heads}), but is {model_dim} instead." + ) + + self.model_dim = model_dim + self.num_heads = num_heads + + head_dim = model_dim // num_heads + + self.max_left_rel_pos = max_left_rel_pos + self.max_right_rel_pos = ( + max_right_rel_pos if max_right_rel_pos is not None else max_left_rel_pos + ) + num_pos = self.max_left_rel_pos + 1 + self.max_right_rel_pos + + self.rel_k_embed = StandardEmbedding( + num_pos, head_dim, device=device, dtype=dtype + ) + + if use_rel_pos_values: + self.rel_v_embed = StandardEmbedding( + num_pos, head_dim, device=device, dtype=dtype + ) + else: + self.register_module("rel_v_embed", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + """Reset the parameters and buffers of the module.""" + nn.init.xavier_uniform_(self.rel_k_embed.weight) + + if self.rel_v_embed is not None: + nn.init.xavier_uniform_(self.rel_v_embed.weight) + + @finaloverride + def forward( + self, + queries: Tensor, + keys: Tensor, + values: Tensor, + *, + mask: Optional[Tensor] = None, + needs_weights: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + if queries.ndim != 4 or keys.ndim != 4 or values.ndim != 4: + raise ValueError( + "`ShawRelativePositionSDPA` can only be used as part of a multi-head attention layer and expects its input tensors to be 4 dimensional." + ) + + # (N, H, S, head_dim) @ (N, H, head_dim, S_kv) = (N, H, S, S_kv) + attn_weights = torch.matmul(queries, keys.transpose(-1, -2)) + + query_len, kv_len = queries.size(2), keys.size(2) + + # (S_kv, S_kv) + rel_pos_indices = self._rel_pos_indices(kv_len, queries.device) + + # (S, S_kv, head_dim) + rel_pos_keys = self.rel_k_embed(rel_pos_indices)[-query_len:] + + # (N, H, S, head_dim) @ (S, S_kv, head_dim) = (N, H, S, S_kv) + rel_attn_weights = torch.einsum("nhsm,stm->nhst", queries, rel_pos_keys) + + attn_weights += rel_attn_weights + + attn_weights = attn_weights * (queries.size(-1) ** -0.5) + + if mask is not None: + attn_weights = attn_weights + mask + + attn_weights = softmax(attn_weights, dim=-1, dtype=torch.float32) + + attn_weights = attn_weights.type_as(queries) + + if self.training and self.attn_dropout_p > 0.0: + attn_weights = dropout(attn_weights, self.attn_dropout_p) + + # (N, H, S, S_kv) @ (N, H, S_kv, head_dim) = (N, H, S, head_dim) + attn = torch.matmul(attn_weights, values) + + if self.rel_v_embed is not None: + # (S, S_kv, head_dim) + rel_pos_values = self.rel_v_embed(rel_pos_indices)[-query_len:] + + # (N, H, S, S_kv) @ (S, S_kv, head_dim) = (N, H, S, head_dim) + rel_attn = torch.einsum("nhst,stm->nhsm", attn_weights, rel_pos_values) + + attn += rel_attn + + return attn, attn_weights if needs_weights else None + + def _rel_pos_indices(self, seq_len: int, device: Device) -> Tensor: + pos = torch.arange(seq_len, device=device).unsqueeze(0) + rel_dist = pos - pos.transpose(0, 1) + rel_dist = torch.clamp(rel_dist, -self.max_left_rel_pos, self.max_right_rel_pos) + return rel_dist + self.max_left_rel_pos + + def extra_repr(self) -> str: + """:meta private:""" + s = super().extra_repr() + + return ( + f"{s}, " + f"model_dim={self.model_dim}, " + f"num_heads={self.num_heads}, " + f"max_left_rel_pos={self.max_left_rel_pos}, " + f"max_right_rel_pos={self.max_right_rel_pos}" + ) From 8deaba4c1d4df8f7fdc66cca86a602fd2e3fb360 Mon Sep 17 00:00:00 2001 From: Guillaume Wenzek Date: Tue, 10 Oct 2023 08:12:22 -0400 Subject: [PATCH 2/2] add `.collate` for `.map(Collater)` (#67) --- .../fairseq2n/bindings/data/data_pipeline.cc | 26 ++++++++++ src/fairseq2/data/data_pipeline.py | 17 +++++- tests/unit/data/test_collater.py | 52 ++++++++++++++++++- 3 files changed, 91 insertions(+), 4 deletions(-) diff --git a/fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc b/fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc index 176d03929..07307e5eb 100644 --- a/fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc +++ b/fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc @@ -403,6 +403,32 @@ def_data_pipeline(py::module_ &data_module) py::arg("bucket_sizes"), py::arg("selector") = std::nullopt, py::arg("drop_remainder") = false) + .def( + "collate", + []( + data_pipeline_builder &self, + std::optional maybe_pad_idx, + std::int64_t pad_to_multiple, + std::optional> maybe_opt_overrides, + std::size_t num_parallel_calls) -> data_pipeline_builder & + { + auto opts = collate_options() + .maybe_pad_idx(maybe_pad_idx).pad_to_multiple(pad_to_multiple); + + std::vector opt_overrides{}; + if (maybe_opt_overrides) + opt_overrides = *std::move(maybe_opt_overrides); + + map_fn f = collater(opts, std::move(opt_overrides)); + + self = std::move(self).map(std::move(f), num_parallel_calls); + + return self; + }, + py::arg("pad_idx") = std::nullopt, + py::arg("pad_to_multiple") = 1, + py::arg("overrides") = std::nullopt, + py::arg("num_parallel_calls") = 1) .def( "filter", [](data_pipeline_builder &self, predicate_fn fn) -> data_pipeline_builder & diff --git a/src/fairseq2/data/data_pipeline.py b/src/fairseq2/data/data_pipeline.py index 063e3b676..edd8434ec 100644 --- a/src/fairseq2/data/data_pipeline.py +++ b/src/fairseq2/data/data_pipeline.py @@ -33,8 +33,9 @@ class DataPipeline(Iterable[Any]): The pipeline state can be persisted to the disk, allowing it to be resumed later. It is a Python Iterable, but it also contains the iterator states. - Calling `iter` a second time while the first iterator is still being used - will segfault or worse. + + Calling `iter` twice will create two iterators reading from the same dataloader, + and sharing the same state, so it will behave inconcistently. """ def __iter__(self) -> Iterator[Any]: @@ -155,6 +156,18 @@ def bucket_by_length( ) -> Self: """Combine examples of similar shape into batches.""" + def collate( + self, + pad_idx: Optional[int] = None, + pad_to_multiple: int = 1, + overrides: Optional[Sequence["CollateOptionsOverride"]] = None, + ) -> Self: + """Concatenate a list of inputs into a single inputs. + + This is equivalent to calling `.map(Collater())`. + See :py:class:`fairseq2.data.Collater` for details. + """ + def filter(self, predicate: Callable[[Any], Any]) -> Self: """Filter examples from data pipeline and keep only those who match ``predicate``. diff --git a/tests/unit/data/test_collater.py b/tests/unit/data/test_collater.py index 8bb08bd73..c2cc13c08 100644 --- a/tests/unit/data/test_collater.py +++ b/tests/unit/data/test_collater.py @@ -8,8 +8,8 @@ import torch from torch.nn.functional import pad -from fairseq2.data import CollateOptionsOverride, Collater -from tests.common import assert_close, assert_equal, device +from fairseq2.data import CollateOptionsOverride, Collater, read_sequence +from tests.common import assert_close, assert_equal, device, python_devel_only class TestCollater: @@ -378,3 +378,51 @@ def test_init_raises_error_when_pad_idx_is_none_and_pad_to_multiple_is_greater_t match=r"^`pad_idx` of the selector 'foo' must be set when `pad_to_multiple` is greater than 1\.$", ): Collater(overrides=[CollateOptionsOverride("foo", pad_to_multiple=2)]) + + +@pytest.mark.skipif(python_devel_only(), reason="fairseq2n 0.2.0") +@pytest.mark.parametrize("pad_to_multiple,pad_size", [(1, 0), (2, 0), (3, 2), (8, 4)]) +def test_collate_works_when_input_has_sequence_tensors( + pad_to_multiple: int, pad_size: int +) -> None: + bucket1 = [ + torch.full((4, 2), 0, device=device, dtype=torch.int64), + torch.full((4, 2), 1, device=device, dtype=torch.int64), + torch.full((4, 2), 2, device=device, dtype=torch.int64), + ] + + bucket2 = [ + [{"foo1": 0, "foo2": 1}, {"foo3": 2, "foo4": 3}], + [{"foo1": 4, "foo2": 5}, {"foo3": 6, "foo4": 7}], + [{"foo1": 8, "foo2": 9}, {"foo3": 0, "foo4": 1}], + ] + + expected1_seqs = torch.tensor( + [ + [[0, 0], [0, 0], [0, 0], [0, 0]], + [[1, 1], [1, 1], [1, 1], [1, 1]], + [[2, 2], [2, 2], [2, 2], [2, 2]], + ], + device=device, + dtype=torch.int64, + ) + expected1_seqs = pad(expected1_seqs, (0, 0, 0, pad_size), value=3) + expected1_seq_lens = torch.tensor([4, 4, 4], device=device, dtype=torch.int64) + + expected2 = [ + {"foo1": [0, 4, 8], "foo2": [1, 5, 9]}, + {"foo3": [2, 6, 0], "foo4": [3, 7, 1]}, + ] + + data = ( + read_sequence([bucket1, bucket2]) + .collate(pad_idx=3, pad_to_multiple=pad_to_multiple) + .and_return() + ) + output1, output2 = list(data) + + assert_close(output1["seqs"], expected1_seqs) + assert_equal(output1["seq_lens"], expected1_seq_lens) + assert output1["is_ragged"] == False + + assert output2 == expected2