diff --git a/sleap_nn/data/instance_centroids.py b/sleap_nn/data/instance_centroids.py new file mode 100644 index 00000000..ce6e461b --- /dev/null +++ b/sleap_nn/data/instance_centroids.py @@ -0,0 +1,94 @@ +"""Handle calculation of instance centroids.""" +import torchdata.datapipes.iter as dp +import lightning.pytorch as pl +from typing import Optional +import sleap_io as sio +import numpy as np +import torch + + +def find_points_bbox_midpoint(points: torch.Tensor) -> torch.Tensor: + """Find the midpoint of the bounding box of a set of points. + + Args: + points: A torch.Tensor of dtype torch.float32 and of shape (..., n_points, 2), + i.e., rank >= 2. + + Returns: + The midpoints between the bounds of each set of points. The output will be of + shape (..., 2), reducing the rank of the input by 1. NaNs will be ignored in the + calculation. + + Notes: + The midpoint is calculated as: + xy_mid = xy_min + ((xy_max - xy_min) / 2) + = ((2 * xy_min) / 2) + ((xy_max - xy_min) / 2) + = (2 * xy_min + xy_max - xy_min) / 2 + = (xy_min + xy_max) / 2 + """ + pts_min = torch.min( + torch.where(torch.isnan(points), torch.inf, points), dim=-2 + ).values + pts_max = torch.max( + torch.where(torch.isnan(points), -torch.inf, points), dim=-2 + ).values + + return (pts_max + pts_min) * 0.5 + + +def find_centroids( + points: torch.Tensor, anchor_ind: Optional[int] = None +) -> torch.Tensor: + """Return centroids, falling back to bounding box midpoints. + + Args: + points: A torch.Tensor of dtype torch.float32 and of shape (..., n_points, 2), + i.e., rank >= 2. + anchor_ind: The index of the node to use as the anchor for the centroid. If not + provided or if not present in the instance, the midpoint of the bounding box + is used instead. + + Returns: + The centroids of the instances. The output will be of shape (..., 2), reducing + the rank of the input by 1. NaNs will be ignored in the calculation. + """ + if anchor_ind is not None: + centroids = points[..., anchor_ind, :] + else: + centroids = torch.full_like(points[..., 0, :], torch.nan) + + missing_anchors = torch.isnan(centroids).any(dim=-1) + if missing_anchors.any(): + centroids[missing_anchors] = find_points_bbox_midpoint(points[missing_anchors]) + + return centroids + + +class InstanceCentroidFinder(dp.IterDataPipe): + """Datapipe for finding centroids of instances. + + This DataPipe will produce examples that contain a 'centroids' key. + + Attributes: + source_dp: The previous `DataPipe` with samples that contain an `instances` key. + anchor_ind: The index of the node to use as the anchor for the centroid. If not + provided or if not present in the instance, the midpoint of the bounding box + is used instead. + """ + + def __init__( + self, + source_dp: dp.IterDataPipe, + anchor_ind: Optional[int] = None, + ): + """Initialize InstanceCentroidFinder with the source `DataPipe.""" + self.source_dp = source_dp + self.anchor_ind = anchor_ind + + def __iter__(self): + """Add `"centroids"` key to example.""" + for example in self.source_dp: + example["centroids"] = find_centroids( + example["instances"], anchor_ind=self.anchor_ind + ) + yield example diff --git a/sleap_nn/data/providers.py b/sleap_nn/data/providers.py index 847ade91..60ea4999 100644 --- a/sleap_nn/data/providers.py +++ b/sleap_nn/data/providers.py @@ -3,6 +3,7 @@ import lightning.pytorch as pl import torch import sleap_io as sio +import numpy as np class LabelsReader(dp.IterDataPipe): @@ -27,13 +28,28 @@ def from_filename(cls, filename: str): return cls(labels) def __iter__(self): - """Return a sample containing the following elements. + """Return an example dictionary containing the following elements. - - a torch.Tensor representing an instance - - a torch.Tensor representing the corresponding image + "image": A torch.Tensor containing full raw frame image as a uint8 array + of shape (1, channels, height, width). + "instances": Keypoint coordinates for all instances in the frame as a + float32 torch.Tensor of shape (1, num_instances, num_nodes, 2). """ for lf in self.labels: + image = np.transpose(lf.image, (2, 0, 1)) # HWC -> CHW + + instances = [] for inst in lf: - instance = torch.from_numpy(inst.numpy()) - image = torch.from_numpy(lf.image) - yield instance, image + instances.append(inst.numpy()) + instances = np.stack(instances, axis=0) + + # Add singleton time dimension for single frames. + image = np.expand_dims(image, axis=0) # (1, C, H, W) + instances = np.expand_dims( + instances, axis=0 + ) # (1, num_instances, num_nodes, 2) + + yield { + "image": torch.from_numpy(image), + "instances": torch.from_numpy(instances), + } diff --git a/tests/test_instance_centroids.py b/tests/test_instance_centroids.py new file mode 100644 index 00000000..c9de5e03 --- /dev/null +++ b/tests/test_instance_centroids.py @@ -0,0 +1,43 @@ +from sleap_nn.data.providers import LabelsReader +import torch +from sleap_nn.data.instance_centroids import ( + InstanceCentroidFinder, + find_points_bbox_midpoint, + find_centroids, +) + + +def test_instance_centroids(minimal_instance): + """Test InstanceCentroidFinder + + Args: + minimal_instance: minimal_instance testing fixture + """ + + # Undefined anchor_ind + datapipe = LabelsReader.from_filename(minimal_instance) + datapipe = InstanceCentroidFinder(datapipe) + sample = next(iter(datapipe)) + instances = sample["instances"] + centroids = sample["centroids"] + centroids = centroids.int() + gt = torch.Tensor([[[122, 180], [242, 195]]]).int() + assert torch.equal(centroids, gt) + + # Defined anchor_ind + centroids = find_centroids(instances, 1).int() + gt = torch.Tensor([[[152, 158], [278, 203]]]) + assert torch.equal(centroids, gt) + + # Defined anchor_ind, but missing one + partial_instance = torch.Tensor( + [ + [ + [[92.6522, 202.7260], [152.3419, 158.4236], [97.2618, 53.5834]], + [[205.9301, 187.8896], [torch.nan, torch.nan], [201.4264, 75.2373]], + ] + ] + ) + centroids = find_centroids(partial_instance, 1).int() + gt = torch.Tensor([[[152, 158], [203, 131]]]) + assert torch.equal(centroids, gt) diff --git a/tests/test_providers.py b/tests/test_providers.py index 8a40ebbc..cf5dfeb1 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -5,12 +5,13 @@ def test_providers(minimal_instance): - """Test sleap dataset + """Test LabelsReader Args: minimal_instance: minimal_instance testing fixture """ l = LabelsReader.from_filename(minimal_instance) - instance, image = next(iter(l)) - assert image.shape == torch.Size([384, 384, 1]) - assert instance.shape == torch.Size([2, 2]) + sample = next(iter(l)) + instances, image = sample["instances"], sample["image"] + assert image.shape == torch.Size([1, 1, 384, 384]) + assert instances.shape == torch.Size([1, 2, 2, 2])