Skip to content

Commit

Permalink
feat: support 2d arrays properly in cc3d.statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
William Silversmith committed Jun 22, 2024
1 parent 9d01796 commit 653d53a
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 18 deletions.
45 changes: 44 additions & 1 deletion automated_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ def test_voxel_graph_3d():
assert np.all(gt.T == graph)

@pytest.mark.parametrize("order", ("C", "F"))
def test_statistics(order):
def test_statistics_3d(order):
labels = np.zeros((123,128,125), dtype=np.uint8, order=order)
labels[10:20,10:20,10:20] = 1
labels[40:50,40:50,40:51] = 2
Expand Down Expand Up @@ -978,6 +978,49 @@ def test_statistics(order):
stats = cc3d.statistics(labels)
assert np.all(stats["centroids"][0] == np.array([255.5,255.5,255.5]))

@pytest.mark.parametrize("order", ("C", "F"))
def test_statistics_2d(order):
labels = np.zeros((123,128), dtype=np.uint8, order=order)
labels[10:20,10:20] = 1
labels[40:50,40:50] = 2

stats = cc3d.statistics(labels)
assert stats["voxel_counts"][1] == 100
assert stats["voxel_counts"][2] == 10 * 10

assert np.all(np.isclose(stats["centroids"][1,:], [14.5,14.5]))
assert np.all(np.isclose(stats["centroids"][2,:], [44.5,44.5]))

print(stats["bounding_boxes"])

assert np.all(stats["bounding_boxes"][0] == (slice(0,123), slice(0,128)))
assert np.all(stats["bounding_boxes"][1] == (slice(10,20), slice(10,20)))
assert np.all(stats["bounding_boxes"][2] == (slice(40,50), slice(40,50)))

stats = cc3d.statistics(labels, no_slice_conversion=True)
print(stats["bounding_boxes"])
assert np.all(stats["bounding_boxes"][0] == np.array([ 0, 122, 0, 127 ]))
assert np.all(stats["bounding_boxes"][1] == np.array([ 10, 19, 10, 19 ]))
assert np.all(stats["bounding_boxes"][2] == np.array([ 40, 49, 40, 49 ]))

labels = np.zeros((1,1), dtype=np.uint8, order=order)
stats = cc3d.statistics(labels)
assert len(stats["voxel_counts"]) == 1
assert stats["voxel_counts"][0] == 1

labels = np.zeros((0,1), dtype=np.uint8, order=order)
stats = cc3d.statistics(labels)
assert stats == {
"voxel_counts": None,
"bounding_boxes": None,
"centroids": None
}

labels = np.zeros((512,512), dtype=np.uint8, order=order)
stats = cc3d.statistics(labels)
assert np.all(stats["centroids"][0] == np.array([255.5,255.5]))


