Skip to content

Commit

Permalink
Fix bug when num bins greater than sequence length
Browse files Browse the repository at this point in the history
  • Loading branch information
AnanyaKumar committed Aug 23, 2022
1 parent 79fc0f9 commit 4398fd8
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
10 changes: 10 additions & 0 deletions calibration/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ def test_get_3_equal_bins_uneven_sizes(self):
bins = np.array(get_equal_bins(probs, num_bins=3))
self.assertTrue(np.allclose(bins, np.array([0.55, 0.85, 1.0])))

def test_equal_bins_more_bins_points(self):
probs = [0.3]
bins = get_equal_bins(probs, num_bins=2)
self.assertEqual(bins, [1.0])
bins = get_equal_bins(probs, num_bins=5)
self.assertEqual(bins, [1.0])
probs = [0.3, 0.5]
bins = get_equal_bins(probs, num_bins=5)
self.assertEqual(bins, [0.4, 1.0])

def test_get_1_equal_prob_bins(self):
probs = [0.3, 0.5, 0.2, 0.3, 0.5, 0.7]
bins = get_equal_prob_bins(probs, num_bins=1)
Expand Down
2 changes: 2 additions & 0 deletions calibration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def split(sequence: List[T], parts: int) -> List[List[T]]:
def get_equal_bins(probs: List[float], num_bins: int=10) -> Bins:
"""Get bins that contain approximately an equal number of data points."""
sorted_probs = sorted(probs)
if num_bins > len(sorted_probs):
num_bins = len(sorted_probs)
binned_data = split(sorted_probs, num_bins)
bins: Bins = []
for i in range(len(binned_data) - 1):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="uncertainty-calibration",
version="0.1.0",
version="0.1.1",
author="Ananya Kumar",
author_email="[email protected]",
description="Utilities to calibrate model uncertainties and measure calibration.",
Expand Down

0 comments on commit 4398fd8

Please sign in to comment.