You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I see for the description of the input tensor in the call function: inputs: input Tensor, 2D, 1 x input_size.
Shouldn't it rather be inputs: input Tensor, 2D, batch x input_size.?
It returns something 2D of the batch size. The training is so fast on my laptop compared to a normal LSTM that I am starting to doubt whether or not if it processes the full batch I am feeding to the cell. I assume that it accepts an input of shape batch x output_dim because the output of the call contains 2D tensors of batch size.
def __call__(self, input_, state=None, scope=None):
"""Run one step of NTM.
Args:
inputs: input Tensor, 2D, 1 x input_size.
state: state Dictionary which contains M, read_w, write_w, read,
output, hidden.
scope: VariableScope for the created subgraph; defaults to class name.
Returns:
A tuple containing:
- A 2D, batch x output_dim, Tensor representing the output of the LSTM
after reading "input_" when previous state was "state".
Here output_dim is:
num_proj if num_proj was set,
num_units otherwise.
- A 2D, batch x state_size, Tensor representing the new state of LSTM
after reading "input_" when previous state was "state".
"""
I see for the description of the input tensor in the call function:
inputs: input Tensor, 2D, 1 x input_size.
Shouldn't it rather be
inputs: input Tensor, 2D, batch x input_size.
?It returns something 2D of the batch size. The training is so fast on my laptop compared to a normal LSTM that I am starting to doubt whether or not if it processes the full batch I am feeding to the cell. I assume that it accepts an input of shape
batch x output_dim
because the output of the call contains 2D tensors ofbatch
size.Found in:
https://github.com/carpedm20/NTM-tensorflow/blob/master/ntm_cell.py
The text was updated successfully, but these errors were encountered: