Skip to content

Commit

Permalink
Refactor MultiheadAttentionState
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Oct 6, 2023
1 parent e081284 commit 5d4623f
Show file tree
Hide file tree
Showing 16 changed files with 121 additions and 75 deletions.
2 changes: 1 addition & 1 deletion src/fairseq2/models/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def decode(
the same index in ``seqs``. *Shape:* :math:`(N)`, where :math:`N` is
the batch size.
:param state_bag:
The state bag to use for incremental evaluation.
The state bag to use for incremental decoding.
:returns:
- The decoder output. *Shape:* :math:`(N,S,M)`, where :math:`N` is
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/models/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def decode(
:math:`(N,S_{enc})`, where :math:`N` is the batch size and
:math:`S_{enc}` is the encoder output sequence length.
:param state_bag:
The state bag to use for incremental evaluation.
The state bag to use for incremental decoding.
:returns:
- The decoder output. *Shape:* :math:`(N,S_{tgt},M)`, where
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/models/s2t_transformer/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def forward(
) -> Tuple[Tensor, Optional[Tensor]]:
if state_bag is not None:
raise ValueError(
"`S2TTransformerFrontend` does not support incremental evaluation."
"`S2TTransformerFrontend` does not support incremental decoding."
)

if self.feature_extractor is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/models/transformer/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def forward(
the same index in ``seqs``. *Shape:* :math:`(N)`, where :math:`N` is
the batch size.
:param state_bag:
The state bag to use for incremental evaluation.
The state bag to use for incremental decoding.
:returns:
- The processed sequences to pass to a Transformer encoder/decoder.
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/models/utils/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(
:param download_manager:
The download manager to use to download model checkpoints.
:param model_factory:
The callable responsible for constructing models.
The factory to use to construct models.
:param archs:
The registry containing all supported model architectures.
:param restrict_checkpoints:
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/models/wav2vec2/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def forward(
) -> Tuple[Tensor, Optional[Tensor]]:
if state_bag is not None:
raise ValueError(
"`Wav2Vec2Frontend` does not support incremental evaluation."
"`Wav2Vec2Frontend` does not support incremental decoding."
)

seqs, seq_lens = self.extract_features(seqs, seq_lens)
Expand Down
4 changes: 2 additions & 2 deletions src/fairseq2/models/wav2vec2/position_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _do_forward(
""":meta private:"""
if state_bag is not None:
raise ValueError(
"`Wav2Vec2PositionEncoder` does not support incremental evaluation."
"`Wav2Vec2PositionEncoder` does not support incremental decoding."
)

# We have to ensure that the padded elements are correctly set to
Expand Down Expand Up @@ -178,7 +178,7 @@ def _do_forward(
""":meta private:"""
if state_bag is not None:
raise ValueError(
"`Wav2Vec2StackedPositionEncoder` does not support incremental evaluation."
"`Wav2Vec2StackedPositionEncoder` does not support incremental decoding."
)

# We have to ensure that the padded elements are correctly set to
Expand Down
17 changes: 8 additions & 9 deletions src/fairseq2/nn/incremental_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@


class IncrementalState(ABC):
"""Holds the state of a module during an incremental evaluation.
"""Holds the state of a module during incremental decoding.
Incremental evaluation is a special mode where the module only receives an
input corresponding to the previous output and must produce the next output
incrementally. Thus the module must cache any long-term state that is needed
about the sequence.
Incremental decoding is a special mode at inference time where the module
only receives an input corresponding to the previous output and must produce
the next output incrementally. Thus the module must cache any long-term
state that is needed about the sequence.
"""

@abstractmethod
Expand All @@ -39,7 +39,7 @@ def reorder(self, new_order: Tensor) -> None:


class IncrementalStateBag:
"""Holds the module states during an incremental evaluation."""
"""Holds the module states during incremental decoding."""

step: int
max_num_steps: int
Expand All @@ -59,9 +59,8 @@ def __init__(self, max_num_steps: int) -> None:
def increment_step(self, delta: int = 1) -> None:
"""Increment the step.
This method should be called after every incremental evaluation (e.g.
beam search) step. It is used by modules to keep track of the position
in the sequence.
This method should be called after every decoding step. It is used by
modules to keep track of the position in the sequence.
:param delta:
The value by which to increment the step.
Expand Down
4 changes: 2 additions & 2 deletions src/fairseq2/nn/position_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def forward(
:math:`*` is any number of batch dimensions including none and
:math:`S` is the sequence length.
:param state_bag:
The state bag to use for incremental evaluation.
The state bag to use for incremental decoding.
:returns:
The input sequences with positional information encoded. *Shape:*
Expand Down Expand Up @@ -92,7 +92,7 @@ def _do_forward(
:math:`*` is any number of batch dimensions including none and
:math:`S` is the sequence length.
:param state_bag:
The state bag to use for incremental evaluation.
The state bag to use for incremental decoding.
:returns:
The input sequences with positional information encoded. *Shape:*
Expand Down
11 changes: 10 additions & 1 deletion src/fairseq2/nn/transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,20 @@
from fairseq2.nn.transformer.multihead_attention import (
AttentionWeightHook as AttentionWeightHook,
)
from fairseq2.nn.transformer.multihead_attention import (
EncoderDecoderAttentionState as EncoderDecoderAttentionState,
)
from fairseq2.nn.transformer.multihead_attention import (
GlobalSelfAttentionState as GlobalSelfAttentionState,
)
from fairseq2.nn.transformer.multihead_attention import (
MultiheadAttention as MultiheadAttention,
)
from fairseq2.nn.transformer.multihead_attention import (
MultiheadAttentionState as MultiheadAttentionState,
SelfAttentionState as SelfAttentionState,
)
from fairseq2.nn.transformer.multihead_attention import (
SelfAttentionStateFactory as SelfAttentionStateFactory,
)
from fairseq2.nn.transformer.multihead_attention import (
StandardMultiheadAttention as StandardMultiheadAttention,
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/nn/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def _naive_scaled_dot_product_attention(


class SDPAFactory(Protocol):
"""Creates instances of :class:`SDPA`."""
"""Constructs instances of :class:`SDPA`."""

def __call__(self, *, attn_dropout_p: float) -> SDPA:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/nn/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def forward(
If not ``None``, it will be called with the output of each layer in
the decoder stack.
:param state_bag:
The state bag to use for incremental evaluation.
The state bag to use for incremental decoding.
:returns:
- The decoder output. *Shape:* Same as ``seqs``.
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/nn/transformer/decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def forward(
:math:`(N,S_{enc})`, where :math:`N` is the batch size and
:math:`S_{enc}` is the encoder output sequence length.
:param state_bag:
The state bag to use for incremental evaluation.
The state bag to use for incremental decoding.
:returns:
- The decoder layer output. *Shape:* Same as ``seqs``.
Expand Down
4 changes: 2 additions & 2 deletions src/fairseq2/nn/transformer/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __init__(
*,
gate_activation: Optional[Module] = None,
inner_dim_scale: float = 2 / 3,
inner_dim_to_multiple: int = 2,
inner_dim_to_multiple: int = 1,
inner_dropout_p: float = 0.0,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
Expand Down Expand Up @@ -190,7 +190,7 @@ def __init__(

self.inner_dim_to_multiple = inner_dim_to_multiple

if inner_dim_to_multiple != 1.0:
if inner_dim_to_multiple != 1:
inner_dim = inner_dim_to_multiple * (
(inner_dim + inner_dim_to_multiple - 1) // inner_dim_to_multiple
)
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/nn/transformer/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class LayerNormFactory(Protocol):
"""Creates instances of :class:`LayerNorm`."""
"""Constructs instances of :class:`LayerNorm`."""

def __call__(
self,
Expand Down
Loading

0 comments on commit 5d4623f

Please sign in to comment.