From 64cb54784058c96bcd7bd48c5818df41029e2c75 Mon Sep 17 00:00:00 2001 From: David Dale Date: Wed, 24 Apr 2024 14:37:24 +0000 Subject: [PATCH 1/3] add truncation in text translator Signed-off-by: David Dale --- src/fairseq2/generation/text.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/fairseq2/generation/text.py b/src/fairseq2/generation/text.py index e1dfa43c0..38267c2b5 100644 --- a/src/fairseq2/generation/text.py +++ b/src/fairseq2/generation/text.py @@ -153,6 +153,7 @@ class TextTranslator: _converter: SequenceToTextConverter _pad_idx: int _source_text_encoder: TextTokenEncoder + max_src_len: Optional[int] def __init__( self, @@ -160,6 +161,7 @@ def __init__( tokenizer: TextTokenizer, source_lang: Optional[str] = None, target_lang: Optional[str] = None, + max_src_len: Optional[int] = None, ) -> None: """ :param generator: @@ -170,6 +172,8 @@ def __init__( The source language. :param target_lang: The target language. + :param max_src_len: + The number of tokens above which the source sequence gets truncated (None for no truncation) """ self._converter = SequenceToTextConverter( generator, tokenizer, "translation", target_lang @@ -188,6 +192,7 @@ def __init__( self._source_text_encoder = tokenizer.create_encoder( task="translation", lang=source_lang, mode="source", device=device ) + self.max_src_len = max_src_len def __call__(self, source_text: str) -> Tuple[str, Seq2SeqGeneratorOutput]: """ @@ -200,6 +205,9 @@ def __call__(self, source_text: str) -> Tuple[str, Seq2SeqGeneratorOutput]: """ source_seq = self._source_text_encoder(source_text) + if self.max_src_len and source_seq.shape[0] > self.max_src_len: + source_seq = source_seq[: self.max_src_len] + return self._converter(source_seq) def batch_translate( @@ -220,6 +228,9 @@ def batch_translate( source_seq_list = [self._source_text_encoder(t) for t in source_texts] + if self.max_src_len: + source_seq_list = [seq[: self.max_src_len] for seq in source_seq_list] + source_seqs, source_padding_mask = pad_seqs(source_seq_list, self._pad_idx) return self._converter.batch_convert(source_seqs, source_padding_mask) From f9804f722e17ace603c91c592ad2f9f4acc1940e Mon Sep 17 00:00:00 2001 From: David Dale Date: Fri, 26 Apr 2024 11:32:28 +0000 Subject: [PATCH 2/3] add a test case for truncation Signed-off-by: David Dale --- tests/integration/models/test_nllb.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/integration/models/test_nllb.py b/tests/integration/models/test_nllb.py index ded0d680e..d5c9d8a49 100644 --- a/tests/integration/models/test_nllb.py +++ b/tests/integration/models/test_nllb.py @@ -6,6 +6,7 @@ from typing import Final +import pytest import torch from fairseq2.generation import BeamSearchSeq2SeqGenerator, TextTranslator @@ -25,7 +26,7 @@ def test_load_dense_distill_600m() -> None: tokenizer = load_nllb_tokenizer(model_name, progress=False) - generator = BeamSearchSeq2SeqGenerator(model, echo_prompt=True) + generator = BeamSearchSeq2SeqGenerator(model, echo_prompt=True, max_seq_len=128) translator = TextTranslator( generator, tokenizer, source_lang="eng_Latn", target_lang="deu_Latn" @@ -34,3 +35,18 @@ def test_load_dense_distill_600m() -> None: text, _ = translator(ENG_SENTENCE) assert text == DEU_SENTENCE + + # testing that truncation prevents length-related errors + with pytest.raises( + ValueError, match="The input sequence length must be less than or equal" + ): + text, _ = translator(ENG_SENTENCE * 20) + + translator = TextTranslator( + generator, + tokenizer, + source_lang="eng_Latn", + target_lang="deu_Latn", + max_src_len=1024, + ) + text, _ = translator(ENG_SENTENCE * 20) From e980dbc6d7d2ceaa0dc4e02b7d4f427f5e95510f Mon Sep 17 00:00:00 2001 From: David Dale Date: Fri, 26 Apr 2024 12:29:27 +0000 Subject: [PATCH 3/3] rename the parameter Signed-off-by: David Dale --- src/fairseq2/generation/text.py | 18 +++++++++--------- tests/integration/models/test_nllb.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/fairseq2/generation/text.py b/src/fairseq2/generation/text.py index 38267c2b5..d2f90a437 100644 --- a/src/fairseq2/generation/text.py +++ b/src/fairseq2/generation/text.py @@ -153,7 +153,7 @@ class TextTranslator: _converter: SequenceToTextConverter _pad_idx: int _source_text_encoder: TextTokenEncoder - max_src_len: Optional[int] + _max_source_len: Optional[int] def __init__( self, @@ -161,7 +161,7 @@ def __init__( tokenizer: TextTokenizer, source_lang: Optional[str] = None, target_lang: Optional[str] = None, - max_src_len: Optional[int] = None, + max_source_len: Optional[int] = None, ) -> None: """ :param generator: @@ -172,8 +172,8 @@ def __init__( The source language. :param target_lang: The target language. - :param max_src_len: - The number of tokens above which the source sequence gets truncated (None for no truncation) + :param max_source_len: + The number of tokens above which the source sequence gets truncated (None or 0 for no truncation) """ self._converter = SequenceToTextConverter( generator, tokenizer, "translation", target_lang @@ -192,7 +192,7 @@ def __init__( self._source_text_encoder = tokenizer.create_encoder( task="translation", lang=source_lang, mode="source", device=device ) - self.max_src_len = max_src_len + self._max_source_len = max_source_len def __call__(self, source_text: str) -> Tuple[str, Seq2SeqGeneratorOutput]: """ @@ -205,8 +205,8 @@ def __call__(self, source_text: str) -> Tuple[str, Seq2SeqGeneratorOutput]: """ source_seq = self._source_text_encoder(source_text) - if self.max_src_len and source_seq.shape[0] > self.max_src_len: - source_seq = source_seq[: self.max_src_len] + if self._max_source_len and source_seq.shape[0] > self._max_source_len: + source_seq = source_seq[: self._max_source_len] return self._converter(source_seq) @@ -228,8 +228,8 @@ def batch_translate( source_seq_list = [self._source_text_encoder(t) for t in source_texts] - if self.max_src_len: - source_seq_list = [seq[: self.max_src_len] for seq in source_seq_list] + if self._max_source_len: + source_seq_list = [seq[: self._max_source_len] for seq in source_seq_list] source_seqs, source_padding_mask = pad_seqs(source_seq_list, self._pad_idx) diff --git a/tests/integration/models/test_nllb.py b/tests/integration/models/test_nllb.py index d5c9d8a49..acddddce8 100644 --- a/tests/integration/models/test_nllb.py +++ b/tests/integration/models/test_nllb.py @@ -47,6 +47,6 @@ def test_load_dense_distill_600m() -> None: tokenizer, source_lang="eng_Latn", target_lang="deu_Latn", - max_src_len=1024, + max_source_len=1024, ) text, _ = translator(ENG_SENTENCE * 20)