Skip to content

Commit

Permalink
feat: support float16 for labeled images
Browse files Browse the repository at this point in the history
  • Loading branch information
william-silversmith committed May 10, 2024
1 parent 24aee1e commit bd4f1ff
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
17 changes: 16 additions & 1 deletion automated_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
np.uint8, np.uint16, np.uint32, np.uint64,
]

TEST_TYPES = INT_TYPES + [ np.float32, np.float64 ]
TEST_TYPES = INT_TYPES + [ np.float16, np.float32, np.float64 ]

OUT_TYPES = [ np.uint16, np.uint32, np.uint64 ]

Expand Down Expand Up @@ -1006,6 +1006,11 @@ def test_continuous_ccl_diagonal(order, dtype, connectivity):
labels[0,1] = 3
labels[1,1] = 4

if dtype == np.float16:
with pytest.raises(TypeError) as e_info:
out = cc3d.connected_components(labels, delta=1, connectivity=connectivity)
return

out = cc3d.connected_components(labels, delta=0, connectivity=connectivity)
assert np.all(np.unique(labels) == [1,2,3,4])

Expand All @@ -1025,6 +1030,11 @@ def test_continuous_ccl_4_6(order, dtype, connectivity):
out = cc3d.connected_components(labels, delta=0, connectivity=connectivity)
assert np.all(np.unique(labels) == [1,2,3,4])

if dtype == np.float16:
with pytest.raises(TypeError) as e_info:
out = cc3d.connected_components(labels, delta=1, connectivity=connectivity)
return

out = cc3d.connected_components(labels, delta=1, connectivity=connectivity)
assert np.all(out == np.array([
[1, 2],
Expand All @@ -1049,6 +1059,11 @@ def test_continuous_blocks(dtype, connectivity, order):
)
assert np.unique(out).size > 1000

if dtype == np.float16:
with pytest.raises(TypeError) as e_info:
out = cc3d.connected_components(img, delta=1, connectivity=connectivity)
return

out = cc3d.connected_components(
img, connectivity=connectivity, delta=1
)
Expand Down
16 changes: 14 additions & 2 deletions cc3d.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,10 @@ def estimate_provisional_labels(data:np.ndarray) -> Tuple[int,int,int]:
first_foreground_row, last_foreground_row
)
else:
raise TypeError("Type {} not currently supported.".format(dtype))
raise TypeError(
f"Type {dtype} is not currently supported. "
f"Supported: bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float16, float32, float64"
)
finally:
if data.flags.owndata:
data.setflags(write=writable)
Expand Down Expand Up @@ -312,6 +315,12 @@ def connected_components(
if not data.flags.c_contiguous and not data.flags.f_contiguous:
data = np.copy(data, order=order)
if data.dtype == np.float16:
if delta == 0:
data = data.view(np.uint16)
else:
raise TypeError("float16 is not supported for continuous images (delta != 0).")
shape = list(data.shape)
if order == 'C':
Expand Down Expand Up @@ -541,7 +550,10 @@ def connected_components(
<uint64_t*>&out_labels64[0], N, periodic_boundary
)
else:
raise TypeError("Type {} not currently supported.".format(dtype))
raise TypeError(
f"Type {dtype} is not currently supported. "
f"Supported: bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float16, float32, float64"
)
finally:
if data.flags.owndata:
data.setflags(write=writable)
Expand Down

0 comments on commit bd4f1ff

Please sign in to comment.