Skip to content

Commit

Permalink
Fix accuracy calculation in batch_eth_mnist.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Hananel-Hazan committed Mar 22, 2024
1 parent 775abcf commit d9fed4a
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion examples/mnist/batch_eth_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,9 @@
)

# Compute network accuracy according to available classification strategies.
accuracy["all"] += float(torch.sum(label_tensor.long() == all_activity_pred.to(device)).item())
accuracy["all"] += float(
torch.sum(label_tensor.long() == all_activity_pred.to(device)).item()
)
accuracy["proportion"] += float(
torch.sum(label_tensor.long() == proportion_pred.to(device)).item()
)
Expand Down

0 comments on commit d9fed4a

Please sign in to comment.