diff --git a/flearn/common/Trainer.py b/flearn/common/Trainer.py index 348034f..ea181e1 100644 --- a/flearn/common/Trainer.py +++ b/flearn/common/Trainer.py @@ -85,7 +85,7 @@ def forward(self, data, target): loss = self.criterion(output, target) iter_acc = self.metrics(output, target) - return loss, iter_acc + return output, loss, iter_acc def batch(self, data, target): """训练/测试每个batch的数据 @@ -104,7 +104,7 @@ def batch(self, data, target): float : iter_acc 对应batch的accuracy """ - loss, iter_acc = self.forward(data, target) + _, loss, iter_acc = self.forward(data, target) if self.model.training: loss += self.fed_loss() self.optimizer.zero_grad()