From 7facb8c992135e2e4fd45ebedbf61fdc15144bbd Mon Sep 17 00:00:00 2001 From: Ruben Date: Tue, 18 Dec 2018 10:09:57 +0100 Subject: [PATCH 1/2] Fix CrossEntropyLoss input dimension --- char-rnn-generation/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/char-rnn-generation/train.py b/char-rnn-generation/train.py index ab46538..9e41d39 100644 --- a/char-rnn-generation/train.py +++ b/char-rnn-generation/train.py @@ -46,7 +46,7 @@ def train(inp, target): for c in range(args.chunk_len): output, hidden = decoder(inp[c], hidden) - loss += criterion(output, target[c]) + loss += criterion(output, target[c].view([1])) loss.backward() decoder_optimizer.step() From dcf1c21dfb040f07501a95634d25395a59a211dd Mon Sep 17 00:00:00 2001 From: Ruben Date: Tue, 18 Dec 2018 10:27:52 +0100 Subject: [PATCH 2/2] Use .item() to acces element in 0-dim tensor --- char-rnn-generation/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/char-rnn-generation/train.py b/char-rnn-generation/train.py index 9e41d39..db99df2 100644 --- a/char-rnn-generation/train.py +++ b/char-rnn-generation/train.py @@ -51,7 +51,7 @@ def train(inp, target): loss.backward() decoder_optimizer.step() - return loss.data[0] / args.chunk_len + return loss.data.item() / args.chunk_len def save(): save_filename = os.path.splitext(os.path.basename(args.filename))[0] + '.pt'