Skip to content

Commit

Permalink
Bin points more evenly
Browse files Browse the repository at this point in the history
  • Loading branch information
AnanyaKumar committed Aug 27, 2022
1 parent ef4ebb5 commit 861d077
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
Binary file removed calibration/.util_test.py.swp
Binary file not shown.
20 changes: 19 additions & 1 deletion calibration/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_get_3_equal_bins_lots_of_1s(self):
def test_get_3_equal_bins_uneven_sizes(self):
probs = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
bins = np.array(get_equal_bins(probs, num_bins=3))
self.assertTrue(np.allclose(bins, np.array([0.55, 0.85, 1.0])))
self.assertTrue(np.allclose(bins, np.array([0.55, 0.75, 1.0])))

def test_equal_bins_more_bins_points(self):
probs = [0.3]
Expand All @@ -63,6 +63,12 @@ def test_equal_bins_more_bins_points(self):
probs = [0.3, 0.5]
bins = get_equal_bins(probs, num_bins=5)
self.assertEqual(bins, [0.4, 1.0])

def test_equal_bin_num_bins(self):
for n in [1,2,3,5,10,20]:
for num_bins in range(1,n):
bins = split(np.arange(n) / float(n), num_bins)
self.assertEqual(len(bins), num_bins)

def test_get_1_equal_prob_bins(self):
probs = [0.3, 0.5, 0.2, 0.3, 0.5, 0.7]
Expand Down Expand Up @@ -102,6 +108,18 @@ def test_get_bin_size_1(self):
self.assertEqual(get_bin(0.5, bins), 0)
self.assertEqual(get_bin(1.0, bins), 0)

def test_bin_all_same(self):
for n in range(1,10):
for num_bins in range(1,min(3,n)):
data = [(0.5, 1.0)] * n
probs = [p for p, y in data]
bins = get_equal_bins(probs, num_bins=num_bins)
binned_data = bin(data, bins)
self.assertTrue(
np.all(np.array(binned_data[0]) == np.array(data)))
for j in range(1, num_bins):
self.assertEqual(len(binned_data[j]), 0)

def test_bin(self):
data = [(0.3, 1.0), (0.5, 0.0), (0.2, 1.0), (0.3, 0.0), (0.5, 1.0), (0.7, 0.0)]
bins = [0.4, 1.0]
Expand Down
9 changes: 4 additions & 5 deletions calibration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@

def split(sequence: List[T], parts: int) -> List[List[T]]:
assert parts <= len(sequence)
part_size = int(np.ceil(len(sequence) * 1.0 / parts))
assert part_size * parts >= len(sequence)
assert (part_size - 1) * parts < len(sequence)
return [sequence[i:i + part_size] for i in range(0, len(sequence), part_size)]

array_splits = np.array_split(sequence, parts)
splits = [list(l) for l in array_splits]
assert len(splits) == parts
return splits

def get_equal_bins(probs: List[float], num_bins: int=10) -> Bins:
"""Get bins that contain approximately an equal number of data points."""
Expand Down

0 comments on commit 861d077

Please sign in to comment.