diff --git a/automated_test.py b/automated_test.py index 15d51e6..de3ed34 100644 --- a/automated_test.py +++ b/automated_test.py @@ -9,7 +9,7 @@ np.uint8, np.uint16, np.uint32, np.uint64, ] -TEST_TYPES = INT_TYPES + [ np.float32, np.float64 ] +TEST_TYPES = INT_TYPES + [ np.float16, np.float32, np.float64 ] OUT_TYPES = [ np.uint16, np.uint32, np.uint64 ] @@ -1006,6 +1006,11 @@ def test_continuous_ccl_diagonal(order, dtype, connectivity): labels[0,1] = 3 labels[1,1] = 4 + if dtype == np.float16: + with pytest.raises(TypeError) as e_info: + out = cc3d.connected_components(labels, delta=1, connectivity=connectivity) + return + out = cc3d.connected_components(labels, delta=0, connectivity=connectivity) assert np.all(np.unique(labels) == [1,2,3,4]) @@ -1025,6 +1030,11 @@ def test_continuous_ccl_4_6(order, dtype, connectivity): out = cc3d.connected_components(labels, delta=0, connectivity=connectivity) assert np.all(np.unique(labels) == [1,2,3,4]) + if dtype == np.float16: + with pytest.raises(TypeError) as e_info: + out = cc3d.connected_components(labels, delta=1, connectivity=connectivity) + return + out = cc3d.connected_components(labels, delta=1, connectivity=connectivity) assert np.all(out == np.array([ [1, 2], @@ -1049,6 +1059,11 @@ def test_continuous_blocks(dtype, connectivity, order): ) assert np.unique(out).size > 1000 + if dtype == np.float16: + with pytest.raises(TypeError) as e_info: + out = cc3d.connected_components(img, delta=1, connectivity=connectivity) + return + out = cc3d.connected_components( img, connectivity=connectivity, delta=1 ) diff --git a/cc3d.pyx b/cc3d.pyx index 53647dd..88ff769 100644 --- a/cc3d.pyx +++ b/cc3d.pyx @@ -222,7 +222,10 @@ def estimate_provisional_labels(data:np.ndarray) -> Tuple[int,int,int]: first_foreground_row, last_foreground_row ) else: - raise TypeError("Type {} not currently supported.".format(dtype)) + raise TypeError( + f"Type {dtype} is not currently supported. " + f"Supported: bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float16, float32, float64" + ) finally: if data.flags.owndata: data.setflags(write=writable) @@ -312,6 +315,12 @@ def connected_components( if not data.flags.c_contiguous and not data.flags.f_contiguous: data = np.copy(data, order=order) + if data.dtype == np.float16: + if delta == 0: + data = data.view(np.uint16) + else: + raise TypeError("float16 is not supported for continuous images (delta != 0).") + shape = list(data.shape) if order == 'C': @@ -541,7 +550,10 @@ def connected_components( &out_labels64[0], N, periodic_boundary ) else: - raise TypeError("Type {} not currently supported.".format(dtype)) + raise TypeError( + f"Type {dtype} is not currently supported. " + f"Supported: bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float16, float32, float64" + ) finally: if data.flags.owndata: data.setflags(write=writable)