Skip to content

Commit

Permalink
add forward return output
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Feb 25, 2022
1 parent c3d644a commit 1977dd8
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions flearn/common/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def metrics(output, target):
def forward(self, data, target):
data, target = data.to(self.device), target.to(self.device)
output = self.model(data)
return output

loss = self.criterion(output, target)
iter_acc = self.metrics(output, target)
return output, loss, iter_acc

def batch(self, data, target):
"""训练/测试每个batch的数据
Expand All @@ -101,8 +104,7 @@ def batch(self, data, target):
float : iter_acc
对应batch的accuracy
"""
output = self.forward(data, target)
loss = self.criterion(output, target)
_, loss, iter_acc = self.forward(data, target)

if self.model.training:
loss += self.fed_loss()
Expand All @@ -112,7 +114,6 @@ def batch(self, data, target):

self.update_info()

iter_acc = self.metrics(output, target)
return loss.data.item(), iter_acc

@show_f
Expand Down

0 comments on commit 1977dd8

Please sign in to comment.