Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Top-down Centered-instance Pipeline #16

Merged
merged 61 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
2dfcd90
added make_centered_bboxes & normalize_bboxes
alckasoc Aug 3, 2023
1088e7f
added make_centered_bboxes & normalize_bboxes
alckasoc Aug 3, 2023
2d0a009
created test_instance_cropping.py
alckasoc Aug 3, 2023
02ea629
added test normalize bboxes; added find_global_peaks_rough
alckasoc Aug 6, 2023
711b3aa
black formatted
alckasoc Aug 6, 2023
3a0cfb7
fixed merges
alckasoc Aug 6, 2023
9a728aa
black formatted peak_finding
alckasoc Aug 6, 2023
e84535f
added make_grid_vectors, normalize_bboxes, integral_regression, added…
alckasoc Aug 10, 2023
36f6573
finished find_global_peaks with integral regression over centroid crops!
alckasoc Aug 10, 2023
b17af28
reformatted with pydocstyle & black
alckasoc Aug 10, 2023
3ea75ae
Merge remote-tracking branch 'origin/main' into vincent/find_peaks
alckasoc Aug 10, 2023
a506579
moved make_grid_vectors to data/utils
alckasoc Aug 10, 2023
02babb1
removed normalize_bboxes
alckasoc Aug 10, 2023
373f4b1
added tests docstrings
alckasoc Aug 10, 2023
6351314
sorted imports with isort
alckasoc Aug 10, 2023
008a994
remove unused imports
alckasoc Aug 10, 2023
b45619c
updated test cases for instance cropping
alckasoc Aug 10, 2023
381a49f
added minimal_cms.pt fixture + unit tests
alckasoc Aug 11, 2023
0ad336c
added minimal_bboxes fixture; added unit tests for crop_bboxes & inte…
alckasoc Aug 11, 2023
da1ba7e
added find_global_peaks unit tests
alckasoc Aug 11, 2023
7778512
finished find_local_peaks_rough!
alckasoc Aug 17, 2023
9f7ac3f
finished find_local_peaks!
alckasoc Aug 17, 2023
b9869d6
added unit tests for find_local_peaks and find_local_peaks_rough
alckasoc Aug 17, 2023
bfd1cac
updated test cases
alckasoc Aug 17, 2023
a8b3c31
added more test cases for find_local_peaks
alckasoc Aug 17, 2023
125625d
updated test cases
alckasoc Aug 17, 2023
a25d920
added architectures folder
alckasoc Aug 17, 2023
3ba92b6
added maxpool2d same padding, get_act_fn; added simpleconvblock, simp…
alckasoc Aug 17, 2023
f9558f2
added test_unet_reference
alckasoc Aug 18, 2023
28d57ca
black formatted common.py & test_unet.py
alckasoc Aug 18, 2023
8ca4538
fixed merge conflicts
alckasoc Aug 18, 2023
6df3c20
Merge branch 'main' into vincent/unet
alckasoc Aug 18, 2023
c4792a6
Merge branch 'vincent/unet' of https://github.com/talmolab/sleap-nn i…
alckasoc Aug 18, 2023
87cd034
deleted tmp nb
alckasoc Aug 18, 2023
7004869
_calc_same_pad returns int
alckasoc Aug 19, 2023
680778d
fixed test case
alckasoc Aug 19, 2023
7cd75dc
added simpleconvblock tests
alckasoc Aug 19, 2023
79b535d
added tests
alckasoc Aug 19, 2023
691af45
added tests for simple upsampling block
alckasoc Aug 19, 2023
2520fa2
updated test_unet
alckasoc Aug 28, 2023
bcf4069
removed unnecessary variables
alckasoc Aug 30, 2023
dbccdcf
updated augmentation random erase default values
alckasoc Aug 30, 2023
029a545
created data/pipelines.py
alckasoc Aug 30, 2023
3e5ae68
added base config in config/data; temporary till config system settled
alckasoc Aug 31, 2023
1b8002b
updated variable defaults to 0 and edited variable names in augmentation
alckasoc Aug 31, 2023
f1c64f4
updated parameter names in data/instance_cropping
alckasoc Aug 31, 2023
2a22674
added data/pipelines topdown pipeline make_base_pipeline
alckasoc Aug 31, 2023
f3ddf2f
added test_pipelines
alckasoc Aug 31, 2023
c861c72
removed configs
alckasoc Sep 5, 2023
31aadc1
updated augmentation class
alckasoc Sep 6, 2023
6630155
modified test
alckasoc Sep 6, 2023
55cf1a9
updated pipelines docstring
alckasoc Sep 6, 2023
9715c01
removed make_base_pipeline and updated tests
alckasoc Sep 6, 2023
7deec65
removed empty_cache in SleapDataset
alckasoc Sep 6, 2023
e32a0a9
Merge branch 'main' into vincent/topdownpipeline
alckasoc Sep 7, 2023
b1ef93c
updated test_pipelines
alckasoc Sep 7, 2023
ae523d1
updated sleapdataset to return a dict
alckasoc Sep 12, 2023
b421937
added key filter transformer block, removed sleap dataset, added type…
alckasoc Sep 12, 2023
fe61f15
updated type hints
alckasoc Sep 12, 2023
0214abb
added coderabbit suggestions
alckasoc Sep 12, 2023
e3b28da
fixed small squeeze issue
alckasoc Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 48 additions & 27 deletions sleap_nn/data/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,6 @@ class KorniaAugmenter(IterDataPipe):
Attributes:
source_dp: The input `IterDataPipe` with examples that contain `"instances"` and
`"image"` keys.
crop_hw: Desired output size (out_h, out_w) of the crop. Must be Tuple[int, int],
then out_h = size[0], out_w = size[1].
crop_p: Probability of applying random crop.
rotation: Angles in degrees as a scalar float of the amount of rotation. A
random angle in `(-rotation, rotation)` will be sampled and applied to both
images and keypoints. Set to 0 to disable rotation augmentation.
Expand All @@ -125,11 +122,14 @@ class KorniaAugmenter(IterDataPipe):
contrast_p: Probability of applying random contrast.
brightness: The brightness factor to apply Default: `(1.0, 1.0)`.
brightness_p: Probability of applying random brightness.
erase_scale: Range of proportion of erased area against input image. Default: `(0.02, 0.33)`.
erase_ratio: Range of aspect ratio of erased area. Default: `(0.3, 3.3)`.
erase_scale: Range of proportion of erased area against input image. Default: `(0.0001, 0.01)`.
erase_ratio: Range of aspect ratio of erased area. Default: `(1, 1)`.
erase_p: Probability of applying random erase.
mixup_lambda: min-max value of mixup strength. Default is 0-1. Default: `None`.
mixup_p: Probability of applying random mixup v2.
random_crop_hw: Desired output size (out_h, out_w) of the crop. Must be Tuple[int, int],
then out_h = size[0], out_w = size[1].
random_crop_p: Probability of applying random crop.

