diff --git a/flearn/common/Trainer.py b/flearn/common/Trainer.py index 0715789..1026145 100644 --- a/flearn/common/Trainer.py +++ b/flearn/common/Trainer.py @@ -82,7 +82,10 @@ def metrics(output, target): def forward(self, data, target): data, target = data.to(self.device), target.to(self.device) output = self.model(data) - return output + + loss = self.criterion(output, target) + iter_acc = self.metrics(output, target) + return output, loss, iter_acc def batch(self, data, target): """训练/测试每个batch的数据 @@ -101,8 +104,7 @@ def batch(self, data, target): float : iter_acc 对应batch的accuracy """ - output = self.forward(data, target) - loss = self.criterion(output, target) + _, loss, iter_acc = self.forward(data, target) if self.model.training: loss += self.fed_loss() @@ -112,7 +114,6 @@ def batch(self, data, target): self.update_info() - iter_acc = self.metrics(output, target) return loss.data.item(), iter_acc @show_f