diff --git a/src/fairseq2/generation/text.py b/src/fairseq2/generation/text.py index e1dfa43c0..d2f90a437 100644 --- a/src/fairseq2/generation/text.py +++ b/src/fairseq2/generation/text.py @@ -153,6 +153,7 @@ class TextTranslator: _converter: SequenceToTextConverter _pad_idx: int _source_text_encoder: TextTokenEncoder + _max_source_len: Optional[int] def __init__( self, @@ -160,6 +161,7 @@ def __init__( tokenizer: TextTokenizer, source_lang: Optional[str] = None, target_lang: Optional[str] = None, + max_source_len: Optional[int] = None, ) -> None: """ :param generator: @@ -170,6 +172,8 @@ def __init__( The source language. :param target_lang: The target language. + :param max_source_len: + The number of tokens above which the source sequence gets truncated (None or 0 for no truncation) """ self._converter = SequenceToTextConverter( generator, tokenizer, "translation", target_lang @@ -188,6 +192,7 @@ def __init__( self._source_text_encoder = tokenizer.create_encoder( task="translation", lang=source_lang, mode="source", device=device ) + self._max_source_len = max_source_len def __call__(self, source_text: str) -> Tuple[str, Seq2SeqGeneratorOutput]: """ @@ -200,6 +205,9 @@ def __call__(self, source_text: str) -> Tuple[str, Seq2SeqGeneratorOutput]: """ source_seq = self._source_text_encoder(source_text) + if self._max_source_len and source_seq.shape[0] > self._max_source_len: + source_seq = source_seq[: self._max_source_len] + return self._converter(source_seq) def batch_translate( @@ -220,6 +228,9 @@ def batch_translate( source_seq_list = [self._source_text_encoder(t) for t in source_texts] + if self._max_source_len: + source_seq_list = [seq[: self._max_source_len] for seq in source_seq_list] + source_seqs, source_padding_mask = pad_seqs(source_seq_list, self._pad_idx) return self._converter.batch_convert(source_seqs, source_padding_mask) diff --git a/tests/integration/models/test_nllb.py b/tests/integration/models/test_nllb.py index ded0d680e..acddddce8 100644 --- a/tests/integration/models/test_nllb.py +++ b/tests/integration/models/test_nllb.py @@ -6,6 +6,7 @@ from typing import Final +import pytest import torch from fairseq2.generation import BeamSearchSeq2SeqGenerator, TextTranslator @@ -25,7 +26,7 @@ def test_load_dense_distill_600m() -> None: tokenizer = load_nllb_tokenizer(model_name, progress=False) - generator = BeamSearchSeq2SeqGenerator(model, echo_prompt=True) + generator = BeamSearchSeq2SeqGenerator(model, echo_prompt=True, max_seq_len=128) translator = TextTranslator( generator, tokenizer, source_lang="eng_Latn", target_lang="deu_Latn" @@ -34,3 +35,18 @@ def test_load_dense_distill_600m() -> None: text, _ = translator(ENG_SENTENCE) assert text == DEU_SENTENCE + + # testing that truncation prevents length-related errors + with pytest.raises( + ValueError, match="The input sequence length must be less than or equal" + ): + text, _ = translator(ENG_SENTENCE * 20) + + translator = TextTranslator( + generator, + tokenizer, + source_lang="eng_Latn", + target_lang="deu_Latn", + max_source_len=1024, + ) + text, _ = translator(ENG_SENTENCE * 20)