Skip to content

Commit

Permalink
add option to specify length in transform
Browse files Browse the repository at this point in the history
  • Loading branch information
pavanchhatpar committed May 3, 2020
1 parent 80ebd98 commit cec1951
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion copynet_tf/search/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self,
self._end_index = end_index
self._max_decoding_steps = max_decoding_steps

def debug(self, name, value):
def debug(self, name, value):
self.logger.debug(f"Debug {name} {value}")

def _first_token(self,
Expand Down
10 changes: 7 additions & 3 deletions copynet_tf/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,17 @@ def _transform(self, tokenized, token2index, seq_len):
res[i, j+1] = token2index[self._end_token]
return res

def transform(self, tokenized, namespace):
def transform(self, tokenized, namespace, seq_len=None):
if namespace == 'source':
if seq_len is None:
seq_len = self._source_seq_len
return self._transform(
tokenized, self._source, self._source_seq_len)
tokenized, self._source, seq_len)
elif namespace == 'target':
if seq_len is None:
seq_len = self._target_seq_len
return self._transform(
tokenized, self._target, self._target_seq_len)
tokenized, self._target, seq_len)
else:
raise ValueError(f"Unknown namespace: {namespace}")

Expand Down

0 comments on commit cec1951

Please sign in to comment.