From de2dfa5db514954d528c0ac9e43220a6f79cdcf9 Mon Sep 17 00:00:00 2001 From: Patrik Date: Wed, 27 Nov 2019 00:14:29 +0100 Subject: [PATCH] #18 fix --- src/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/train.py b/src/train.py index 1c5b0ea..1ce913d 100644 --- a/src/train.py +++ b/src/train.py @@ -306,7 +306,7 @@ def compute_loss(outputs, targets, ignore_idx): correct = (targets_view == preds) & not_ignore correct = correct.float().sum() - acc = (correct / num_targets).item() + acc = correct / num_targets loss = loss / num_targets ppl = torch.exp(loss).item() @@ -555,7 +555,7 @@ def forward_step(batch): # for more accurate logging acc = reduce_tensor(acc) - return loss, acc, ppl + return loss, acc.item(), ppl def train_step(batch): """