Skip to content

Commit

Permalink
add forward return output
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Feb 25, 2022
1 parent 7b5e1e5 commit 264281d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions flearn/common/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def forward(self, data, target):

loss = self.criterion(output, target)
iter_acc = self.metrics(output, target)
return loss, iter_acc
return output, loss, iter_acc

def batch(self, data, target):
"""训练/测试每个batch的数据
Expand All @@ -104,7 +104,7 @@ def batch(self, data, target):
float : iter_acc
对应batch的accuracy
"""
loss, iter_acc = self.forward(data, target)
_, loss, iter_acc = self.forward(data, target)
if self.model.training:
loss += self.fed_loss()
self.optimizer.zero_grad()
Expand Down

0 comments on commit 264281d

Please sign in to comment.