Skip to content

Commit

Permalink
Add InstanceCentroid class and corresponding functions
Browse files Browse the repository at this point in the history
index on david/instance-centroids: 8daa796 Add InstanceCentroid class and corresponding functions

remove debug code
  • Loading branch information
davidasamy committed Jul 20, 2023
1 parent 5903df4 commit 07f4a61
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 5 deletions.
64 changes: 64 additions & 0 deletions sleap_nn/data/instance_centroids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Handle calculation of instance centroids."""
import torchdata.datapipes.iter as dp
import lightning.pytorch as pl
from typing import Optional
import sleap_io as sio
import numpy as np
import torch


def find_points_bbox_midpoint(points: torch.Tensor) -> torch.Tensor:
"""Find the midpoint of the bounding box of a set of points.
Args:
instances: A torch.Tensor of dtype torch.float32 and of shape (..., n_points, 2),
i.e., rank >= 2.
Returns:
The midpoints between the bounds of each set of points. The output will be of
shape (..., 2), reducing the rank of the input by 1. NaNs will be ignored in the
calculation.
Notes:
The midpoint is calculated as:
xy_mid = xy_min + ((xy_max - xy_min) / 2)
= ((2 * xy_min) / 2) + ((xy_max - xy_min) / 2)
= (2 * xy_min + xy_max - xy_min) / 2
= (xy_min + xy_max) / 2
"""
pts_min = torch.min(torch.where(torch.isnan(points), np.inf, points), dim=-2).values
pts_max = torch.max(
torch.where(torch.isnan(points), -np.inf, points), dim=-2
).values

return (pts_max + pts_min) * 0.5


class InstanceCentroidFinder(dp.IterDataPipe):
"""Datapipe for finding centroids of instances.
This DataPipe will produce examples that have been containing a 'centroid' key.
Attributes:
source_dp: the previous `DataPipe` with samples that contain an `instance`
"""

def __init__(
self,
source_dp: dp.IterDataPipe,
):
"""Initialize InstanceCentroidFinder with the source `DataPipe."""
self.source_dp = source_dp

def __iter__(self):
"""Add 'centroid' key to sample."""

def find_centroids(sample):
mid_pts = find_points_bbox_midpoint(sample["instance"])
sample["centroid"] = mid_pts

return sample

for sample in self.source_dp:
find_centroids(sample)
yield sample
6 changes: 3 additions & 3 deletions sleap_nn/data/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def from_filename(cls, filename: str):
def __iter__(self):
"""Return a sample containing the following elements.
- a torch.Tensor representing an instance
- a torch.Tensor representing the corresponding image
"image": a torch.Tensor containing full raw frame image.
"instance": A single instance of the corresponding image.
"""
for lf in self.labels:
for inst in lf:
instance = torch.from_numpy(inst.numpy())
image = torch.from_numpy(lf.image)
yield instance, image
yield {"image": image, "instance": instance}
21 changes: 21 additions & 0 deletions tests/test_instance_centroids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from sleap_nn.data.providers import LabelsReader
import torch
from sleap_nn.data.instance_centroids import (
InstanceCentroidFinder,
find_points_bbox_midpoint,
)


def test_instance_centroids(minimal_instance):
"""Test InstanceCentroidFinder
Args:
minimal_instance: minimal_instance testing fixture
"""
datapipe = LabelsReader.from_filename(minimal_instance)
datapipe = InstanceCentroidFinder(datapipe)
sample = next(iter(datapipe))
centroid = sample["centroid"]
centroid = centroid.int()
gt = torch.Tensor([122, 180]).int()
assert torch.equal(centroid, gt)
5 changes: 3 additions & 2 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@


def test_providers(minimal_instance):
"""Test sleap dataset
"""Test LabelsReader
Args:
minimal_instance: minimal_instance testing fixture
"""
l = LabelsReader.from_filename(minimal_instance)
instance, image = next(iter(l))
sample = next(iter(l))
instance, image = sample["instance"], sample["image"]
assert image.shape == torch.Size([384, 384, 1])
assert instance.shape == torch.Size([2, 2])

0 comments on commit 07f4a61

Please sign in to comment.