diff --git a/sleap_nn/architectures/common.py b/sleap_nn/architectures/common.py index 1865e471..839d4fea 100644 --- a/sleap_nn/architectures/common.py +++ b/sleap_nn/architectures/common.py @@ -1,4 +1,6 @@ """Common utilities for architecture and model building.""" +from typing import List + import torch from torch import nn from torch.nn import functional as F @@ -134,7 +136,7 @@ def get_act_fn(activation: str) -> nn.Module: return activations[activation] -def get_children_layers(model: torch.nn.Module): +def get_children_layers(model: torch.nn.Module) -> List[nn.Module]: """Recursively retrieves a flattened list of all children modules and submodules within the given model. Args: diff --git a/sleap_nn/data/augmentation.py b/sleap_nn/data/augmentation.py index b445bd6a..d38d7910 100644 --- a/sleap_nn/data/augmentation.py +++ b/sleap_nn/data/augmentation.py @@ -167,7 +167,7 @@ def __init__( mixup_p: float = 0.0, random_crop_hw: Tuple[int, int] = (0, 0), random_crop_p: float = 0.0, - ): + ) -> None: """Initialize the block and the augmentation pipeline.""" self.source_dp = source_dp self.rotation = rotation @@ -282,7 +282,7 @@ def __init__( same_on_batch=True, ) - def __iter__(self): + def __iter__(self) -> Dict[str, torch.Tensor]: """Return an example dictionary with the augmented image and instances.""" for ex in self.source_dp: if "instance_image" in ex and "instance" in ex: diff --git a/sleap_nn/data/confidence_maps.py b/sleap_nn/data/confidence_maps.py index 42b58310..deae59a4 100644 --- a/sleap_nn/data/confidence_maps.py +++ b/sleap_nn/data/confidence_maps.py @@ -1,5 +1,5 @@ """Generate confidence maps.""" -from typing import Optional +from typing import Dict, Optional import sleap_io as sio import torch @@ -10,7 +10,7 @@ def make_confmaps( points: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor, sigma: float -): +) -> torch.Tensor: """Make confidence maps from a set of points from a single instance. Args: @@ -70,7 +70,7 @@ def __init__( output_stride: int = 1, instance_key: str = "instance", image_key: str = "instance_image", - ): + ) -> None: """Initialize ConfidenceMapGenerator with input `DataPipe`, sigma, and output stride.""" self.source_dp = source_dp self.sigma = sigma @@ -78,7 +78,7 @@ def __init__( self.instance_key = instance_key self.image_key = image_key - def __iter__(self): + def __iter__(self) -> Dict[str, torch.Tensor]: """Generate confidence maps for each example.""" for example in self.source_dp: instance = example[self.instance_key] diff --git a/sleap_nn/data/general.py b/sleap_nn/data/general.py new file mode 100644 index 00000000..251b2bd0 --- /dev/null +++ b/sleap_nn/data/general.py @@ -0,0 +1,38 @@ +"""General purpose transformers for common pipeline processing tasks.""" +from typing import Callable, Dict, List, Text + +from torch.utils.data.datapipes.datapipe import IterDataPipe + + +class KeyFilter(IterDataPipe): + """Transformer for filtering example keys.""" + + def __init__(self, source_dp: IterDataPipe, keep_keys: List[Text] = None) -> None: + """Initialize KeyFilter with the source `DataPipe.""" + self.dp = source_dp + self.keep_keys = keep_keys + + def __iter__(self): + """Return a dictionary filtered for the relevant outputs. + + The input dictionary includes: + - image: the full frame image + - instances: all keypoints of all instances in the frame image + - centroids: all centroids of all instances in the frame image + - instance: the individual instance's keypoints + - instance_bbox: the individual instance's bbox + - instance_image: the individual instance's cropped image + - confidence_maps: the individual instance's heatmap + """ + for example in self.dp: + if self.keep_keys is None: + # If keep_keys is not provided, yield the entire example. + yield example + else: + # Filter the example dictionary based on keep_keys. + filtered_example = { + key: value + for key, value in example.items() + if key in self.keep_keys + } + yield filtered_example diff --git a/sleap_nn/data/instance_centroids.py b/sleap_nn/data/instance_centroids.py index 6a2596db..f32cba52 100644 --- a/sleap_nn/data/instance_centroids.py +++ b/sleap_nn/data/instance_centroids.py @@ -1,5 +1,5 @@ """Handle calculation of instance centroids.""" -from typing import Optional +from typing import Dict, Optional import torch from torch.utils.data.datapipes.datapipe import IterDataPipe @@ -79,12 +79,12 @@ def __init__( self, source_dp: IterDataPipe, anchor_ind: Optional[int] = None, - ): + ) -> None: """Initialize InstanceCentroidFinder with the source `DataPipe.""" self.source_dp = source_dp self.anchor_ind = anchor_ind - def __iter__(self): + def __iter__(self) -> Dict[str, torch.Tensor]: """Add `"centroids"` key to example.""" for ex in self.source_dp: ex["centroids"] = find_centroids( diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index 3682505f..f53d0c73 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -93,8 +93,11 @@ def __iter__(self): center_instance = instance - point instance_example = { - "instance_image": instance_image, # (B, channels, crop_height, crop_width) + "instance_image": instance_image.squeeze( + 0 + ), # (B=1, channels, crop_height, crop_width) "instance_bbox": instance_bbox, # (B, 4, 2) "instance": center_instance, # (num_nodes, 2) } - yield instance_example + ex.update(instance_example) + yield ex diff --git a/sleap_nn/data/normalization.py b/sleap_nn/data/normalization.py index 8c150ebc..1b6f61c7 100644 --- a/sleap_nn/data/normalization.py +++ b/sleap_nn/data/normalization.py @@ -1,4 +1,6 @@ """This module implements data pipeline blocks for normalization operations.""" +from typing import Dict + import torch from torch.utils.data.datapipes.datapipe import IterDataPipe @@ -16,11 +18,11 @@ class Normalizer(IterDataPipe): def __init__( self, source_dp: IterDataPipe, - ): + ) -> None: """Initialize the block.""" self.source_dp = source_dp - def __iter__(self): + def __iter__(self) -> Dict[str, torch.Tensor]: """Return an example dictionary with the augmented image and instance.""" for ex in self.source_dp: if not torch.is_floating_point(ex["image"]): diff --git a/sleap_nn/data/pipelines.py b/sleap_nn/data/pipelines.py index b75d7cc3..a443d114 100644 --- a/sleap_nn/data/pipelines.py +++ b/sleap_nn/data/pipelines.py @@ -9,42 +9,10 @@ from sleap_nn.data.augmentation import KorniaAugmenter from sleap_nn.data.confidence_maps import ConfidenceMapGenerator +from sleap_nn.data.general import KeyFilter from sleap_nn.data.instance_centroids import InstanceCentroidFinder from sleap_nn.data.instance_cropping import InstanceCropper from sleap_nn.data.normalization import Normalizer -from sleap_nn.data.providers import LabelsReader - - -class SleapDataset(IterDataPipe): - """Returns image and corresponding heatmap for the DataLoader. - - This class is to return the image and its corresponding confidence map - to load the dataset with the DataLoader class - - Attributes: - source_dp: The previous `DataPipe` with samples that contain an `instances` key. - """ - - def __init__(self, source_dp: IterDataPipe): - """Initialize SleapDataset with the source `DataPipe.""" - self.dp = source_dp - - def __iter__(self): - """Return a dictionary with the relevant outputs. - - This dictionary includes: - - image: the full frame image - - instances: all keypoints of all instances in the frame image - - centroids: all centroids of all instances in the frame image - - instance: the individual instance's keypoints - - instance_bbox: the individual instance's bbox - - instance_image: the individual instance's cropped image - - confidence_maps: the individual instance's heatmap - """ - for example in self.dp: - if len(example["instance_image"].shape) == 4: - example["instance_image"] = example["instance_image"].squeeze(dim=0) - yield example class TopdownConfmapsPipeline: @@ -94,6 +62,6 @@ def make_training_pipeline( sigma=self.data_config.preprocessing.conf_map_gen.sigma, output_stride=self.data_config.preprocessing.conf_map_gen.output_stride, ) - datapipe = SleapDataset(datapipe) + datapipe = KeyFilter(datapipe, keep_keys=self.data_config.general.keep_keys) return datapipe diff --git a/sleap_nn/data/providers.py b/sleap_nn/data/providers.py index d090775c..f8ce2121 100644 --- a/sleap_nn/data/providers.py +++ b/sleap_nn/data/providers.py @@ -1,4 +1,5 @@ """This module implements pipeline blocks for reading input data such as labels.""" +from typing import Dict import numpy as np import sleap_io as sio import torch @@ -16,17 +17,17 @@ class LabelsReader(IterDataPipe): accessed through a torchdata DataPipe """ - def __init__(self, labels: sio.Labels): + def __init__(self, labels: sio.Labels) -> None: """Initialize labels attribute of the class.""" self.labels = labels @classmethod - def from_filename(cls, filename: str): + def from_filename(cls, filename: str) -> "LabelsReader": """Create LabelsReader from a .slp filename.""" labels = sio.load_slp(filename) return cls(labels) - def __iter__(self): + def __iter__(self) -> Dict[str, torch.Tensor]: """Return an example dictionary containing the following elements. "image": A torch.Tensor containing full raw frame image as a uint8 array diff --git a/tests/data/test_instance_cropping.py b/tests/data/test_instance_cropping.py index 78114992..662f709a 100644 --- a/tests/data/test_instance_cropping.py +++ b/tests/data/test_instance_cropping.py @@ -29,9 +29,21 @@ def test_instance_cropper(minimal_instance): datapipe = InstanceCropper(datapipe, (100, 100)) sample = next(iter(datapipe)) + gt_sample_keys = [ + "image", + "instances", + "centroids", + "instance", + "instance_bbox", + "instance_image", + ] + # Test shapes. + assert len(sample.keys()) == 6 + for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): + assert gt_key == key assert sample["instance"].shape == (2, 2) - assert sample["instance_image"].shape == (1, 1, 100, 100) + assert sample["instance_image"].shape == (1, 100, 100) assert sample["instance_bbox"].shape == (1, 4, 2) # Test samples. diff --git a/tests/data/test_pipelines.py b/tests/data/test_pipelines.py index 572c39c3..15d7ebbf 100644 --- a/tests/data/test_pipelines.py +++ b/tests/data/test_pipelines.py @@ -2,10 +2,11 @@ from omegaconf import OmegaConf from sleap_nn.data.confidence_maps import ConfidenceMapGenerator +from sleap_nn.data.general import KeyFilter from sleap_nn.data.instance_centroids import InstanceCentroidFinder from sleap_nn.data.instance_cropping import InstanceCropper from sleap_nn.data.normalization import Normalizer -from sleap_nn.data.pipelines import SleapDataset, TopdownConfmapsPipeline +from sleap_nn.data.pipelines import TopdownConfmapsPipeline from sleap_nn.data.providers import LabelsReader @@ -15,17 +16,31 @@ def test_sleap_dataset(minimal_instance): datapipe = InstanceCentroidFinder(datapipe) datapipe = InstanceCropper(datapipe, (160, 160)) datapipe = ConfidenceMapGenerator(datapipe, sigma=1.5, output_stride=2) - datapipe = SleapDataset(datapipe) + datapipe = KeyFilter(datapipe, keep_keys=None) + + gt_sample_keys = [ + "image", + "instances", + "centroids", + "instance", + "instance_bbox", + "instance_image", + "confidence_maps", + ] sample = next(iter(datapipe)) - assert len(sample) == 2 - assert sample[0].shape == (1, 160, 160) - assert sample[1].shape == (2, 80, 80) + assert len(sample.keys()) == len(gt_sample_keys) + + for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): + assert gt_key == key + assert sample["instance_image"].shape == (1, 160, 160) + assert sample["confidence_maps"].shape == (2, 80, 80) def test_topdownconfmapspipeline(minimal_instance): base_topdown_data_config = OmegaConf.create( { + "general": {"keep_keys": ["instance_image", "confidence_maps"]}, "preprocessing": { "crop_hw": (160, 160), "conf_map_gen": {"sigma": 1.5, "output_stride": 2}, @@ -62,13 +77,19 @@ def test_topdownconfmapspipeline(minimal_instance): data_provider=LabelsReader, filename=minimal_instance ) + gt_sample_keys = ["instance_image", "confidence_maps"] + sample = next(iter(datapipe)) - assert len(sample) == 2 - assert sample[0].shape == (1, 160, 160) - assert sample[1].shape == (2, 80, 80) + assert len(sample.keys()) == len(gt_sample_keys) + + for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): + assert gt_key == key + assert sample["instance_image"].shape == (1, 160, 160) + assert sample["confidence_maps"].shape == (2, 80, 80) base_topdown_data_config = OmegaConf.create( { + "general": {"keep_keys": None}, "preprocessing": { "crop_hw": (160, 160), "conf_map_gen": {"sigma": 1.5, "output_stride": 2}, @@ -105,7 +126,20 @@ def test_topdownconfmapspipeline(minimal_instance): data_provider=LabelsReader, filename=minimal_instance ) + gt_sample_keys = [ + "image", + "instances", + "centroids", + "instance", + "instance_bbox", + "instance_image", + "confidence_maps", + ] + sample = next(iter(datapipe)) - assert len(sample) == 2 - assert sample[0].shape == (1, 160, 160) - assert sample[1].shape == (2, 80, 80) + assert len(sample.keys()) == len(gt_sample_keys) + + for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): + assert gt_key == key + assert sample["instance_image"].shape == (1, 160, 160) + assert sample["confidence_maps"].shape == (2, 80, 80)