Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add centroid finder block #7

Merged
merged 4 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions sleap_nn/data/instance_centroids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Handle calculation of instance centroids."""
import torchdata.datapipes.iter as dp
import lightning.pytorch as pl
from typing import Optional
import sleap_io as sio
import numpy as np
import torch


def find_points_bbox_midpoint(points: torch.Tensor) -> torch.Tensor:
"""Find the midpoint of the bounding box of a set of points.
Args:
instances: A torch.Tensor of dtype torch.float32 and of shape (..., n_points, 2),
i.e., rank >= 2.
Returns:
The midpoints between the bounds of each set of points. The output will be of
shape (..., 2), reducing the rank of the input by 1. NaNs will be ignored in the
calculation.
Notes:
The midpoint is calculated as:
xy_mid = xy_min + ((xy_max - xy_min) / 2)
= ((2 * xy_min) / 2) + ((xy_max - xy_min) / 2)
= (2 * xy_min + xy_max - xy_min) / 2
= (xy_min + xy_max) / 2
"""
pts_min = torch.min(torch.where(torch.isnan(points), np.inf, points), dim=-2).values
pts_max = torch.max(
torch.where(torch.isnan(points), -np.inf, points), dim=-2
).values

return (pts_max + pts_min) * 0.5


class InstanceCentroidFinder(dp.IterDataPipe):
"""Datapipe for finding centroids of instances.
This DataPipe will produce examples that have been containing a 'centroid' key.
Attributes:
source_dp: the previous `DataPipe` with samples that contain an `instance`
"""

def __init__(
self,
source_dp: dp.IterDataPipe,
):
davidasamy marked this conversation as resolved.
Show resolved Hide resolved
"""Initialize InstanceCentroidFinder with the source `DataPipe."""
self.source_dp = source_dp

def __iter__(self):
"""Add 'centroid' key to sample."""

def find_centroids(sample):
mid_pts = find_points_bbox_midpoint(sample["instance"])
sample["centroid"] = mid_pts

return sample
talmo marked this conversation as resolved.
Show resolved Hide resolved

for sample in self.source_dp:
find_centroids(sample)
yield sample
6 changes: 3 additions & 3 deletions sleap_nn/data/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def from_filename(cls, filename: str):
def __iter__(self):
"""Return a sample containing the following elements.
- a torch.Tensor representing an instance
- a torch.Tensor representing the corresponding image
"image": a torch.Tensor containing full raw frame image.
"instance": A single instance of the corresponding image.
"""
for lf in self.labels:
for inst in lf:
instance = torch.from_numpy(inst.numpy())
image = torch.from_numpy(lf.image)
yield instance, image
yield {"image": image, "instance": instance}
21 changes: 21 additions & 0 deletions tests/test_instance_centroids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from sleap_nn.data.providers import LabelsReader
import torch
from sleap_nn.data.instance_centroids import (
InstanceCentroidFinder,
find_points_bbox_midpoint,
)


def test_instance_centroids(minimal_instance):
"""Test InstanceCentroidFinder
Args:
minimal_instance: minimal_instance testing fixture
"""
datapipe = LabelsReader.from_filename(minimal_instance)
datapipe = InstanceCentroidFinder(datapipe)
sample = next(iter(datapipe))
centroid = sample["centroid"]
centroid = centroid.int()
gt = torch.Tensor([122, 180]).int()
assert torch.equal(centroid, gt)
5 changes: 3 additions & 2 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@


def test_providers(minimal_instance):
"""Test sleap dataset
"""Test LabelsReader
Args:
minimal_instance: minimal_instance testing fixture
"""
l = LabelsReader.from_filename(minimal_instance)
instance, image = next(iter(l))
sample = next(iter(l))
instance, image = sample["instance"], sample["image"]
assert image.shape == torch.Size([384, 384, 1])
assert instance.shape == torch.Size([2, 2])