Skip to content

Commit

Permalink
added key filter transformer block, removed sleap dataset, added type…
Browse files Browse the repository at this point in the history
… hinting
  • Loading branch information
alckasoc committed Sep 12, 2023
1 parent ae523d1 commit b421937
Show file tree
Hide file tree
Showing 11 changed files with 123 additions and 63 deletions.
4 changes: 3 additions & 1 deletion sleap_nn/architectures/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Common utilities for architecture and model building."""
from typing import List

import torch
from torch import nn
from torch.nn import functional as F
Expand Down Expand Up @@ -134,7 +136,7 @@ def get_act_fn(activation: str) -> nn.Module:
return activations[activation]


def get_children_layers(model: torch.nn.Module):
def get_children_layers(model: torch.nn.Module) -> List[nn.Module]:
"""Recursively retrieves a flattened list of all children modules and submodules within the given model.
Args:
Expand Down
4 changes: 2 additions & 2 deletions sleap_nn/data/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __init__(
mixup_p: float = 0.0,
random_crop_hw: Tuple[int, int] = (0, 0),
random_crop_p: float = 0.0,
):
) -> None:
"""Initialize the block and the augmentation pipeline."""
self.source_dp = source_dp
self.rotation = rotation
Expand Down Expand Up @@ -282,7 +282,7 @@ def __init__(
same_on_batch=True,
)

def __iter__(self):
def __iter__(self) -> Dict[str, torch.Tensor]:
"""Return an example dictionary with the augmented image and instances."""
for ex in self.source_dp:
if "instance_image" in ex and "instance" in ex:
Expand Down
8 changes: 4 additions & 4 deletions sleap_nn/data/confidence_maps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Generate confidence maps."""
from typing import Optional
from typing import Dict, Optional

import sleap_io as sio
import torch
Expand All @@ -10,7 +10,7 @@

def make_confmaps(
points: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor, sigma: float
):
) -> torch.Tensor:
"""Make confidence maps from a set of points from a single instance.
Args:
Expand Down Expand Up @@ -70,15 +70,15 @@ def __init__(
output_stride: int = 1,
instance_key: str = "instance",
image_key: str = "instance_image",
):
) -> None:
"""Initialize ConfidenceMapGenerator with input `DataPipe`, sigma, and output stride."""
self.source_dp = source_dp
self.sigma = sigma
self.output_stride = output_stride
self.instance_key = instance_key
self.image_key = image_key

def __iter__(self):
def __iter__(self) -> Dict[str, torch.Tensor]:
"""Generate confidence maps for each example."""
for example in self.source_dp:
instance = example[self.instance_key]
Expand Down
38 changes: 38 additions & 0 deletions sleap_nn/data/general.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""General purpose transformers for common pipeline processing tasks."""
from typing import Callable, Dict, List, Text

from torch.utils.data.datapipes.datapipe import IterDataPipe


class KeyFilter(IterDataPipe):
"""Transformer for filtering example keys."""

def __init__(self, source_dp: IterDataPipe, keep_keys: List[Text] = None) -> None:
"""Initialize KeyFilter with the source `DataPipe."""
self.dp = source_dp
self.keep_keys = keep_keys

def __iter__(self):
"""Return a dictionary filtered for the relevant outputs.
The input dictionary includes:
- image: the full frame image
- instances: all keypoints of all instances in the frame image
- centroids: all centroids of all instances in the frame image
- instance: the individual instance's keypoints
- instance_bbox: the individual instance's bbox
- instance_image: the individual instance's cropped image
- confidence_maps: the individual instance's heatmap
"""
for example in self.dp:
if self.keep_keys is None:
# If keep_keys is not provided, yield the entire example.
yield example
else:
# Filter the example dictionary based on keep_keys.
filtered_example = {
key: value
for key, value in example.items()
if key in self.keep_keys
}
yield filtered_example
6 changes: 3 additions & 3 deletions sleap_nn/data/instance_centroids.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Handle calculation of instance centroids."""
from typing import Optional
from typing import Dict, Optional

import torch
from torch.utils.data.datapipes.datapipe import IterDataPipe
Expand Down Expand Up @@ -79,12 +79,12 @@ def __init__(
self,
source_dp: IterDataPipe,
anchor_ind: Optional[int] = None,
):
) -> None:
"""Initialize InstanceCentroidFinder with the source `DataPipe."""
self.source_dp = source_dp
self.anchor_ind = anchor_ind

def __iter__(self):
def __iter__(self) -> Dict[str, torch.Tensor]:
"""Add `"centroids"` key to example."""
for ex in self.source_dp:
ex["centroids"] = find_centroids(
Expand Down
7 changes: 5 additions & 2 deletions sleap_nn/data/instance_cropping.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@ def __iter__(self):
center_instance = instance - point

instance_example = {
"instance_image": instance_image, # (B, channels, crop_height, crop_width)
"instance_image": instance_image.squeeze(
0
), # (B=1, channels, crop_height, crop_width)
"instance_bbox": instance_bbox, # (B, 4, 2)
"instance": center_instance, # (num_nodes, 2)
}
yield instance_example
ex.update(instance_example)
yield ex
6 changes: 4 additions & 2 deletions sleap_nn/data/normalization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""This module implements data pipeline blocks for normalization operations."""
from typing import Dict

