Skip to content

Commit

Permalink
Introduce fast repeat_interleave (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Sep 21, 2023
1 parent a012f91 commit ec77833
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 14 deletions.
6 changes: 4 additions & 2 deletions src/fairseq2/generation/sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from fairseq2.generation.logits_processor import LogitsProcessor
from fairseq2.models.encoder_decoder import Seq2SeqDecoder
from fairseq2.nn.incremental_state import IncrementalStateBag
from fairseq2.nn.utils.seq import pad_sequence
from fairseq2.nn.ops import pad_sequence, repeat_interleave
from fairseq2.typing import Device


Expand Down Expand Up @@ -480,7 +480,9 @@ def _fan_out_encoder_output(
fan_out_indices = torch.arange(num_searches, device=encoder_output.device)

# (N) -> (N x B)
fan_out_indices = fan_out_indices.repeat_interleave(self.beam_size)
fan_out_indices = repeat_interleave(
fan_out_indices, dim=0, repeat=self.beam_size
)

# (N, S_enc, M) -> (N x B, S_enc, M)
encoder_output = encoder_output.index_select(dim=0, index=fan_out_indices)
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/generation/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
SequenceGeneratorOutput,
)
from fairseq2.models.encoder_decoder import EncoderDecoderModel
from fairseq2.nn.ops import pad_sequence
from fairseq2.nn.utils.module import infer_device
from fairseq2.nn.utils.seq import pad_sequence


class SequenceToTextGeneratorBase:
Expand Down
3 changes: 2 additions & 1 deletion src/fairseq2/models/wav2vec2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
VectorQuantizer,
VectorQuantizerOutput,
)
from fairseq2.nn.ops import repeat_interleave
from fairseq2.nn.projection import Linear
from fairseq2.nn.transformer import TransformerEncoder
from fairseq2.nn.utils.module import check_model_dim
Expand Down Expand Up @@ -230,7 +231,7 @@ def _sample_distractors(self, targets: Tensor) -> Tensor:
indices = torch.arange(seq_len, device=device)

# (S) -> (S x L)
indices = indices.repeat_interleave(self.num_distractors)
indices = repeat_interleave(indices, dim=0, repeat=self.num_distractors)

# (N, S x L)
rand_indices = torch.randint(
Expand Down
31 changes: 31 additions & 0 deletions src/fairseq2/nn/utils/seq.py → src/fairseq2/nn/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,34 @@ def pad_sequence(
return padded_seqs, None

return padded_seqs, seq_lens


def repeat_interleave(x: Tensor, dim: int, repeat: int) -> Tensor:
"""Repeat elements of a tensor.
:param x:
The input tensor.
:param dim:
The dimension along which to repeat values.
:param repeat:
The number of repetitions.
:returns:
The repeated tensor which has the same shape as input, except along the
given axis.
.. note::
This is a lightweight version of :func:`torch.repeat_interleave` that
is faster for repetitions along a single dimension.
"""
if repeat == 1:
return x

shape = [-1] * (x.ndim + 1)

if dim < 0:
dim += x.ndim

shape[dim + 1] = repeat

return x.unsqueeze(dim + 1).expand(shape).flatten(dim, dim + 1)
9 changes: 3 additions & 6 deletions src/fairseq2/nn/position_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.nn.parameter import Parameter

from fairseq2.nn.incremental_state import IncrementalStateBag
from fairseq2.nn.ops import repeat_interleave
from fairseq2.typing import DataType, Device, finaloverride


Expand Down Expand Up @@ -367,12 +368,8 @@ def reset_non_persistent_buffers(self) -> None:
cos = torch.cos(table)
sin = torch.sin(table)

self.cos_weight[:] = torch.repeat_interleave(
cos, 2, dim=-1, output_size=self.encoding_dim
)
self.sin_weight[:] = torch.repeat_interleave(
sin, 2, dim=-1, output_size=self.encoding_dim
)
self.cos_weight[:] = repeat_interleave(cos, dim=-1, repeat=2)
self.sin_weight[:] = repeat_interleave(sin, dim=-1, repeat=2)

@finaloverride
def _do_forward(
Expand Down
5 changes: 3 additions & 2 deletions src/fairseq2/nn/transformer/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch.utils.hooks import RemovableHandle

from fairseq2.nn.incremental_state import IncrementalState, IncrementalStateBag
from fairseq2.nn.ops import repeat_interleave
from fairseq2.nn.position_encoder import PositionEncoder
from fairseq2.nn.projection import Linear, Projection
from fairseq2.nn.transformer.attention import SDPA, create_default_sdpa
Expand Down Expand Up @@ -440,9 +441,9 @@ def forward(
# With Grouped Query Attention, each key/value head is repeated.
if (num_query_groups := self.num_heads // self.num_key_value_heads) > 1:
# (N, H_kv, S_kv, K_h) -> (N, H, S_kv, K_h)
k = torch.repeat_interleave(k, dim=1, repeats=num_query_groups)
k = repeat_interleave(k, dim=1, repeat=num_query_groups)
# (N, H_kv, S_kv, K_h) -> (N, H, S_kv, V_h)
v = torch.repeat_interleave(v, dim=1, repeats=num_query_groups)
v = repeat_interleave(v, dim=1, repeat=num_query_groups)

mask_pad = 0

Expand Down
5 changes: 3 additions & 2 deletions src/fairseq2/nn/utils/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from torch import Tensor

from fairseq2.nn.ops import repeat_interleave
from fairseq2.typing import DataType, Device


Expand Down Expand Up @@ -191,7 +192,7 @@ def _compute_mask_spans(
span_start_range = row_lens - span_len + 1

# (R) -> (R x N)
span_start_range = span_start_range.repeat_interleave(num_spans)
span_start_range = repeat_interleave(span_start_range, dim=0, repeat=num_spans)

# Unlike the fairseq implementation, we do sample with replacement, which is
# more consistent with the overlap strategy.
Expand All @@ -208,7 +209,7 @@ def _compute_mask_spans(
span_offsets = span_offsets.type(dtype).view(num_rows, -1)

# (R, N) -> (R, N x L)
span_offsets = span_offsets.repeat_interleave(span_len, dim=-1)
span_offsets = repeat_interleave(span_offsets, dim=-1, repeat=span_len)

# (L)
indices = torch.arange(span_len, device=device, dtype=dtype)
Expand Down

0 comments on commit ec77833

Please sign in to comment.