diff --git a/sleap_nn/architectures/common.py b/sleap_nn/architectures/common.py index 1865e471..839d4fea 100644 --- a/sleap_nn/architectures/common.py +++ b/sleap_nn/architectures/common.py @@ -1,4 +1,6 @@ """Common utilities for architecture and model building.""" +from typing import List + import torch from torch import nn from torch.nn import functional as F @@ -134,7 +136,7 @@ def get_act_fn(activation: str) -> nn.Module: return activations[activation] -def get_children_layers(model: torch.nn.Module): +def get_children_layers(model: torch.nn.Module) -> List[nn.Module]: """Recursively retrieves a flattened list of all children modules and submodules within the given model. Args: diff --git a/sleap_nn/data/augmentation.py b/sleap_nn/data/augmentation.py index 7602981c..6db595c1 100644 --- a/sleap_nn/data/augmentation.py +++ b/sleap_nn/data/augmentation.py @@ -1,5 +1,5 @@ """This module implements data pipeline blocks for augmentation operations.""" -from typing import Any, Dict, Optional, Text, Tuple, Union +from typing import Any, Dict, Iterator, Optional, Text, Tuple, Union import kornia as K import torch @@ -100,9 +100,6 @@ class KorniaAugmenter(IterDataPipe): Attributes: source_dp: The input `IterDataPipe` with examples that contain `"instances"` and `"image"` keys. - crop_hw: Desired output size (out_h, out_w) of the crop. Must be Tuple[int, int], - then out_h = size[0], out_w = size[1]. - crop_p: Probability of applying random crop. rotation: Angles in degrees as a scalar float of the amount of rotation. A random angle in `(-rotation, rotation)` will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation. @@ -125,11 +122,14 @@ class KorniaAugmenter(IterDataPipe): contrast_p: Probability of applying random contrast. brightness: The brightness factor to apply Default: `(1.0, 1.0)`. brightness_p: Probability of applying random brightness. - erase_scale: Range of proportion of erased area against input image. Default: `(0.02, 0.33)`. - erase_ratio: Range of aspect ratio of erased area. Default: `(0.3, 3.3)`. + erase_scale: Range of proportion of erased area against input image. Default: `(0.0001, 0.01)`. + erase_ratio: Range of aspect ratio of erased area. Default: `(1, 1)`. erase_p: Probability of applying random erase. mixup_lambda: min-max value of mixup strength. Default is 0-1. Default: `None`. mixup_p: Probability of applying random mixup v2. + random_crop_hw: Desired output size (out_h, out_w) of the crop. Must be Tuple[int, int], + then out_h = size[0], out_w = size[1]. + random_crop_p: Probability of applying random crop. Notes: This block expects the "image" and "instances" keys to be present in the input @@ -150,24 +150,24 @@ def __init__( rotation: Optional[float] = 15.0, scale: Optional[float] = 0.05, translate: Optional[Tuple[float, float]] = (0.02, 0.02), - affine_p: float = 0.5, + affine_p: float = 0.0, uniform_noise: Optional[Tuple[float, float]] = (0.0, 0.04), - uniform_noise_p: float = 0.5, + uniform_noise_p: float = 0.0, gaussian_noise_mean: Optional[float] = 0.02, gaussian_noise_std: Optional[float] = 0.004, - gaussian_noise_p: float = 0.5, + gaussian_noise_p: float = 0.0, contrast: Optional[Tuple[float, float]] = (0.5, 2.0), - contrast_p: float = 0.5, + contrast_p: float = 0.0, brightness: Optional[float] = 0.0, - brightness_p: float = 0.5, + brightness_p: float = 0.0, erase_scale: Optional[Tuple[float, float]] = (0.0001, 0.01), erase_ratio: Optional[Tuple[float, float]] = (1, 1), - erase_p: float = 0.5, + erase_p: float = 0.0, mixup_lambda: Union[Optional[float], Tuple[float, float], None] = None, - mixup_p: float = 0.5, - crop_hw: Tuple[int, int] = (0, 0), - crop_p: float = 0.0, - ): + mixup_p: float = 0.0, + random_crop_hw: Tuple[int, int] = (0, 0), + random_crop_p: float = 0.0, + ) -> None: """Initialize the block and the augmentation pipeline.""" self.source_dp = source_dp self.rotation = rotation @@ -188,8 +188,8 @@ def __init__( self.erase_p = erase_p self.mixup_lambda = mixup_lambda self.mixup_p = mixup_p - self.crop_hw = crop_hw - self.crop_p = crop_p + self.random_crop_hw = random_crop_hw + self.random_crop_p = random_crop_p aug_stack = [] if self.affine_p > 0: @@ -259,19 +259,21 @@ def __init__( same_on_batch=True, ) ) - if self.crop_p > 0: - if self.crop_hw[0] > 0 and self.crop_hw[1] > 0: + if self.random_crop_p > 0: + if self.random_crop_hw[0] > 0 and self.random_crop_hw[1] > 0: aug_stack.append( K.augmentation.RandomCrop( - size=self.crop_hw, + size=self.random_crop_hw, pad_if_needed=True, - p=self.crop_p, + p=self.random_crop_p, keepdim=True, same_on_batch=True, ) ) else: - raise ValueError(f"crop_hw height and width must be greater than 0.") + raise ValueError( + f"random_crop_hw height and width must be greater than 0." + ) self.augmenter = AugmentationSequential( *aug_stack, @@ -280,12 +282,31 @@ def __init__( same_on_batch=True, ) - def __iter__(self): + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Return an example dictionary with the augmented image and instances.""" for ex in self.source_dp: - inst_shape = ex["instances"].shape # (B, num_instances, num_nodes, 2) - image, instances = ex["image"], ex["instances"].reshape( - inst_shape[0], -1, 2 - ) - aug_image, aug_instances = self.augmenter(image, instances) - yield {"image": aug_image, "instances": aug_instances.reshape(*inst_shape)} + if "instance_image" in ex and "instance" in ex: + inst_shape = ex["instance"].shape + # (B, channels, height, width), (1, num_nodes, 2) + image, instances = ex["instance_image"], ex["instance"].unsqueeze(0) + aug_image, aug_instances = self.augmenter(image, instances) + ex.update( + { + "instance_image": aug_image, + "instance": aug_instances.reshape(*inst_shape), + } + ) + elif "image" in ex and "instances" in ex: + inst_shape = ex["instances"].shape # (B, num_instances, num_nodes, 2) + image, instances = ex["image"], ex["instances"].reshape( + inst_shape[0], -1, 2 + ) # (B, channels, height, width), (B, num_instances x num_nodes, 2) + + aug_image, aug_instances = self.augmenter(image, instances) + ex.update( + { + "image": aug_image, + "instances": aug_instances.reshape(*inst_shape), + } + ) + yield ex diff --git a/sleap_nn/data/confidence_maps.py b/sleap_nn/data/confidence_maps.py index 42b58310..dbdbf722 100644 --- a/sleap_nn/data/confidence_maps.py +++ b/sleap_nn/data/confidence_maps.py @@ -1,5 +1,5 @@ """Generate confidence maps.""" -from typing import Optional +from typing import Dict, Iterator, Optional import sleap_io as sio import torch @@ -10,7 +10,7 @@ def make_confmaps( points: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor, sigma: float -): +) -> torch.Tensor: """Make confidence maps from a set of points from a single instance. Args: @@ -70,7 +70,7 @@ def __init__( output_stride: int = 1, instance_key: str = "instance", image_key: str = "instance_image", - ): + ) -> None: """Initialize ConfidenceMapGenerator with input `DataPipe`, sigma, and output stride.""" self.source_dp = source_dp self.sigma = sigma @@ -78,7 +78,7 @@ def __init__( self.instance_key = instance_key self.image_key = image_key - def __iter__(self): + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Generate confidence maps for each example.""" for example in self.source_dp: instance = example[self.instance_key] diff --git a/sleap_nn/data/general.py b/sleap_nn/data/general.py new file mode 100644 index 00000000..06c9b7de --- /dev/null +++ b/sleap_nn/data/general.py @@ -0,0 +1,39 @@ +"""General purpose transformers for common pipeline processing tasks.""" +from typing import Callable, Dict, Iterator, List, Text + +import torch +from torch.utils.data.datapipes.datapipe import IterDataPipe + + +class KeyFilter(IterDataPipe): + """Transformer for filtering example keys.""" + + def __init__(self, source_dp: IterDataPipe, keep_keys: List[Text] = None) -> None: + """Initialize KeyFilter with the source `DataPipe.""" + self.dp = source_dp + self.keep_keys = set(keep_keys) if keep_keys else None + + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: + """Return a dictionary filtered for the relevant outputs. + + The input dictionary includes: + - image: the full frame image + - instances: all keypoints of all instances in the frame image + - centroids: all centroids of all instances in the frame image + - instance: the individual instance's keypoints + - instance_bbox: the individual instance's bbox + - instance_image: the individual instance's cropped image + - confidence_maps: the individual instance's heatmap + """ + for example in self.dp: + if self.keep_keys is None: + # If keep_keys is not provided, yield the entire example. + yield example + else: + # Filter the example dictionary based on keep_keys. + filtered_example = { + key: value + for key, value in example.items() + if key in self.keep_keys + } + yield filtered_example diff --git a/sleap_nn/data/instance_centroids.py b/sleap_nn/data/instance_centroids.py index b15deb91..6240fd35 100644 --- a/sleap_nn/data/instance_centroids.py +++ b/sleap_nn/data/instance_centroids.py @@ -1,5 +1,5 @@ """Handle calculation of instance centroids.""" -from typing import Optional +from typing import Dict, Iterator, Optional import torch from torch.utils.data.datapipes.datapipe import IterDataPipe @@ -79,15 +79,15 @@ def __init__( self, source_dp: IterDataPipe, anchor_ind: Optional[int] = None, - ): + ) -> None: """Initialize InstanceCentroidFinder with the source `DataPipe.""" self.source_dp = source_dp self.anchor_ind = anchor_ind - def __iter__(self): + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Add `"centroids"` key to example.""" - for example in self.source_dp: - example["centroids"] = find_centroids( - example["instances"], anchor_ind=self.anchor_ind + for ex in self.source_dp: + ex["centroids"] = find_centroids( + ex["instances"], anchor_ind=self.anchor_ind ) - yield example + yield ex diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index 81f02056..6f57599e 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -1,5 +1,5 @@ """Handle cropping of instances.""" -from typing import Optional +from typing import Dict, Iterator, Optional, Tuple import numpy as np import sleap_io as sio @@ -58,50 +58,46 @@ class InstanceCropper(IterDataPipe): Attributes: source_dp: The previous `DataPipe` with samples that contain an `instances` key. - crop_width: Width of the crop in pixels - crop_height: Height of the crop in pixels + crop_hw: Height and Width of the crop in pixels """ - def __init__( - self, - source_dp: IterDataPipe, - crop_width: int, - crop_height: int, - ): + def __init__(self, source_dp: IterDataPipe, crop_hw: Tuple[int, int]) -> None: """Initialize InstanceCropper with the source `DataPipe.""" self.source_dp = source_dp - self.crop_width = crop_width - self.crop_height = crop_height + self.crop_hw = crop_hw - def __iter__(self): + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Generate instance cropped examples.""" - for example in self.source_dp: - image = example["image"] # (frames, channels, height, width) - instances = example["instances"] # (frames, n_instances, n_nodes, 2) - centroids = example["centroids"] # (frames, n_instances, 2) + for ex in self.source_dp: + image = ex["image"] # (B, channels, height, width) + instances = ex["instances"] # (B, n_instances, num_nodes, 2) + centroids = ex["centroids"] # (B, n_instances, 2) for instance, centroid in zip(instances[0], centroids[0]): # Generate bounding boxes from centroid. - bbox = torch.unsqueeze( - make_centered_bboxes(centroid, self.crop_height, self.crop_width), 0 - ) # (frames, 4, 2) + instance_bbox = torch.unsqueeze( + make_centered_bboxes(centroid, self.crop_hw[0], self.crop_hw[1]), 0 + ) # (B, 4, 2) - box_size = (self.crop_height, self.crop_width) + box_size = (self.crop_hw[0], self.crop_hw[1]) - # Generate cropped image of shape (frames, channels, crop_height, crop_width) + # Generate cropped image of shape (B, channels, crop_height, crop_width) instance_image = crop_and_resize( image, - boxes=bbox, + boxes=instance_bbox, size=box_size, ) # Access top left point (x,y) of bounding box and subtract this offset from # position of nodes. - point = bbox[0][0] + point = instance_bbox[0][0] center_instance = instance - point instance_example = { - "instance_image": instance_image, # (frames, channels, crop_height, crop_width) - "bbox": bbox, # (frames, 4, 2) - "instance": center_instance, # (n_instances, 2) + "instance_image": instance_image.squeeze( + 0 + ), # (B=1, channels, crop_height, crop_width) + "instance_bbox": instance_bbox, # (B, 4, 2) + "instance": center_instance, # (num_nodes, 2) } - yield instance_example + ex.update(instance_example) + yield ex diff --git a/sleap_nn/data/normalization.py b/sleap_nn/data/normalization.py index 8c150ebc..ae7a7b6d 100644 --- a/sleap_nn/data/normalization.py +++ b/sleap_nn/data/normalization.py @@ -1,4 +1,6 @@ """This module implements data pipeline blocks for normalization operations.""" +from typing import Dict, Iterator + import torch from torch.utils.data.datapipes.datapipe import IterDataPipe @@ -16,11 +18,11 @@ class Normalizer(IterDataPipe): def __init__( self, source_dp: IterDataPipe, - ): + ) -> None: """Initialize the block.""" self.source_dp = source_dp - def __iter__(self): + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Return an example dictionary with the augmented image and instance.""" for ex in self.source_dp: if not torch.is_floating_point(ex["image"]): diff --git a/sleap_nn/data/pipelines.py b/sleap_nn/data/pipelines.py new file mode 100644 index 00000000..a443d114 --- /dev/null +++ b/sleap_nn/data/pipelines.py @@ -0,0 +1,67 @@ +"""This module defines high level pipeline configurations from providers/transformers. + +This allows for convenient ways to configure individual variants of common pipelines, as +well as to define training vs inference versions based on the same configurations. +""" +import torch +from omegaconf.dictconfig import DictConfig +from torch.utils.data.datapipes.datapipe import IterDataPipe + +from sleap_nn.data.augmentation import KorniaAugmenter +from sleap_nn.data.confidence_maps import ConfidenceMapGenerator +from sleap_nn.data.general import KeyFilter +from sleap_nn.data.instance_centroids import InstanceCentroidFinder +from sleap_nn.data.instance_cropping import InstanceCropper +from sleap_nn.data.normalization import Normalizer + + +class TopdownConfmapsPipeline: + """Pipeline builder for instance-centered confidence map models. + + Attributes: + data_config: Data-related configuration. + """ + + def __init__(self, data_config: DictConfig) -> None: + """Initialize the data config.""" + self.data_config = data_config + + def make_training_pipeline( + self, data_provider: IterDataPipe, filename: str + ) -> IterDataPipe: + """Create training pipeline with input data only. + + Args: + data_provider: A `Provider` that generates data examples, typically a + `LabelsReader` instance. + filename: A string path to the name of the `.slp` file. + + Returns: + An `IterDataPipe` instance configured to produce input examples. + """ + datapipe = data_provider.from_filename(filename=filename) + datapipe = Normalizer(datapipe) + + datapipe = InstanceCentroidFinder(datapipe) + datapipe = InstanceCropper(datapipe, self.data_config.preprocessing.crop_hw) + + if self.data_config.augmentation_config.random_crop.random_crop_p: + datapipe = KorniaAugmenter( + datapipe, + random_crop_hw=self.data_config.augmentation_config.random_crop.random_crop_hw, + random_crop_p=self.data_config.augmentation_config.random_crop.random_crop_p, + ) + + if self.data_config.augmentation_config.use_augmentations: + datapipe = KorniaAugmenter( + datapipe, **dict(self.data_config.augmentation_config.augmentations) + ) + + datapipe = ConfidenceMapGenerator( + datapipe, + sigma=self.data_config.preprocessing.conf_map_gen.sigma, + output_stride=self.data_config.preprocessing.conf_map_gen.output_stride, + ) + datapipe = KeyFilter(datapipe, keep_keys=self.data_config.general.keep_keys) + + return datapipe diff --git a/sleap_nn/data/providers.py b/sleap_nn/data/providers.py index d090775c..6d778102 100644 --- a/sleap_nn/data/providers.py +++ b/sleap_nn/data/providers.py @@ -1,4 +1,6 @@ """This module implements pipeline blocks for reading input data such as labels.""" +from typing import Dict, Iterator + import numpy as np import sleap_io as sio import torch @@ -16,17 +18,17 @@ class LabelsReader(IterDataPipe): accessed through a torchdata DataPipe """ - def __init__(self, labels: sio.Labels): + def __init__(self, labels: sio.Labels) -> None: """Initialize labels attribute of the class.""" self.labels = labels @classmethod - def from_filename(cls, filename: str): + def from_filename(cls, filename: str) -> "LabelsReader": """Create LabelsReader from a .slp filename.""" labels = sio.load_slp(filename) return cls(labels) - def __iter__(self): + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Return an example dictionary containing the following elements. "image": A torch.Tensor containing full raw frame image as a uint8 array diff --git a/tests/data/test_augmentation.py b/tests/data/test_augmentation.py index a9e907a2..a5dbf9a2 100644 --- a/tests/data/test_augmentation.py +++ b/tests/data/test_augmentation.py @@ -47,8 +47,8 @@ def test_kornia_augmentation(minimal_instance): erase_p=1.0, mixup_p=1.0, mixup_lambda=(0.0, 1.0), - crop_hw=(384, 384), - crop_p=1.0, + random_crop_hw=(384, 384), + random_crop_p=1.0, ) # Test all augmentations. @@ -68,6 +68,6 @@ def test_kornia_augmentation(minimal_instance): ): p = KorniaAugmenter( p, - crop_hw=(0, 0), - crop_p=1.0, + random_crop_hw=(0, 0), + random_crop_p=1.0, ) diff --git a/tests/data/test_confmaps.py b/tests/data/test_confmaps.py index 21531009..f7821e6d 100644 --- a/tests/data/test_confmaps.py +++ b/tests/data/test_confmaps.py @@ -12,7 +12,7 @@ def test_confmaps(minimal_instance): datapipe = LabelsReader.from_filename(minimal_instance) datapipe = InstanceCentroidFinder(datapipe) datapipe = Normalizer(datapipe) - datapipe = InstanceCropper(datapipe, 100, 100) + datapipe = InstanceCropper(datapipe, (100, 100)) datapipe1 = ConfidenceMapGenerator(datapipe, sigma=1.5, output_stride=1) sample = next(iter(datapipe1)) diff --git a/tests/data/test_instance_cropping.py b/tests/data/test_instance_cropping.py index 966e5572..8f112991 100644 --- a/tests/data/test_instance_cropping.py +++ b/tests/data/test_instance_cropping.py @@ -26,13 +26,25 @@ def test_instance_cropper(minimal_instance): datapipe = LabelsReader.from_filename(minimal_instance) datapipe = InstanceCentroidFinder(datapipe) datapipe = Normalizer(datapipe) - datapipe = InstanceCropper(datapipe, 100, 100) + datapipe = InstanceCropper(datapipe, (100, 100)) sample = next(iter(datapipe)) + gt_sample_keys = [ + "image", + "instances", + "centroids", + "instance", + "instance_bbox", + "instance_image", + ] + # Test shapes. + assert len(sample.keys()) == len(gt_sample_keys) + for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): + assert gt_key == key assert sample["instance"].shape == (2, 2) - assert sample["instance_image"].shape == (1, 1, 100, 100) - assert sample["bbox"].shape == (1, 4, 2) + assert sample["instance_image"].shape == (1, 100, 100) + assert sample["instance_bbox"].shape == (1, 4, 2) # Test samples. gt = torch.Tensor( diff --git a/tests/data/test_pipelines.py b/tests/data/test_pipelines.py new file mode 100644 index 00000000..15d7ebbf --- /dev/null +++ b/tests/data/test_pipelines.py @@ -0,0 +1,145 @@ +import torch +from omegaconf import OmegaConf + +from sleap_nn.data.confidence_maps import ConfidenceMapGenerator +from sleap_nn.data.general import KeyFilter +from sleap_nn.data.instance_centroids import InstanceCentroidFinder +from sleap_nn.data.instance_cropping import InstanceCropper +from sleap_nn.data.normalization import Normalizer +from sleap_nn.data.pipelines import TopdownConfmapsPipeline +from sleap_nn.data.providers import LabelsReader + + +def test_sleap_dataset(minimal_instance): + datapipe = LabelsReader.from_filename(filename=minimal_instance) + datapipe = Normalizer(datapipe) + datapipe = InstanceCentroidFinder(datapipe) + datapipe = InstanceCropper(datapipe, (160, 160)) + datapipe = ConfidenceMapGenerator(datapipe, sigma=1.5, output_stride=2) + datapipe = KeyFilter(datapipe, keep_keys=None) + + gt_sample_keys = [ + "image", + "instances", + "centroids", + "instance", + "instance_bbox", + "instance_image", + "confidence_maps", + ] + + sample = next(iter(datapipe)) + assert len(sample.keys()) == len(gt_sample_keys) + + for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): + assert gt_key == key + assert sample["instance_image"].shape == (1, 160, 160) + assert sample["confidence_maps"].shape == (2, 80, 80) + + +def test_topdownconfmapspipeline(minimal_instance): + base_topdown_data_config = OmegaConf.create( + { + "general": {"keep_keys": ["instance_image", "confidence_maps"]}, + "preprocessing": { + "crop_hw": (160, 160), + "conf_map_gen": {"sigma": 1.5, "output_stride": 2}, + }, + "augmentation_config": { + "random_crop": {"random_crop_p": 1.0, "random_crop_hw": (160, 160)}, + "use_augmentations": False, + "augmentations": { + "rotation": 15.0, + "scale": 0.05, + "translate": (0.02, 0.02), + "affine_p": 0.5, + "uniform_noise": (0.0, 0.04), + "uniform_noise_p": 0.5, + "gaussian_noise_mean": 0.02, + "gaussian_noise_std": 0.004, + "gaussian_noise_p": 0.5, + "contrast": (0.5, 2.0), + "contrast_p": 0.5, + "brightness": 0.0, + "brightness_p": 0.5, + "erase_scale": (0.0001, 0.01), + "erase_ratio": (1, 1), + "erase_p": 0.5, + "mixup_lambda": None, + "mixup_p": 0.5, + }, + }, + } + ) + + pipeline = TopdownConfmapsPipeline(data_config=base_topdown_data_config) + datapipe = pipeline.make_training_pipeline( + data_provider=LabelsReader, filename=minimal_instance + ) + + gt_sample_keys = ["instance_image", "confidence_maps"] + + sample = next(iter(datapipe)) + assert len(sample.keys()) == len(gt_sample_keys) + + for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): + assert gt_key == key + assert sample["instance_image"].shape == (1, 160, 160) + assert sample["confidence_maps"].shape == (2, 80, 80) + + base_topdown_data_config = OmegaConf.create( + { + "general": {"keep_keys": None}, + "preprocessing": { + "crop_hw": (160, 160), + "conf_map_gen": {"sigma": 1.5, "output_stride": 2}, + }, + "augmentation_config": { + "random_crop": {"random_crop_p": 1.0, "random_crop_hw": (160, 160)}, + "use_augmentations": True, + "augmentations": { + "rotation": 15.0, + "scale": 0.05, + "translate": (0.02, 0.02), + "affine_p": 0.5, + "uniform_noise": (0.0, 0.04), + "uniform_noise_p": 0.5, + "gaussian_noise_mean": 0.02, + "gaussian_noise_std": 0.004, + "gaussian_noise_p": 0.5, + "contrast": (0.5, 2.0), + "contrast_p": 0.5, + "brightness": 0.0, + "brightness_p": 0.5, + "erase_scale": (0.0001, 0.01), + "erase_ratio": (1, 1), + "erase_p": 0.5, + "mixup_lambda": None, + "mixup_p": 0.5, + }, + }, + } + ) + + pipeline = TopdownConfmapsPipeline(data_config=base_topdown_data_config) + datapipe = pipeline.make_training_pipeline( + data_provider=LabelsReader, filename=minimal_instance + ) + + gt_sample_keys = [ + "image", + "instances", + "centroids", + "instance", + "instance_bbox", + "instance_image", + "confidence_maps", + ] + + sample = next(iter(datapipe)) + assert len(sample.keys()) == len(gt_sample_keys) + + for gt_key, key in zip(sorted(gt_sample_keys), sorted(sample.keys())): + assert gt_key == key + assert sample["instance_image"].shape == (1, 160, 160) + assert sample["confidence_maps"].shape == (2, 80, 80)