import torch
from torch.utils.data.datapipes.datapipe import IterDataPipe

Expand All @@ -16,11 +18,11 @@ class Normalizer(IterDataPipe):
def __init__(
self,
source_dp: IterDataPipe,
):
) -> None:
"""Initialize the block."""
self.source_dp = source_dp

def __iter__(self):
def __iter__(self) -> Dict[str, torch.Tensor]:
"""Return an example dictionary with the augmented image and instance."""
for ex in self.source_dp:
if not torch.is_floating_point(ex["image"]):
Expand Down
36 changes: 2 additions & 34 deletions sleap_nn/data/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,10 @@

from sleap_nn.data.augmentation import KorniaAugmenter
from sleap_nn.data.confidence_maps import ConfidenceMapGenerator
from sleap_nn.data.general import KeyFilter
from sleap_nn.data.instance_centroids import InstanceCentroidFinder
from sleap_nn.data.instance_cropping import InstanceCropper
from sleap_nn.data.normalization import Normalizer
from sleap_nn.data.providers import LabelsReader


class SleapDataset(IterDataPipe):
"""Returns image and corresponding heatmap for the DataLoader.
This class is to return the image and its corresponding confidence map
to load the dataset with the DataLoader class
Attributes:
source_dp: The previous `DataPipe` with samples that contain an `instances` key.
"""

def __init__(self, source_dp: IterDataPipe):
"""Initialize SleapDataset with the source `DataPipe."""
self.dp = source_dp

def __iter__(self):
"""Return a dictionary with the relevant outputs.
This dictionary includes:
- image: the full frame image
- instances: all keypoints of all instances in the frame image
- centroids: all centroids of all instances in the frame image
- instance: the individual instance's keypoints
- instance_bbox: the individual instance's bbox
- instance_image: the individual instance's cropped image
- confidence_maps: the individual instance's heatmap
"""
for example in self.dp:
if len(example["instance_image"].shape) == 4:
example["instance_image"] = example["instance_image"].squeeze(dim=0)
yield example


class TopdownConfmapsPipeline:
Expand Down Expand Up @@ -94,6 +62,6 @@ def make_training_pipeline(
sigma=self.data_config.preprocessing.conf_map_gen.sigma,
output_stride=self.data_config.preprocessing.conf_map_gen.output_stride,
)
datapipe = SleapDataset(datapipe)
datapipe = KeyFilter(datapipe, keep_keys=self.data_config.general.keep_keys)

return datapipe
7 changes: 4 additions & 3 deletions sleap_nn/data/providers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""This module implements pipeline blocks for reading input data such as labels."""
from typing import Dict
import numpy as np
import sleap_io as sio
import torch
Expand All @@ -16,17 +17,17 @@ class LabelsReader(IterDataPipe):
accessed through a torchdata DataPipe
"""

def __init__(self, labels: sio.Labels):
def __init__(self, labels: sio.Labels) -> None:
"""Initialize labels attribute of the class."""
self.labels = labels

@classmethod
def from_filename(cls, filename: str):
def from_filename(cls, filename: str) -> "LabelsReader":
"""Create LabelsReader from a .slp filename."""
labels = sio.load_slp(filename)
return cls(labels)