@pytest.mark.parametrize("order", ["C", "F"])
def test_statistics_big(order):
labels = np.zeros((50,66000,1), dtype=np.uint8, order=order)
Expand Down
109 changes: 92 additions & 17 deletions cc3d.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -665,26 +665,16 @@ def statistics(
centroids: np.ndarray[float64] (N+1,3)
}
"""
while out_labels.ndim < 3:
while out_labels.ndim < 2:
out_labels = out_labels[..., np.newaxis]
if out_labels.dtype == bool:
out_labels = out_labels.view(np.uint8)
return _statistics(out_labels, no_slice_conversion)
@cython.cdivision(True)
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def _statistics(
cnp.ndarray[UINT, ndim=3] out_labels,
native_bool no_slice_conversion
):
cdef uint64_t voxels = out_labels.size;
cdef uint64_t sx = out_labels.shape[0]
cdef uint64_t sy = out_labels.shape[1]
cdef uint64_t sz = out_labels.shape[2]
cdef uint64_t sz = (out_labels.shape[2] if out_labels.ndim > 2 else 1)
if voxels == 0:
return {
Expand All @@ -704,17 +694,102 @@ def _statistics(
cdef cnp.ndarray[uint32_t] bounding_boxes32
if np.any(np.array([sx,sy,sz]) > np.iinfo(np.uint16).max):
bounding_boxes32 = np.zeros(6 * (N + 1), dtype=np.uint32)
return _statistics_helper(out_labels, no_slice_conversion, bounding_boxes32, N)
if out_labels.ndim == 2:
bounding_boxes32 = np.zeros(4 * (N + 1), dtype=np.uint32)
return _statistics_helper2d(out_labels, no_slice_conversion, bounding_boxes32, N)
else:
bounding_boxes32 = np.zeros(6 * (N + 1), dtype=np.uint32)
return _statistics_helper3d(out_labels, no_slice_conversion, bounding_boxes32, N)
else:
bounding_boxes16 = np.zeros(6 * (N + 1), dtype=np.uint16)
return _statistics_helper(out_labels, no_slice_conversion, bounding_boxes16, N)
if out_labels.ndim == 2:
bounding_boxes32 = np.zeros(4 * (N + 1), dtype=np.uint32)
return _statistics_helper2d(out_labels, no_slice_conversion, bounding_boxes32, N)
else:
bounding_boxes16 = np.zeros(6 * (N + 1), dtype=np.uint16)
return _statistics_helper3d(out_labels, no_slice_conversion, bounding_boxes16, N)
@cython.cdivision(True)
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def _statistics_helper2d(
cnp.ndarray[UINT, ndim=2] out_labels,
native_bool no_slice_conversion,
cnp.ndarray[BBOX_T, ndim=1] bounding_boxes,
uint64_t N
):
cdef uint64_t voxels = out_labels.size;
cdef uint64_t sx = out_labels.shape[0]
cdef uint64_t sy = out_labels.shape[1]
cdef cnp.ndarray[uint32_t] counts = np.zeros(N + 1, dtype=np.uint32)
cdef cnp.ndarray[double] centroids = np.zeros(2 * (N + 1), dtype=np.float64)
cdef BBOX_T x = 0
cdef BBOX_T y = 0
cdef uint64_t label = 0
bounding_boxes[::2] = np.iinfo(bounding_boxes.dtype).max
if out_labels.flags.f_contiguous:
for y in range(sy):
for x in range(sx):
label = <uint64_t>out_labels[x,y]
counts[label] += 1
bounding_boxes[4 * label + 0] = <BBOX_T>min(bounding_boxes[4 * label + 0], x)
bounding_boxes[4 * label + 1] = <BBOX_T>max(bounding_boxes[4 * label + 1], x)
bounding_boxes[4 * label + 2] = <BBOX_T>min(bounding_boxes[4 * label + 2], y)
bounding_boxes[4 * label + 3] = <BBOX_T>max(bounding_boxes[4 * label + 3], y)
centroids[2 * label + 0] += <double>x
centroids[2 * label + 1] += <double>y
else:
for x in range(sx):
for y in range(sy):
label = <uint64_t>out_labels[x,y]
counts[label] += 1
bounding_boxes[4 * label + 0] = <BBOX_T>min(bounding_boxes[4 * label + 0], x)
bounding_boxes[4 * label + 1] = <BBOX_T>max(bounding_boxes[4 * label + 1], x)
bounding_boxes[4 * label + 2] = <BBOX_T>min(bounding_boxes[4 * label + 2], y)
bounding_boxes[4 * label + 3] = <BBOX_T>max(bounding_boxes[4 * label + 3], y)
centroids[2 * label + 0] += <double>x
centroids[2 * label + 1] += <double>y
for label in range(N+1):
if <double>counts[label] == 0:
centroids[2 * label + 0] = float('NaN')
centroids[2 * label + 1] = float('NaN')
else:
centroids[2 * label + 0] /= <double>counts[label]
centroids[2 * label + 1] /= <double>counts[label]
bbxes = bounding_boxes.reshape((N+1,4))
output = {
"voxel_counts": counts,
"bounding_boxes": bbxes,
"centroids": centroids.reshape((N+1,2)),
}
if no_slice_conversion:
return output
slices = []
for xs, xe, ys, ye in bbxes:
if xs < voxels and ys < voxels:
slices.append((slice(xs, int(xe+1)), slice(ys, int(ye+1))))
else:
slices.append(None)
output["bounding_boxes"] = slices
return output
@cython.cdivision(True)
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def _statistics_helper(
def _statistics_helper3d(
cnp.ndarray[UINT, ndim=3] out_labels,
native_bool no_slice_conversion,
cnp.ndarray[BBOX_T, ndim=1] bounding_boxes,
Expand Down

0 comments on commit 653d53a

Please sign in to comment.