Notes:
This block expects the "image" and "instances" keys to be present in the input
Expand All @@ -150,23 +150,23 @@ def __init__(
rotation: Optional[float] = 15.0,
scale: Optional[float] = 0.05,
translate: Optional[Tuple[float, float]] = (0.02, 0.02),
affine_p: float = 0.5,
affine_p: float = 0.0,
uniform_noise: Optional[Tuple[float, float]] = (0.0, 0.04),
uniform_noise_p: float = 0.5,
uniform_noise_p: float = 0.0,
gaussian_noise_mean: Optional[float] = 0.02,
gaussian_noise_std: Optional[float] = 0.004,
gaussian_noise_p: float = 0.5,
gaussian_noise_p: float = 0.0,
contrast: Optional[Tuple[float, float]] = (0.5, 2.0),
contrast_p: float = 0.5,
contrast_p: float = 0.0,
brightness: Optional[float] = 0.0,
brightness_p: float = 0.5,
brightness_p: float = 0.0,
erase_scale: Optional[Tuple[float, float]] = (0.0001, 0.01),
erase_ratio: Optional[Tuple[float, float]] = (1, 1),
erase_p: float = 0.5,
erase_p: float = 0.0,
mixup_lambda: Union[Optional[float], Tuple[float, float], None] = None,
mixup_p: float = 0.5,
crop_hw: Tuple[int, int] = (0, 0),
crop_p: float = 0.0,
mixup_p: float = 0.0,
random_crop_hw: Tuple[int, int] = (0, 0),
random_crop_p: float = 0.0,
):
"""Initialize the block and the augmentation pipeline."""
self.source_dp = source_dp
Expand All @@ -188,8 +188,8 @@ def __init__(
self.erase_p = erase_p
self.mixup_lambda = mixup_lambda
self.mixup_p = mixup_p
self.crop_hw = crop_hw
self.crop_p = crop_p
self.random_crop_hw = random_crop_hw
self.random_crop_p = random_crop_p

aug_stack = []
if self.affine_p > 0:
Expand Down Expand Up @@ -259,19 +259,21 @@ def __init__(
same_on_batch=True,
)
)
if self.crop_p > 0:
if self.crop_hw[0] > 0 and self.crop_hw[1] > 0:
if self.random_crop_p > 0:
if self.random_crop_hw[0] > 0 and self.random_crop_hw[1] > 0:
aug_stack.append(
K.augmentation.RandomCrop(
size=self.crop_hw,
size=self.random_crop_hw,
pad_if_needed=True,
p=self.crop_p,
p=self.random_crop_p,
keepdim=True,
same_on_batch=True,
)
)
else:
raise ValueError(f"crop_hw height and width must be greater than 0.")
raise ValueError(
f"random_crop_hw height and width must be greater than 0."
)
alckasoc marked this conversation as resolved.
Show resolved Hide resolved

