Skip to content

Commit

Permalink
Merge pull request #667 from AXYZdong/master
Browse files Browse the repository at this point in the history
Update batch_eth_mnist.py
  • Loading branch information
Hananel-Hazan authored Mar 21, 2024
2 parents 60c8497 + 3b6f6f4 commit fb32876
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions examples/mnist/batch_eth_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,12 @@
# Compute network accuracy according to available classification strategies.
accuracy["all"].append(
100
* torch.sum(label_tensor.long() == all_activity_pred).item()
* torch.sum(label_tensor.long() == all_activity_pred.to(device)).item()
/ len(label_tensor)
)
accuracy["proportion"].append(
100
* torch.sum(label_tensor.long() == proportion_pred).item()
* torch.sum(label_tensor.long() == proportion_pred.to(device)).item()
/ len(label_tensor)
)

Expand Down Expand Up @@ -363,9 +363,9 @@
)

# Compute network accuracy according to available classification strategies.
accuracy["all"] += float(torch.sum(label_tensor.long() == all_activity_pred).item())
accuracy["all"] += float(torch.sum(label_tensor.long() == all_activity_pred.to(device)).item())
accuracy["proportion"] += float(
torch.sum(label_tensor.long() == proportion_pred).item()
torch.sum(label_tensor.long() == proportion_pred.to(device)).item()
)

network.reset_state_variables() # Reset state variables.
Expand Down

0 comments on commit fb32876

Please sign in to comment.