Skip to content

Commit

Permalink
Add more unit tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
AnanyaKumar committed Aug 23, 2022
1 parent 4398fd8 commit ef4ebb5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
Binary file added calibration/.util_test.py.swp
Binary file not shown.
10 changes: 10 additions & 0 deletions calibration/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,11 @@ def test_ece(self):
true_ece = 4/6.0 * (1 - (0.8+0.85+0.9+1.0)/4)
pred_ece = get_ece(probs, labels, num_bins=3)
self.assertAlmostEqual(pred_ece, true_ece)
probs = [0.6, 0.7, 0.8, 0.9]
labels = [0, 0, 1, 1]
pred_ece = get_ece(probs, labels, num_bins=2)
true_ece = 0.25
self.assertAlmostEqual(pred_ece, true_ece)

@parameterized.expand([
[[0.1], [1], 1, 0.9],
Expand All @@ -266,12 +271,17 @@ def test_ece(self):
[[0.1, 0.1, 0.1, 0.1, 0.7], [0, 1, 0, 0, 1], 2, 0.15*4/5+0.3*1/5],
[[0.1, 0.7, 0.5, 0.9], [0, 1, 0, 1], 2, 0.25],
[[0.1, 0.7, 0.5, 0.9], [0, 1, 0, 1], 4, 0.25],
[[0.6, 0.7, 0.8, 0.9], [0, 0, 1, 1], 2, 0.4],
[[0.1, 0.7, 0.5, 0.9], [0, 1, 1, 1], 2, 0.2],
[[0.1, 0.7, 0.5, 0.9], [0, 1, 1, 1], 4, 0.25],
])
def test_1d_ece_em(self, probs, correct, num_bins, true_ece):
pred_ece = get_ece_em(probs, correct, num_bins=num_bins)
self.assertAlmostEqual(pred_ece, true_ece)
# If number of bins is 1, then test that the regular ece is the same too.
if num_bins == 1:
pred_ece_ew = get_ece(probs, correct, num_bins=num_bins)
self.assertAlmostEqual(pred_ece_ew, true_ece)

def test_missing_classes_ece(self):
pred_ece = get_ece([[0.9,0.1], [0.8,0.2]], [0,0])
Expand Down

0 comments on commit ef4ebb5

Please sign in to comment.