diff --git a/sleap_nn/data/instance_centroids.py b/sleap_nn/data/instance_centroids.py new file mode 100644 index 00000000..17f87d91 --- /dev/null +++ b/sleap_nn/data/instance_centroids.py @@ -0,0 +1,64 @@ +"""Handle calculation of instance centroids.""" +import torchdata.datapipes.iter as dp +import lightning.pytorch as pl +from typing import Optional +import sleap_io as sio +import numpy as np +import torch + + +def find_points_bbox_midpoint(points: torch.Tensor) -> torch.Tensor: + """Find the midpoint of the bounding box of a set of points. + + Args: + instances: A torch.Tensor of dtype torch.float32 and of shape (..., n_points, 2), + i.e., rank >= 2. + + Returns: + The midpoints between the bounds of each set of points. The output will be of + shape (..., 2), reducing the rank of the input by 1. NaNs will be ignored in the + calculation. + + Notes: + The midpoint is calculated as: + xy_mid = xy_min + ((xy_max - xy_min) / 2) + = ((2 * xy_min) / 2) + ((xy_max - xy_min) / 2) + = (2 * xy_min + xy_max - xy_min) / 2 + = (xy_min + xy_max) / 2 + """ + pts_min = torch.min(torch.where(torch.isnan(points), np.inf, points), dim=-2).values + pts_max = torch.max( + torch.where(torch.isnan(points), -np.inf, points), dim=-2 + ).values + + return (pts_max + pts_min) * 0.5 + + +class InstanceCentroidFinder(dp.IterDataPipe): + """Datapipe for finding centroids of instances. + + This DataPipe will produce examples that have been containing a 'centroid' key. + + Attributes: + source_dp: the previous `DataPipe` with samples that contain an `instance` + """ + + def __init__( + self, + source_dp: dp.IterDataPipe, + ): + """Initialize InstanceCentroidFinder with the source `DataPipe.""" + self.source_dp = source_dp + + def __iter__(self): + """Add 'centroid' key to sample.""" + + def find_centroids(sample): + mid_pts = find_points_bbox_midpoint(sample["instance"]) + sample["centroid"] = mid_pts + + return sample + + for sample in self.source_dp: + find_centroids(sample) + yield sample diff --git a/sleap_nn/data/providers.py b/sleap_nn/data/providers.py index 847ade91..6bf0fa6c 100644 --- a/sleap_nn/data/providers.py +++ b/sleap_nn/data/providers.py @@ -29,11 +29,11 @@ def from_filename(cls, filename: str): def __iter__(self): """Return a sample containing the following elements. - - a torch.Tensor representing an instance - - a torch.Tensor representing the corresponding image + "image": a torch.Tensor containing full raw frame image. + "instance": A single instance of the corresponding image. """ for lf in self.labels: for inst in lf: instance = torch.from_numpy(inst.numpy()) image = torch.from_numpy(lf.image) - yield instance, image + yield {"image": image, "instance": instance} diff --git a/tests/test_instance_centroids.py b/tests/test_instance_centroids.py new file mode 100644 index 00000000..8f706c9d --- /dev/null +++ b/tests/test_instance_centroids.py @@ -0,0 +1,21 @@ +from sleap_nn.data.providers import LabelsReader +import torch +from sleap_nn.data.instance_centroids import ( + InstanceCentroidFinder, + find_points_bbox_midpoint, +) + + +def test_instance_centroids(minimal_instance): + """Test InstanceCentroidFinder + + Args: + minimal_instance: minimal_instance testing fixture + """ + datapipe = LabelsReader.from_filename(minimal_instance) + datapipe = InstanceCentroidFinder(datapipe) + sample = next(iter(datapipe)) + centroid = sample["centroid"] + centroid = centroid.int() + gt = torch.Tensor([122, 180]).int() + assert torch.equal(centroid, gt) diff --git a/tests/test_providers.py b/tests/test_providers.py index 8a40ebbc..4113b664 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -5,12 +5,13 @@ def test_providers(minimal_instance): - """Test sleap dataset + """Test LabelsReader Args: minimal_instance: minimal_instance testing fixture """ l = LabelsReader.from_filename(minimal_instance) - instance, image = next(iter(l)) + sample = next(iter(l)) + instance, image = sample["instance"], sample["image"] assert image.shape == torch.Size([384, 384, 1]) assert instance.shape == torch.Size([2, 2])