From 1977dd8139e05e7c548da98d428f574b49aba5d9 Mon Sep 17 00:00:00 2001 From: wnma3mz Date: Fri, 25 Feb 2022 12:11:11 +0800 Subject: [PATCH] add forward return output --- flearn/common/Trainer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/flearn/common/Trainer.py b/flearn/common/Trainer.py index 0715789..1026145 100644 --- a/flearn/common/Trainer.py +++ b/flearn/common/Trainer.py @@ -82,7 +82,10 @@ def metrics(output, target): def forward(self, data, target): data, target = data.to(self.device), target.to(self.device) output = self.model(data) - return output + + loss = self.criterion(output, target) + iter_acc = self.metrics(output, target) + return output, loss, iter_acc def batch(self, data, target): """训练/测试每个batch的数据 @@ -101,8 +104,7 @@ def batch(self, data, target): float : iter_acc 对应batch的accuracy """ - output = self.forward(data, target) - loss = self.criterion(output, target) + _, loss, iter_acc = self.forward(data, target) if self.model.training: loss += self.fed_loss() @@ -112,7 +114,6 @@ def batch(self, data, target): self.update_info() - iter_acc = self.metrics(output, target) return loss.data.item(), iter_acc @show_f