From cec19515dfa62a5b4df953be13c8025bb5230920 Mon Sep 17 00:00:00 2001 From: pavanchhatpar Date: Sun, 3 May 2020 03:54:06 -0400 Subject: [PATCH] add option to specify length in transform --- copynet_tf/search/beam_search.py | 2 +- copynet_tf/vocab.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/copynet_tf/search/beam_search.py b/copynet_tf/search/beam_search.py index 4a22fec..1aa62f8 100644 --- a/copynet_tf/search/beam_search.py +++ b/copynet_tf/search/beam_search.py @@ -15,7 +15,7 @@ def __init__(self, self._end_index = end_index self._max_decoding_steps = max_decoding_steps - def debug(self, name, value): + def debug(self, name, value): self.logger.debug(f"Debug {name} {value}") def _first_token(self, diff --git a/copynet_tf/vocab.py b/copynet_tf/vocab.py index d5d9935..8310834 100644 --- a/copynet_tf/vocab.py +++ b/copynet_tf/vocab.py @@ -111,13 +111,17 @@ def _transform(self, tokenized, token2index, seq_len): res[i, j+1] = token2index[self._end_token] return res - def transform(self, tokenized, namespace): + def transform(self, tokenized, namespace, seq_len=None): if namespace == 'source': + if seq_len is None: + seq_len = self._source_seq_len return self._transform( - tokenized, self._source, self._source_seq_len) + tokenized, self._source, seq_len) elif namespace == 'target': + if seq_len is None: + seq_len = self._target_seq_len return self._transform( - tokenized, self._target, self._target_seq_len) + tokenized, self._target, seq_len) else: raise ValueError(f"Unknown namespace: {namespace}")