Skip to content

Commit

Permalink
remove criterion and metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Feb 25, 2022
1 parent 264281d commit c3d644a
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions flearn/common/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,7 @@ def metrics(output, target):
def forward(self, data, target):
data, target = data.to(self.device), target.to(self.device)
output = self.model(data)

loss = self.criterion(output, target)
iter_acc = self.metrics(output, target)
return output, loss, iter_acc
return output

def batch(self, data, target):
"""训练/测试每个batch的数据
Expand All @@ -104,14 +101,18 @@ def batch(self, data, target):
float : iter_acc
对应batch的accuracy
"""
_, loss, iter_acc = self.forward(data, target)
output = self.forward(data, target)
loss = self.criterion(output, target)

if self.model.training:
loss += self.fed_loss()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

self.update_info()

iter_acc = self.metrics(output, target)
return loss.data.item(), iter_acc

@show_f
Expand Down

0 comments on commit c3d644a

Please sign in to comment.