diff --git a/char-rnn-generation/train.py b/char-rnn-generation/train.py index ab46538..db99df2 100644 --- a/char-rnn-generation/train.py +++ b/char-rnn-generation/train.py @@ -46,12 +46,12 @@ def train(inp, target): for c in range(args.chunk_len): output, hidden = decoder(inp[c], hidden) - loss += criterion(output, target[c]) + loss += criterion(output, target[c].view([1])) loss.backward() decoder_optimizer.step() - return loss.data[0] / args.chunk_len + return loss.data.item() / args.chunk_len def save(): save_filename = os.path.splitext(os.path.basename(args.filename))[0] + '.pt'