Skip to content

Commit

Permalink
Introduce ShawRelativePositionSDPA.
Browse files Browse the repository at this point in the history
  • Loading branch information
kauterry committed Oct 9, 2023
1 parent 4812198 commit bc07def
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/fairseq2/models/w2vbert/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _encoder_600m() -> Wav2Vec2EncoderConfig:
depthwise_conv_kernel_size=31,
causal_depthwise_conv=False,
conv_norm_type="batch_norm",
shaw_rel_position_sdpa_config=None,
)


Expand Down Expand Up @@ -78,6 +79,7 @@ def _encoder_300m() -> Wav2Vec2EncoderConfig:
depthwise_conv_kernel_size=31,
causal_depthwise_conv=False,
conv_norm_type="batch_norm",
shaw_rel_position_sdpa_config=None,
)


Expand Down
31 changes: 30 additions & 1 deletion src/fairseq2/models/wav2vec2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
MultiheadAttention,
RelativePositionalEncoding,
RelativePositionSDPA,
ShawRelativePositionSDPA,
StandardFeedForwardNetwork,
StandardMultiheadAttention,
StandardTransformerEncoder,
Expand All @@ -46,6 +47,18 @@
from fairseq2.typing import DataType, Device


@dataclass
class ShawRelativePositionSDPAConfig:
max_left_rel_position: int
"""The left clipping value for relative positions."""

max_right_rel_position: Optional[int]
"""The right clipping value for relative positions."""

use_rel_position_values: bool = False
"""Whether to use relative position values to compute relative attention."""


@dataclass
class Wav2Vec2EncoderConfig:
"""Holds the configuration of a wav2vec 2.0 encoder."""
Expand Down Expand Up @@ -97,7 +110,7 @@ class Wav2Vec2EncoderConfig:
sample_fbank_every_k: int

# Position Encoder
pos_encoder_type: Literal["conv", "relative", "rotary"]
pos_encoder_type: Literal["conv", "relative", "relative_shaw", "rotary"]
"""The type of position encoder."""

# Convolutional Position Encoder
Expand Down Expand Up @@ -146,6 +159,9 @@ class Wav2Vec2EncoderConfig:
conv_norm_type: Literal["batch_norm", "layer_norm"]
"""The type of normalization to use in the Conformer convolution module."""

shaw_rel_position_sdpa_config: Optional[ShawRelativePositionSDPAConfig]
"""The parameters for ShawRelativePositionSDPA."""


def _encoder_base() -> Wav2Vec2EncoderConfig:
layer_descs = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
Expand Down Expand Up @@ -179,6 +195,7 @@ def _encoder_base() -> Wav2Vec2EncoderConfig:
depthwise_conv_kernel_size=0,
causal_depthwise_conv=False,
conv_norm_type="batch_norm",
shaw_rel_position_sdpa_config=None,
)


Expand Down Expand Up @@ -369,6 +386,18 @@ def build_attention(self) -> MultiheadAttention:
device=self.device,
dtype=self.dtype,
)
elif self.config.pos_encoder_type == "relative_shaw":
sdpa_config = self.config.shaw_rel_position_sdpa_config
sdpa = ShawRelativePositionSDPA(
self.config.model_dim,
self.config.num_encoder_attn_heads,
sdpa_config.max_left_rel_position,
max_right_rel_position=sdpa_config.max_right_rel_position,
use_rel_position_values=sdpa_config.use_rel_position_values,
attn_dropout_p=self.config.attn_dropout_p,
device=self.device,
dtype=self.dtype,
)
else:
sdpa = create_default_sdpa(self.config.attn_dropout_p)

Expand Down
3 changes: 3 additions & 0 deletions src/fairseq2/nn/transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,6 @@
from fairseq2.nn.transformer.relative_attention import (
RelativePositionSDPA as RelativePositionSDPA,
)
from fairseq2.nn.transformer.relative_position_attention import (
ShawRelativePositionSDPA as ShawRelativePositionSDPA,
)
170 changes: 170 additions & 0 deletions src/fairseq2/nn/transformer/relative_position_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Tuple, final

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Embedding
from torch.nn.functional import dropout, softmax

from fairseq2.nn.transformer.attention import SDPA
from fairseq2.typing import DataType, Device, finaloverride


@final
class ShawRelativePositionSDPA(SDPA):
"""Computes relative position scaled dot-product attention
as described in :cite:t:`https://arxiv.org/pdf/1803.02155v2.pdf`."""

model_dim: int
num_heads: int
max_left_rel_position: int
max_right_rel_position: Optional[int]
rel_k_embedding: Embedding
rel_v_embedding: Optional[Embedding]
device: Optional[Device]

