-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add truncation in text translator #472
add truncation in text translator #472
Conversation
Signed-off-by: David Dale <[email protected]>
Signed-off-by: David Dale <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just left three nit comments.
src/fairseq2/generation/text.py
Outdated
@@ -200,6 +205,9 @@ def __call__(self, source_text: str) -> Tuple[str, Seq2SeqGeneratorOutput]: | |||
""" | |||
source_seq = self._source_text_encoder(source_text) | |||
|
|||
if self.max_src_len and source_seq.shape[0] > self.max_src_len: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although it would be a very edge case; if max_src_len
is 0, then the first condition if self.max_src_len
will be false, but maybe the user deliberately wants to trim to zero length (although does not make much sense, still expected since our parameter contract says that this parameter is ignored only if it is None
). The "right" check would be if self.max_src_len is not None and ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the source is trimmed to 0 elements, the translation is going to fail, because the encoder would consume an empty tensor. Therefore, I don't want to apply truncation to zero, and instead will change the parameter contract to explicitly truncate only to a positive number of tokens.
Signed-off-by: David Dale <[email protected]>
What does this PR do? Please describe:
This PR allows bypassing the problem of
TextTranslator
failure when the number of source tokens exceeds the maximal length supported by the translation model. As tokenization happens within theTextTranslator
, truncation, if implemented, should also happen there. The PR therefore adds amax_src_len
argument to it, and enables truncation if this argument is not empty.Does your PR introduce any breaking changes? If yes, please list them:
None
Check list: