Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add centroid finder block #7

Merged
merged 4 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions sleap_nn/data/instance_centroids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Handle calculation of instance centroids."""
import torchdata.datapipes.iter as dp
import lightning.pytorch as pl
from typing import Optional
import sleap_io as sio
import numpy as np
import torch


def find_points_bbox_midpoint(points: torch.Tensor) -> torch.Tensor:
"""Find the midpoint of the bounding box of a set of points.

Args:
points: A torch.Tensor of dtype torch.float32 and of shape (..., n_points, 2),
i.e., rank >= 2.

Returns:
The midpoints between the bounds of each set of points. The output will be of
shape (..., 2), reducing the rank of the input by 1. NaNs will be ignored in the
calculation.

Notes:
The midpoint is calculated as:
xy_mid = xy_min + ((xy_max - xy_min) / 2)
= ((2 * xy_min) / 2) + ((xy_max - xy_min) / 2)
= (2 * xy_min + xy_max - xy_min) / 2
= (xy_min + xy_max) / 2
"""
pts_min = torch.min(
torch.where(torch.isnan(points), torch.inf, points), dim=-2
).values
pts_max = torch.max(
torch.where(torch.isnan(points), -torch.inf, points), dim=-2
).values

return (pts_max + pts_min) * 0.5


def find_centroids(
points: torch.Tensor, anchor_ind: Optional[int] = None
) -> torch.Tensor:
"""Return centroids, falling back to bounding box midpoints.

Args:
points: A torch.Tensor of dtype torch.float32 and of shape (..., n_points, 2),
i.e., rank >= 2.
anchor_ind: The index of the node to use as the anchor for the centroid. If not
provided or if not present in the instance, the midpoint of the bounding box
is used instead.

Returns:
The centroids of the instances. The output will be of shape (..., 2), reducing
the rank of the input by 1. NaNs will be ignored in the calculation.
"""
if anchor_ind is not None:
centroids = points[..., anchor_ind, :]
else:
centroids = torch.full_like(points[..., 0, :], torch.nan)

missing_anchors = torch.isnan(centroids).any(dim=-1)
if missing_anchors.any():
centroids[missing_anchors] = find_points_bbox_midpoint(points[missing_anchors])

return centroids


class InstanceCentroidFinder(dp.IterDataPipe):
"""Datapipe for finding centroids of instances.

This DataPipe will produce examples that contain a 'centroids' key.

Attributes:
source_dp: The previous `DataPipe` with samples that contain an `instances` key.
anchor_ind: The index of the node to use as the anchor for the centroid. If not
provided or if not present in the instance, the midpoint of the bounding box
is used instead.
"""

def __init__(
self,
source_dp: dp.IterDataPipe,
anchor_ind: Optional[int] = None,
):
davidasamy marked this conversation as resolved.
Show resolved Hide resolved
"""Initialize InstanceCentroidFinder with the source `DataPipe."""
self.source_dp = source_dp
self.anchor_ind = anchor_ind

def __iter__(self):
"""Add `"centroids"` key to example."""
for example in self.source_dp:
example["centroids"] = find_centroids(
example["instances"], anchor_ind=self.anchor_ind
)
yield example
28 changes: 22 additions & 6 deletions sleap_nn/data/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import lightning.pytorch as pl
import torch
import sleap_io as sio
import numpy as np


class LabelsReader(dp.IterDataPipe):
Expand All @@ -27,13 +28,28 @@ def from_filename(cls, filename: str):
return cls(labels)

def __iter__(self):
"""Return a sample containing the following elements.
"""Return an example dictionary containing the following elements.

- a torch.Tensor representing an instance
- a torch.Tensor representing the corresponding image
"image": A torch.Tensor containing full raw frame image as a uint8 array
of shape (1, channels, height, width).
"instances": Keypoint coordinates for all instances in the frame as a
float32 torch.Tensor of shape (1, num_instances, num_nodes, 2).
"""
for lf in self.labels:
image = np.transpose(lf.image, (2, 0, 1)) # HWC -> CHW

instances = []
for inst in lf:
instance = torch.from_numpy(inst.numpy())
image = torch.from_numpy(lf.image)
yield instance, image
instances.append(inst.numpy())
instances = np.stack(instances, axis=0)

# Add singleton time dimension for single frames.
image = np.expand_dims(image, axis=0) # (1, C, H, W)
instances = np.expand_dims(
instances, axis=0
) # (1, num_instances, num_nodes, 2)

yield {
"image": torch.from_numpy(image),
"instances": torch.from_numpy(instances),
}
43 changes: 43 additions & 0 deletions tests/test_instance_centroids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from sleap_nn.data.providers import LabelsReader
import torch
from sleap_nn.data.instance_centroids import (
InstanceCentroidFinder,
find_points_bbox_midpoint,
find_centroids,
)


def test_instance_centroids(minimal_instance):
"""Test InstanceCentroidFinder

Args:
minimal_instance: minimal_instance testing fixture
"""

# Undefined anchor_ind
datapipe = LabelsReader.from_filename(minimal_instance)
datapipe = InstanceCentroidFinder(datapipe)
sample = next(iter(datapipe))
instances = sample["instances"]
centroids = sample["centroids"]
centroids = centroids.int()
gt = torch.Tensor([[[122, 180], [242, 195]]]).int()
assert torch.equal(centroids, gt)

# Defined anchor_ind
centroids = find_centroids(instances, 1).int()
gt = torch.Tensor([[[152, 158], [278, 203]]])
assert torch.equal(centroids, gt)

# Defined anchor_ind, but missing one
partial_instance = torch.Tensor(
[
[
[[92.6522, 202.7260], [152.3419, 158.4236], [97.2618, 53.5834]],
[[205.9301, 187.8896], [torch.nan, torch.nan], [201.4264, 75.2373]],
]
]
)
centroids = find_centroids(partial_instance, 1).int()
gt = torch.Tensor([[[152, 158], [203, 131]]])
assert torch.equal(centroids, gt)
9 changes: 5 additions & 4 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@


def test_providers(minimal_instance):
"""Test sleap dataset
"""Test LabelsReader

Args:
minimal_instance: minimal_instance testing fixture
"""
l = LabelsReader.from_filename(minimal_instance)
instance, image = next(iter(l))
assert image.shape == torch.Size([384, 384, 1])
assert instance.shape == torch.Size([2, 2])
sample = next(iter(l))
instances, image = sample["instances"], sample["image"]
assert image.shape == torch.Size([1, 1, 384, 384])
assert instances.shape == torch.Size([1, 2, 2, 2])