From 07f4a610233a4830aa4a559d4367b851b92ed32a Mon Sep 17 00:00:00 2001 From: David Samy Date: Thu, 20 Jul 2023 14:01:05 -0700 Subject: [PATCH 1/4] Add InstanceCentroid class and corresponding functions index on david/instance-centroids: 8daa796 Add InstanceCentroid class and corresponding functions remove debug code --- sleap_nn/data/instance_centroids.py | 64 +++++++++++++++++++++++++++++ sleap_nn/data/providers.py | 6 +-- tests/test_instance_centroids.py | 21 ++++++++++ tests/test_providers.py | 5 ++- 4 files changed, 91 insertions(+), 5 deletions(-) create mode 100644 sleap_nn/data/instance_centroids.py create mode 100644 tests/test_instance_centroids.py diff --git a/sleap_nn/data/instance_centroids.py b/sleap_nn/data/instance_centroids.py new file mode 100644 index 00000000..17f87d91 --- /dev/null +++ b/sleap_nn/data/instance_centroids.py @@ -0,0 +1,64 @@ +"""Handle calculation of instance centroids.""" +import torchdata.datapipes.iter as dp +import lightning.pytorch as pl +from typing import Optional +import sleap_io as sio +import numpy as np +import torch + + +def find_points_bbox_midpoint(points: torch.Tensor) -> torch.Tensor: + """Find the midpoint of the bounding box of a set of points. + + Args: + instances: A torch.Tensor of dtype torch.float32 and of shape (..., n_points, 2), + i.e., rank >= 2. + + Returns: + The midpoints between the bounds of each set of points. The output will be of + shape (..., 2), reducing the rank of the input by 1. NaNs will be ignored in the + calculation. + + Notes: + The midpoint is calculated as: + xy_mid = xy_min + ((xy_max - xy_min) / 2) + = ((2 * xy_min) / 2) + ((xy_max - xy_min) / 2) + = (2 * xy_min + xy_max - xy_min) / 2 + = (xy_min + xy_max) / 2 + """ + pts_min = torch.min(torch.where(torch.isnan(points), np.inf, points), dim=-2).values + pts_max = torch.max( + torch.where(torch.isnan(points), -np.inf, points), dim=-2 + ).values + + return (pts_max + pts_min) * 0.5 + + +class InstanceCentroidFinder(dp.IterDataPipe): + """Datapipe for finding centroids of instances. + + This DataPipe will produce examples that have been containing a 'centroid' key. + + Attributes: + source_dp: the previous `DataPipe` with samples that contain an `instance` + """ + + def __init__( + self, + source_dp: dp.IterDataPipe, + ): + """Initialize InstanceCentroidFinder with the source `DataPipe.""" + self.source_dp = source_dp + + def __iter__(self): + """Add 'centroid' key to sample.""" + + def find_centroids(sample): + mid_pts = find_points_bbox_midpoint(sample["instance"]) + sample["centroid"] = mid_pts + + return sample + + for sample in self.source_dp: + find_centroids(sample) + yield sample diff --git a/sleap_nn/data/providers.py b/sleap_nn/data/providers.py index 847ade91..6bf0fa6c 100644 --- a/sleap_nn/data/providers.py +++ b/sleap_nn/data/providers.py @@ -29,11 +29,11 @@ def from_filename(cls, filename: str): def __iter__(self): """Return a sample containing the following elements. - - a torch.Tensor representing an instance - - a torch.Tensor representing the corresponding image + "image": a torch.Tensor containing full raw frame image. + "instance": A single instance of the corresponding image. """ for lf in self.labels: for inst in lf: instance = torch.from_numpy(inst.numpy()) image = torch.from_numpy(lf.image) - yield instance, image + yield {"image": image, "instance": instance} diff --git a/tests/test_instance_centroids.py b/tests/test_instance_centroids.py new file mode 100644 index 00000000..8f706c9d --- /dev/null +++ b/tests/test_instance_centroids.py @@ -0,0 +1,21 @@ +from sleap_nn.data.providers import LabelsReader +import torch +from sleap_nn.data.instance_centroids import ( + InstanceCentroidFinder, + find_points_bbox_midpoint, +) + + +def test_instance_centroids(minimal_instance): + """Test InstanceCentroidFinder + + Args: + minimal_instance: minimal_instance testing fixture + """ + datapipe = LabelsReader.from_filename(minimal_instance) + datapipe = InstanceCentroidFinder(datapipe) + sample = next(iter(datapipe)) + centroid = sample["centroid"] + centroid = centroid.int() + gt = torch.Tensor([122, 180]).int() + assert torch.equal(centroid, gt) diff --git a/tests/test_providers.py b/tests/test_providers.py index 8a40ebbc..4113b664 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -5,12 +5,13 @@ def test_providers(minimal_instance): - """Test sleap dataset + """Test LabelsReader Args: minimal_instance: minimal_instance testing fixture """ l = LabelsReader.from_filename(minimal_instance) - instance, image = next(iter(l)) + sample = next(iter(l)) + instance, image = sample["instance"], sample["image"] assert image.shape == torch.Size([384, 384, 1]) assert instance.shape == torch.Size([2, 2]) From a3a4dcb2239067533898d2da0b40277ad68033bc Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Fri, 21 Jul 2023 12:48:58 -0700 Subject: [PATCH 2/4] Revise labels provider to yield frame-level examples --- sleap_nn/data/providers.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/sleap_nn/data/providers.py b/sleap_nn/data/providers.py index 6bf0fa6c..6e0a360b 100644 --- a/sleap_nn/data/providers.py +++ b/sleap_nn/data/providers.py @@ -3,6 +3,7 @@ import lightning.pytorch as pl import torch import sleap_io as sio +import numpy as np class LabelsReader(dp.IterDataPipe): @@ -27,13 +28,28 @@ def from_filename(cls, filename: str): return cls(labels) def __iter__(self): - """Return a sample containing the following elements. + """Return an example dictionary containing the following elements: - "image": a torch.Tensor containing full raw frame image. - "instance": A single instance of the corresponding image. + "image": A torch.Tensor containing full raw frame image as a uint8 array + of shape (1, channels, height, width). + "instances": Keypoint coordinates for all instances in the frame as a + float32 torch.Tensor of shape (1, num_instances, num_nodes, 2). """ for lf in self.labels: + image = np.transpose(lf.image, (2, 0, 1)) # HWC -> CHW + + instances = [] for inst in lf: - instance = torch.from_numpy(inst.numpy()) - image = torch.from_numpy(lf.image) - yield {"image": image, "instance": instance} + instances.append(inst.numpy()) + instances = np.stack(instances, axis=0) + + # Add singleton time dimension for single frames. + image = np.expand_dims(image, axis=0) # (1, C, H, W) + instances = np.expand_dims( + instances, axis=0 + ) # (1, num_instances, num_nodes, 2) + + yield { + "image": torch.from_numpy(image), + "instances": torch.from_numpy(instances), + } From 21b59ca140fb37a1d4cde896ce5d4608dc0d3761 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Fri, 21 Jul 2023 12:49:23 -0700 Subject: [PATCH 3/4] Add anchor preference logic in centroid finder --- sleap_nn/data/instance_centroids.py | 62 +++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/sleap_nn/data/instance_centroids.py b/sleap_nn/data/instance_centroids.py index 17f87d91..a084770c 100644 --- a/sleap_nn/data/instance_centroids.py +++ b/sleap_nn/data/instance_centroids.py @@ -11,7 +11,7 @@ def find_points_bbox_midpoint(points: torch.Tensor) -> torch.Tensor: """Find the midpoint of the bounding box of a set of points. Args: - instances: A torch.Tensor of dtype torch.float32 and of shape (..., n_points, 2), + points: A torch.Tensor of dtype torch.float32 and of shape (..., n_points, 2), i.e., rank >= 2. Returns: @@ -26,39 +26,69 @@ def find_points_bbox_midpoint(points: torch.Tensor) -> torch.Tensor: = (2 * xy_min + xy_max - xy_min) / 2 = (xy_min + xy_max) / 2 """ - pts_min = torch.min(torch.where(torch.isnan(points), np.inf, points), dim=-2).values + pts_min = torch.min( + torch.where(torch.isnan(points), torch.inf, points), dim=-2 + ).values pts_max = torch.max( - torch.where(torch.isnan(points), -np.inf, points), dim=-2 + torch.where(torch.isnan(points), -torch.inf, points), dim=-2 ).values return (pts_max + pts_min) * 0.5 +def find_centroids( + points: torch.Tensor, anchor_ind: Optional[int] = None +) -> torch.Tensor: + """Return centroids, falling back to bounding box midpoints. + + Args: + points: A torch.Tensor of dtype torch.float32 and of shape (..., n_points, 2), + i.e., rank >= 2. + anchor_ind: The index of the node to use as the anchor for the centroid. If not + provided or if not present in the instance, the midpoint of the bounding box + is used instead. + + Returns: + The centroids of the instances. The output will be of shape (..., 2), reducing + the rank of the input by 1. NaNs will be ignored in the calculation. + """ + if anchor_ind is not None: + centroids = points[..., anchor_ind, :] + else: + centroids = torch.full_like(points[..., 0, :], torch.nan) + + missing_anchors = torch.isnan(centroids).any(dim=-1) + if missing_anchors.any(): + centroids[missing_anchors] = find_points_bbox_midpoint(points[missing_anchors]) + + return centroids + + class InstanceCentroidFinder(dp.IterDataPipe): """Datapipe for finding centroids of instances. - This DataPipe will produce examples that have been containing a 'centroid' key. + This DataPipe will produce examples that contain a 'centroids' key. Attributes: - source_dp: the previous `DataPipe` with samples that contain an `instance` + source_dp: The previous `DataPipe` with samples that contain an `instances` key. + anchor_ind: The index of the node to use as the anchor for the centroid. If not + provided or if not present in the instance, the midpoint of the bounding box + is used instead. """ def __init__( self, source_dp: dp.IterDataPipe, + anchor_ind: Optional[int] = None, ): """Initialize InstanceCentroidFinder with the source `DataPipe.""" self.source_dp = source_dp + self.anchor_ind = anchor_ind def __iter__(self): - """Add 'centroid' key to sample.""" - - def find_centroids(sample): - mid_pts = find_points_bbox_midpoint(sample["instance"]) - sample["centroid"] = mid_pts - - return sample - - for sample in self.source_dp: - find_centroids(sample) - yield sample + """Add `"centroids"` key to sample.""" + for example in self.source_dp: + example["centroids"] = find_centroids( + example["instances"], anchor_ind=self.anchor_ind + ) + yield example From 28b8f6dd6c24c44dff8b5a3adf8a59b8e0726908 Mon Sep 17 00:00:00 2001 From: David Samy Date: Fri, 21 Jul 2023 14:46:56 -0700 Subject: [PATCH 4/4] Fix tests --- sleap_nn/data/instance_centroids.py | 2 +- sleap_nn/data/providers.py | 2 +- tests/test_instance_centroids.py | 30 +++++++++++++++++++++++++---- tests/test_providers.py | 6 +++--- 4 files changed, 31 insertions(+), 9 deletions(-) diff --git a/sleap_nn/data/instance_centroids.py b/sleap_nn/data/instance_centroids.py index a084770c..ce6e461b 100644 --- a/sleap_nn/data/instance_centroids.py +++ b/sleap_nn/data/instance_centroids.py @@ -86,7 +86,7 @@ def __init__( self.anchor_ind = anchor_ind def __iter__(self): - """Add `"centroids"` key to sample.""" + """Add `"centroids"` key to example.""" for example in self.source_dp: example["centroids"] = find_centroids( example["instances"], anchor_ind=self.anchor_ind diff --git a/sleap_nn/data/providers.py b/sleap_nn/data/providers.py index 6e0a360b..60ea4999 100644 --- a/sleap_nn/data/providers.py +++ b/sleap_nn/data/providers.py @@ -28,7 +28,7 @@ def from_filename(cls, filename: str): return cls(labels) def __iter__(self): - """Return an example dictionary containing the following elements: + """Return an example dictionary containing the following elements. "image": A torch.Tensor containing full raw frame image as a uint8 array of shape (1, channels, height, width). diff --git a/tests/test_instance_centroids.py b/tests/test_instance_centroids.py index 8f706c9d..c9de5e03 100644 --- a/tests/test_instance_centroids.py +++ b/tests/test_instance_centroids.py @@ -3,6 +3,7 @@ from sleap_nn.data.instance_centroids import ( InstanceCentroidFinder, find_points_bbox_midpoint, + find_centroids, ) @@ -12,10 +13,31 @@ def test_instance_centroids(minimal_instance): Args: minimal_instance: minimal_instance testing fixture """ + + # Undefined anchor_ind datapipe = LabelsReader.from_filename(minimal_instance) datapipe = InstanceCentroidFinder(datapipe) sample = next(iter(datapipe)) - centroid = sample["centroid"] - centroid = centroid.int() - gt = torch.Tensor([122, 180]).int() - assert torch.equal(centroid, gt) + instances = sample["instances"] + centroids = sample["centroids"] + centroids = centroids.int() + gt = torch.Tensor([[[122, 180], [242, 195]]]).int() + assert torch.equal(centroids, gt) + + # Defined anchor_ind + centroids = find_centroids(instances, 1).int() + gt = torch.Tensor([[[152, 158], [278, 203]]]) + assert torch.equal(centroids, gt) + + # Defined anchor_ind, but missing one + partial_instance = torch.Tensor( + [ + [ + [[92.6522, 202.7260], [152.3419, 158.4236], [97.2618, 53.5834]], + [[205.9301, 187.8896], [torch.nan, torch.nan], [201.4264, 75.2373]], + ] + ] + ) + centroids = find_centroids(partial_instance, 1).int() + gt = torch.Tensor([[[152, 158], [203, 131]]]) + assert torch.equal(centroids, gt) diff --git a/tests/test_providers.py b/tests/test_providers.py index 4113b664..cf5dfeb1 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -12,6 +12,6 @@ def test_providers(minimal_instance): """ l = LabelsReader.from_filename(minimal_instance) sample = next(iter(l)) - instance, image = sample["instance"], sample["image"] - assert image.shape == torch.Size([384, 384, 1]) - assert instance.shape == torch.Size([2, 2]) + instances, image = sample["instances"], sample["image"] + assert image.shape == torch.Size([1, 1, 384, 384]) + assert instances.shape == torch.Size([1, 2, 2, 2])