self.augmenter = AugmentationSequential(
*aug_stack,
Expand All @@ -283,9 +285,28 @@ def __init__(
def __iter__(self):
"""Return an example dictionary with the augmented image and instances."""
for ex in self.source_dp:
inst_shape = ex["instances"].shape # (B, num_instances, num_nodes, 2)
image, instances = ex["image"], ex["instances"].reshape(
inst_shape[0], -1, 2
)
aug_image, aug_instances = self.augmenter(image, instances)
yield {"image": aug_image, "instances": aug_instances.reshape(*inst_shape)}
if "instance_image" in ex and "instance" in ex:
inst_shape = ex["instance"].shape
# (B, channels, height, width), (1, num_nodes, 2)
image, instances = ex["instance_image"], ex["instance"].unsqueeze(0)
aug_image, aug_instances = self.augmenter(image, instances)
ex.update(
{
"instance_image": aug_image,
"instance": aug_instances.reshape(*inst_shape),
}
)
elif "image" in ex and "instances" in ex:
inst_shape = ex["instances"].shape # (B, num_instances, num_nodes, 2)
image, instances = ex["image"], ex["instances"].reshape(
inst_shape[0], -1, 2
) # (B, channels, height, width), (B, num_instances x num_nodes, 2)

aug_image, aug_instances = self.augmenter(image, instances)
ex.update(
{
"image": aug_image,
"instances": aug_instances.reshape(*inst_shape),
}
)
yield ex
8 changes: 4 additions & 4 deletions sleap_nn/data/instance_centroids.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def __init__(

def __iter__(self):
"""Add `"centroids"` key to example."""
for example in self.source_dp:
example["centroids"] = find_centroids(
example["instances"], anchor_ind=self.anchor_ind
for ex in self.source_dp:
ex["centroids"] = find_centroids(
ex["instances"], anchor_ind=self.anchor_ind
)
yield example
yield ex
43 changes: 18 additions & 25 deletions sleap_nn/data/instance_cropping.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Handle cropping of instances."""
from typing import Optional
from typing import Optional, Tuple

import numpy as np
import sleap_io as sio
Expand Down Expand Up @@ -58,50 +58,43 @@ class InstanceCropper(IterDataPipe):

Attributes:
source_dp: The previous `DataPipe` with samples that contain an `instances` key.
crop_width: Width of the crop in pixels
crop_height: Height of the crop in pixels
crop_hw: Height and Width of the crop in pixels
"""

def __init__(
self,
source_dp: IterDataPipe,
crop_width: int,
crop_height: int,
):
def __init__(self, source_dp: IterDataPipe, crop_hw: Tuple[int, int]):
"""Initialize InstanceCropper with the source `DataPipe."""
self.source_dp = source_dp
self.crop_width = crop_width
self.crop_height = crop_height
self.crop_hw = crop_hw

def __iter__(self):
"""Generate instance cropped examples."""
for example in self.source_dp:
image = example["image"] # (frames, channels, height, width)
instances = example["instances"] # (frames, n_instances, n_nodes, 2)
centroids = example["centroids"] # (frames, n_instances, 2)
for ex in self.source_dp:
image = ex["image"] # (B, channels, height, width)
instances = ex["instances"] # (B, n_instances, num_nodes, 2)
centroids = ex["centroids"] # (B, n_instances, 2)
for instance, centroid in zip(instances[0], centroids[0]):
# Generate bounding boxes from centroid.
bbox = torch.unsqueeze(
make_centered_bboxes(centroid, self.crop_height, self.crop_width), 0
) # (frames, 4, 2)
instance_bbox = torch.unsqueeze(
make_centered_bboxes(centroid, self.crop_hw[0], self.crop_hw[1]), 0
) # (B, 4, 2)

box_size = (self.crop_height, self.crop_width)
box_size = (self.crop_hw[0], self.crop_hw[1])

# Generate cropped image of shape (frames, channels, crop_height, crop_width)
# Generate cropped image of shape (B, channels, crop_height, crop_width)
instance_image = crop_and_resize(
image,
boxes=bbox,
boxes=instance_bbox,
size=box_size,
)

# Access top left point (x,y) of bounding box and subtract this offset from
# position of nodes.
point = bbox[0][0]
point = instance_bbox[0][0]
center_instance = instance - point

instance_example = {
"instance_image": instance_image, # (frames, channels, crop_height, crop_width)
"bbox": bbox, # (frames, 4, 2)
"instance": center_instance, # (n_instances, 2)
"instance_image": instance_image, # (B, channels, crop_height, crop_width)
"instance_bbox": instance_bbox, # (B, 4, 2)
"instance": center_instance, # (num_nodes, 2)
}
yield instance_example
89 changes: 89 additions & 0 deletions sleap_nn/data/pipelines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""This module defines high level pipeline configurations from providers/transformers.

This allows for convenient ways to configure individual variants of common pipelines, as
well as to define training vs inference versions based on the same configurations.
"""
import torch
from omegaconf.dictconfig import DictConfig
from torch.utils.data.datapipes.datapipe import IterDataPipe

from sleap_nn.data.augmentation import KorniaAugmenter
from sleap_nn.data.confidence_maps import ConfidenceMapGenerator
from sleap_nn.data.instance_centroids import InstanceCentroidFinder
from sleap_nn.data.instance_cropping import InstanceCropper
from sleap_nn.data.normalization import Normalizer
from sleap_nn.data.providers import LabelsReader


class SleapDataset(IterDataPipe):
"""Returns image and corresponding heatmap for the DataLoader.

This class is to return the image and its corresponding confidence map
to load the dataset with the DataLoader class

Attributes:
source_dp: The previous `DataPipe` with samples that contain an `instances` key.
"""

def __init__(self, source_dp: IterDataPipe):
"""Initialize SleapDataset with the source `DataPipe."""
self.dp = source_dp

def __iter__(self):
"""Return a tuple with the cropped image and the heatmap."""
for example in self.dp:
if len(example["instance_image"].shape) == 4:
example["instance_image"] = example["instance_image"].squeeze(dim=0)
yield example["instance_image"], example["confidence_maps"]
alckasoc marked this conversation as resolved.
Show resolved Hide resolved
alckasoc marked this conversation as resolved.
Show resolved Hide resolved


class TopdownConfmapsPipeline:
"""Pipeline builder for instance-centered confidence map models.

Attributes:
data_config: Data-related configuration.
"""

def __init__(self, data_config: DictConfig) -> None:
"""Initialize the data config."""
self.data_config = data_config
alckasoc marked this conversation as resolved.
Show resolved Hide resolved

alckasoc marked this conversation as resolved.
Show resolved Hide resolved
def make_training_pipeline(
self, data_provider: IterDataPipe, filename: str
) -> IterDataPipe:
"""Create training pipeline with input data only.

Args:
data_provider: A `Provider` that generates data examples, typically a
`LabelsReader` instance.
filename: A string path to the name of the `.slp` file.

Returns:
An `IterDataPipe` instance configured to produce input examples.
"""
datapipe = data_provider.from_filename(filename=filename)
datapipe = Normalizer(datapipe)

datapipe = InstanceCentroidFinder(datapipe)
datapipe = InstanceCropper(datapipe, self.data_config.preprocessing.crop_hw)

if self.data_config.augmentation_config.random_crop.random_crop_p:
datapipe = KorniaAugmenter(
datapipe,
random_crop_hw=self.data_config.augmentation_config.random_crop.random_crop_hw,
random_crop_p=self.data_config.augmentation_config.random_crop.random_crop_p,
)

if self.data_config.augmentation_config.use_augmentations:
datapipe = KorniaAugmenter(
datapipe, **dict(self.data_config.augmentation_config.augmentations)
)
alckasoc marked this conversation as resolved.
Show resolved Hide resolved

datapipe = ConfidenceMapGenerator(
datapipe,
sigma=self.data_config.preprocessing.conf_map_gen.sigma,
output_stride=self.data_config.preprocessing.conf_map_gen.output_stride,
)
datapipe = SleapDataset(datapipe)

return datapipe
alckasoc marked this conversation as resolved.
Show resolved Hide resolved
alckasoc marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 4 additions & 4 deletions tests/data/test_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def test_kornia_augmentation(minimal_instance):
erase_p=1.0,
mixup_p=1.0,
mixup_lambda=(0.0, 1.0),
crop_hw=(384, 384),
crop_p=1.0,
random_crop_hw=(384, 384),
random_crop_p=1.0,
)

# Test all augmentations.
Expand All @@ -68,6 +68,6 @@ def test_kornia_augmentation(minimal_instance):
):
p = KorniaAugmenter(
p,
crop_hw=(0, 0),
crop_p=1.0,
random_crop_hw=(0, 0),
random_crop_p=1.0,
)
2 changes: 1 addition & 1 deletion tests/data/test_confmaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_confmaps(minimal_instance):
datapipe = LabelsReader.from_filename(minimal_instance)
datapipe = InstanceCentroidFinder(datapipe)
datapipe = Normalizer(datapipe)
datapipe = InstanceCropper(datapipe, 100, 100)
datapipe = InstanceCropper(datapipe, (100, 100))
datapipe1 = ConfidenceMapGenerator(datapipe, sigma=1.5, output_stride=1)
sample = next(iter(datapipe1))

Expand Down
4 changes: 2 additions & 2 deletions tests/data/test_instance_cropping.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ def test_instance_cropper(minimal_instance):
datapipe = LabelsReader.from_filename(minimal_instance)
datapipe = InstanceCentroidFinder(datapipe)
datapipe = Normalizer(datapipe)
datapipe = InstanceCropper(datapipe, 100, 100)
datapipe = InstanceCropper(datapipe, (100, 100))
sample = next(iter(datapipe))

# Test shapes.
assert sample["instance"].shape == (2, 2)
assert sample["instance_image"].shape == (1, 1, 100, 100)
assert sample["bbox"].shape == (1, 4, 2)
assert sample["instance_bbox"].shape == (1, 4, 2)

# Test samples.
gt = torch.Tensor(
Expand Down
Loading