def __iter__(self):
def __iter__(self) -> Dict[str, torch.Tensor]:
"""Return an example dictionary containing the following elements.
"image": A torch.Tensor containing full raw frame image as a uint8 array
Expand Down
14 changes: 13 additions & 1 deletion tests/data/test_instance_cropping.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,21 @@ def test_instance_cropper(minimal_instance):
datapipe = InstanceCropper(datapipe, (100, 100))
sample = next(iter(datapipe))

gt_sample_keys = [
"image",
"instances",
"centroids",
"instance",
"instance_bbox",
"instance_image",
]

# Test shapes.
assert len(sample.keys()) == 6
for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())):
assert gt_key == key
assert sample["instance"].shape == (2, 2)
assert sample["instance_image"].shape == (1, 1, 100, 100)
assert sample["instance_image"].shape == (1, 100, 100)
assert sample["instance_bbox"].shape == (1, 4, 2)

# Test samples.
Expand Down
56 changes: 45 additions & 11 deletions tests/data/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from omegaconf import OmegaConf

from sleap_nn.data.confidence_maps import ConfidenceMapGenerator
from sleap_nn.data.general import KeyFilter
from sleap_nn.data.instance_centroids import InstanceCentroidFinder
from sleap_nn.data.instance_cropping import InstanceCropper
from sleap_nn.data.normalization import Normalizer
from sleap_nn.data.pipelines import SleapDataset, TopdownConfmapsPipeline
from sleap_nn.data.pipelines import TopdownConfmapsPipeline
from sleap_nn.data.providers import LabelsReader


Expand All @@ -15,17 +16,31 @@ def test_sleap_dataset(minimal_instance):
datapipe = InstanceCentroidFinder(datapipe)
datapipe = InstanceCropper(datapipe, (160, 160))
datapipe = ConfidenceMapGenerator(datapipe, sigma=1.5, output_stride=2)
datapipe = SleapDataset(datapipe)
datapipe = KeyFilter(datapipe, keep_keys=None)

gt_sample_keys = [
"image",
"instances",
"centroids",
"instance",
"instance_bbox",
"instance_image",
"confidence_maps",
]

sample = next(iter(datapipe))
assert len(sample) == 2
assert sample[0].shape == (1, 160, 160)
assert sample[1].shape == (2, 80, 80)
assert len(sample.keys()) == len(gt_sample_keys)

for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())):
assert gt_key == key
assert sample["instance_image"].shape == (1, 160, 160)
assert sample["confidence_maps"].shape == (2, 80, 80)


def test_topdownconfmapspipeline(minimal_instance):
base_topdown_data_config = OmegaConf.create(
{
"general": {"keep_keys": ["instance_image", "confidence_maps"]},
"preprocessing": {
"crop_hw": (160, 160),
"conf_map_gen": {"sigma": 1.5, "output_stride": 2},
Expand Down Expand Up @@ -62,13 +77,19 @@ def test_topdownconfmapspipeline(minimal_instance):
data_provider=LabelsReader, filename=minimal_instance
)

gt_sample_keys = ["instance_image", "confidence_maps"]

sample = next(iter(datapipe))
assert len(sample) == 2
assert sample[0].shape == (1, 160, 160)
assert sample[1].shape == (2, 80, 80)
assert len(sample.keys()) == len(gt_sample_keys)

for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())):
assert gt_key == key
assert sample["instance_image"].shape == (1, 160, 160)
assert sample["confidence_maps"].shape == (2, 80, 80)

base_topdown_data_config = OmegaConf.create(
{
"general": {"keep_keys": None},
"preprocessing": {
"crop_hw": (160, 160),
"conf_map_gen": {"sigma": 1.5, "output_stride": 2},
Expand Down Expand Up @@ -105,7 +126,20 @@ def test_topdownconfmapspipeline(minimal_instance):
data_provider=LabelsReader, filename=minimal_instance
)

gt_sample_keys = [
"image",
"instances",
"centroids",
"instance",
"instance_bbox",
"instance_image",
"confidence_maps",
]

sample = next(iter(datapipe))
assert len(sample) == 2
assert sample[0].shape == (1, 160, 160)
assert sample[1].shape == (2, 80, 80)
assert len(sample.keys()) == len(gt_sample_keys)

for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())):
assert gt_key == key
assert sample["instance_image"].shape == (1, 160, 160)
assert sample["confidence_maps"].shape == (2, 80, 80)

0 comments on commit b421937

Please sign in to comment.