def __init__(
self,
model_dim: int,
num_heads: int,
max_left_rel_position: int,
*,
max_right_rel_position: Optional[int] = None,
use_rel_position_values: bool = False,
attn_dropout_p: float = 0.0,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> None:
"""
:param model_dim:
The dimensionality of the model.
:param: num_heads:
The number of attention heads.
:param: max_left_rel_position:
The left clipping value for relative positions.
:param: max_right_rel_position:
The right clipping value for relative positions.
:param: use_rel_position_values:
Whether to use relative position values to compute relative attention.
:param attn_dropout_p:
The dropout probability on attention weights.
"""
super().__init__(attn_dropout_p=attn_dropout_p)

if model_dim % num_heads != 0:
raise ValueError(
f"`model_dim` must be a multiple of `num_heads` ({num_heads}), but is {model_dim} instead."
)

self.model_dim = model_dim
self.num_heads = num_heads

head_dim = model_dim // num_heads

self.max_left_rel_position = max_left_rel_position
self.max_right_rel_position = (
max_right_rel_position
if max_right_rel_position is not None
else max_left_rel_position
)
num_positions = self.max_left_rel_position + 1 + self.max_right_rel_position

self.rel_k_embedding = Embedding(
num_positions, head_dim, device=device, dtype=dtype
)

if use_rel_position_values:
self.rel_v_embedding = Embedding(
num_positions, head_dim, device=device, dtype=dtype
)
else:
self.register_module("rel_v_embedding", None)

self.reset_parameters()

def reset_parameters(self) -> None:
"""Reset the parameters and buffers of the module."""
nn.init.xavier_uniform_(self.rel_k_embedding.weight)
if self.rel_v_embedding is not None:
nn.init.xavier_uniform_(self.rel_v_embedding.weight)

def rel_position_indices(self, seq_len: int) -> Tensor:
positions = torch.arange(seq_len).unsqueeze(0)
rel_dist = positions - positions.t()
rel_dist = torch.clamp(
rel_dist, -self.max_left_rel_position, self.max_right_rel_position
)
return rel_dist + self.max_left_rel_position

@finaloverride
def forward(
self,
queries: Tensor,
keys: Tensor,
values: Tensor,
*,
mask: Optional[Tensor] = None,
needs_weights: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
if queries.ndim != 4 or keys.ndim != 4 or values.ndim != 4:
raise ValueError(
"`ShawRelativePositionSDPA` can only be used as part of a multi-head attention layer and expects its input tensors to be 4 dimensional."
)

# (N, H, S, head_dim) @ (N, H, head_dim, S_kv) = (N, H, S, S_kv)
attn_weights = torch.matmul(queries, keys.transpose(-1, -2))

query_length, kv_length = queries.shape[2], keys.shape[2]

# (S_kv, S_kv)
rel_position_indices = self.rel_position_indices(kv_length)

rel_position_indices = rel_position_indices.to(device=queries.device)

# (S, S_kv, head_dim)
rel_position_keys = self.rel_k_embedding(rel_position_indices)[-query_length:]

# (N, H, S, head_dim) @ (S, S_kv, head_dim) = (N, H, S, S_kv)
rel_attn_weights = torch.einsum("nhsm,stm->nhst", queries, rel_position_keys)

attn_weights += rel_attn_weights

attn_weights = attn_weights * (queries.size(-1) ** -0.5)

if mask is not None:
attn_weights = attn_weights + mask

attn_weights = softmax(attn_weights, dim=-1, dtype=torch.float32)

attn_weights = attn_weights.type_as(queries)

if self.training and self.attn_dropout_p > 0.0:
attn_weights = dropout(attn_weights, self.attn_dropout_p)

# (N, H, S, S_kv) @ (N, H, S_kv, head_dim) = (N, H, S, head_dim)
attn = torch.matmul(attn_weights, values)

if self.rel_v_embedding is not None:
# (S, S_kv, head_dim)
rel_position_values = self.rel_v_embedding(rel_position_indices)[
-query_length:
]

# (N, H, S, S_kv) @ (S, S_kv, head_dim) = (N, H, S, head_dim)
rel_attn = torch.einsum("nhst,stm->nhsm", attn_weights, rel_position_values)

attn += rel_attn

return attn, attn_weights if needs_weights else None

def extra_repr(self) -> str:
""":meta private:"""
s = super().extra_repr()

return f"{s}, model_dim={self.model_dim}, num_heads={self.num_heads}"

0 comments on commit bc07def

Please sign in to comment.