Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Top-down Centered-instance Pipeline #16

Merged
merged 61 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 58 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
2dfcd90
added make_centered_bboxes & normalize_bboxes
alckasoc Aug 3, 2023
1088e7f
added make_centered_bboxes & normalize_bboxes
alckasoc Aug 3, 2023
2d0a009
created test_instance_cropping.py
alckasoc Aug 3, 2023
02ea629
added test normalize bboxes; added find_global_peaks_rough
alckasoc Aug 6, 2023
711b3aa
black formatted
alckasoc Aug 6, 2023
3a0cfb7
fixed merges
alckasoc Aug 6, 2023
9a728aa
black formatted peak_finding
alckasoc Aug 6, 2023
e84535f
added make_grid_vectors, normalize_bboxes, integral_regression, added…
alckasoc Aug 10, 2023
36f6573
finished find_global_peaks with integral regression over centroid crops!
alckasoc Aug 10, 2023
b17af28
reformatted with pydocstyle & black
alckasoc Aug 10, 2023
3ea75ae
Merge remote-tracking branch 'origin/main' into vincent/find_peaks
alckasoc Aug 10, 2023
a506579
moved make_grid_vectors to data/utils
alckasoc Aug 10, 2023
02babb1
removed normalize_bboxes
alckasoc Aug 10, 2023
373f4b1
added tests docstrings
alckasoc Aug 10, 2023
6351314
sorted imports with isort
alckasoc Aug 10, 2023
008a994
remove unused imports
alckasoc Aug 10, 2023
b45619c
updated test cases for instance cropping
alckasoc Aug 10, 2023
381a49f
added minimal_cms.pt fixture + unit tests
alckasoc Aug 11, 2023
0ad336c
added minimal_bboxes fixture; added unit tests for crop_bboxes & inte…
alckasoc Aug 11, 2023
da1ba7e
added find_global_peaks unit tests
alckasoc Aug 11, 2023
7778512
finished find_local_peaks_rough!
alckasoc Aug 17, 2023
9f7ac3f
finished find_local_peaks!
alckasoc Aug 17, 2023
b9869d6
added unit tests for find_local_peaks and find_local_peaks_rough
alckasoc Aug 17, 2023
bfd1cac
updated test cases
alckasoc Aug 17, 2023
a8b3c31
added more test cases for find_local_peaks
alckasoc Aug 17, 2023
125625d
updated test cases
alckasoc Aug 17, 2023
a25d920
added architectures folder
alckasoc Aug 17, 2023
3ba92b6
added maxpool2d same padding, get_act_fn; added simpleconvblock, simp…
alckasoc Aug 17, 2023
f9558f2
added test_unet_reference
alckasoc Aug 18, 2023
28d57ca
black formatted common.py & test_unet.py
alckasoc Aug 18, 2023
8ca4538
fixed merge conflicts
alckasoc Aug 18, 2023
6df3c20
Merge branch 'main' into vincent/unet
alckasoc Aug 18, 2023
c4792a6
Merge branch 'vincent/unet' of https://github.com/talmolab/sleap-nn i…
alckasoc Aug 18, 2023
87cd034
deleted tmp nb
alckasoc Aug 18, 2023
7004869
_calc_same_pad returns int
alckasoc Aug 19, 2023
680778d
fixed test case
alckasoc Aug 19, 2023
7cd75dc
added simpleconvblock tests
alckasoc Aug 19, 2023
79b535d
added tests
alckasoc Aug 19, 2023
691af45
added tests for simple upsampling block
alckasoc Aug 19, 2023
2520fa2
updated test_unet
alckasoc Aug 28, 2023
bcf4069
removed unnecessary variables
alckasoc Aug 30, 2023
dbccdcf
updated augmentation random erase default values
alckasoc Aug 30, 2023
029a545
created data/pipelines.py
alckasoc Aug 30, 2023
3e5ae68
added base config in config/data; temporary till config system settled
alckasoc Aug 31, 2023
1b8002b
updated variable defaults to 0 and edited variable names in augmentation
alckasoc Aug 31, 2023
f1c64f4
updated parameter names in data/instance_cropping
alckasoc Aug 31, 2023
2a22674
added data/pipelines topdown pipeline make_base_pipeline
alckasoc Aug 31, 2023
f3ddf2f
added test_pipelines
alckasoc Aug 31, 2023
c861c72
removed configs
alckasoc Sep 5, 2023
31aadc1
updated augmentation class
alckasoc Sep 6, 2023
6630155
modified test
alckasoc Sep 6, 2023
55cf1a9
updated pipelines docstring
alckasoc Sep 6, 2023
9715c01
removed make_base_pipeline and updated tests
alckasoc Sep 6, 2023
7deec65
removed empty_cache in SleapDataset
alckasoc Sep 6, 2023
e32a0a9
Merge branch 'main' into vincent/topdownpipeline
alckasoc Sep 7, 2023
b1ef93c
updated test_pipelines
alckasoc Sep 7, 2023
ae523d1
updated sleapdataset to return a dict
alckasoc Sep 12, 2023
b421937
added key filter transformer block, removed sleap dataset, added type…
alckasoc Sep 12, 2023
fe61f15
updated type hints
alckasoc Sep 12, 2023
0214abb
added coderabbit suggestions
alckasoc Sep 12, 2023
e3b28da
fixed small squeeze issue
alckasoc Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sleap_nn/architectures/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Common utilities for architecture and model building."""
from typing import List

import torch
from torch import nn
from torch.nn import functional as F
Expand Down Expand Up @@ -134,7 +136,7 @@ def get_act_fn(activation: str) -> nn.Module:
return activations[activation]


def get_children_layers(model: torch.nn.Module):
def get_children_layers(model: torch.nn.Module) -> List[nn.Module]:
"""Recursively retrieves a flattened list of all children modules and submodules within the given model.

Args:
Expand Down
79 changes: 50 additions & 29 deletions sleap_nn/data/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,6 @@ class KorniaAugmenter(IterDataPipe):
Attributes:
source_dp: The input `IterDataPipe` with examples that contain `"instances"` and
`"image"` keys.
crop_hw: Desired output size (out_h, out_w) of the crop. Must be Tuple[int, int],
then out_h = size[0], out_w = size[1].
crop_p: Probability of applying random crop.
rotation: Angles in degrees as a scalar float of the amount of rotation. A
random angle in `(-rotation, rotation)` will be sampled and applied to both
images and keypoints. Set to 0 to disable rotation augmentation.
Expand All @@ -125,11 +122,14 @@ class KorniaAugmenter(IterDataPipe):
contrast_p: Probability of applying random contrast.
brightness: The brightness factor to apply Default: `(1.0, 1.0)`.
brightness_p: Probability of applying random brightness.
erase_scale: Range of proportion of erased area against input image. Default: `(0.02, 0.33)`.
erase_ratio: Range of aspect ratio of erased area. Default: `(0.3, 3.3)`.
erase_scale: Range of proportion of erased area against input image. Default: `(0.0001, 0.01)`.
erase_ratio: Range of aspect ratio of erased area. Default: `(1, 1)`.
erase_p: Probability of applying random erase.
mixup_lambda: min-max value of mixup strength. Default is 0-1. Default: `None`.
mixup_p: Probability of applying random mixup v2.
random_crop_hw: Desired output size (out_h, out_w) of the crop. Must be Tuple[int, int],
then out_h = size[0], out_w = size[1].
random_crop_p: Probability of applying random crop.

Notes:
This block expects the "image" and "instances" keys to be present in the input
Expand All @@ -150,24 +150,24 @@ def __init__(
rotation: Optional[float] = 15.0,
scale: Optional[float] = 0.05,
translate: Optional[Tuple[float, float]] = (0.02, 0.02),
affine_p: float = 0.5,
affine_p: float = 0.0,
uniform_noise: Optional[Tuple[float, float]] = (0.0, 0.04),
uniform_noise_p: float = 0.5,
uniform_noise_p: float = 0.0,
gaussian_noise_mean: Optional[float] = 0.02,
gaussian_noise_std: Optional[float] = 0.004,
gaussian_noise_p: float = 0.5,
gaussian_noise_p: float = 0.0,
contrast: Optional[Tuple[float, float]] = (0.5, 2.0),
contrast_p: float = 0.5,
contrast_p: float = 0.0,
brightness: Optional[float] = 0.0,
brightness_p: float = 0.5,
brightness_p: float = 0.0,
erase_scale: Optional[Tuple[float, float]] = (0.0001, 0.01),
erase_ratio: Optional[Tuple[float, float]] = (1, 1),
erase_p: float = 0.5,
erase_p: float = 0.0,
mixup_lambda: Union[Optional[float], Tuple[float, float], None] = None,
mixup_p: float = 0.5,
crop_hw: Tuple[int, int] = (0, 0),
crop_p: float = 0.0,
):
mixup_p: float = 0.0,
random_crop_hw: Tuple[int, int] = (0, 0),
random_crop_p: float = 0.0,
) -> None:
"""Initialize the block and the augmentation pipeline."""
self.source_dp = source_dp
self.rotation = rotation
Expand All @@ -188,8 +188,8 @@ def __init__(
self.erase_p = erase_p
self.mixup_lambda = mixup_lambda
self.mixup_p = mixup_p
self.crop_hw = crop_hw
self.crop_p = crop_p
self.random_crop_hw = random_crop_hw
self.random_crop_p = random_crop_p

aug_stack = []
if self.affine_p > 0:
Expand Down Expand Up @@ -259,19 +259,21 @@ def __init__(
same_on_batch=True,
)
)
if self.crop_p > 0:
if self.crop_hw[0] > 0 and self.crop_hw[1] > 0:
if self.random_crop_p > 0:
if self.random_crop_hw[0] > 0 and self.random_crop_hw[1] > 0:
aug_stack.append(
K.augmentation.RandomCrop(
size=self.crop_hw,
size=self.random_crop_hw,
pad_if_needed=True,
p=self.crop_p,
p=self.random_crop_p,
keepdim=True,
same_on_batch=True,
)
)
else:
raise ValueError(f"crop_hw height and width must be greater than 0.")
raise ValueError(
f"random_crop_hw height and width must be greater than 0."
)
alckasoc marked this conversation as resolved.
Show resolved Hide resolved

self.augmenter = AugmentationSequential(
*aug_stack,
Expand All @@ -280,12 +282,31 @@ def __init__(
same_on_batch=True,
)

def __iter__(self):
def __iter__(self) -> Dict[str, torch.Tensor]:
"""Return an example dictionary with the augmented image and instances."""
for ex in self.source_dp:
inst_shape = ex["instances"].shape # (B, num_instances, num_nodes, 2)
image, instances = ex["image"], ex["instances"].reshape(
inst_shape[0], -1, 2
)
aug_image, aug_instances = self.augmenter(image, instances)
yield {"image": aug_image, "instances": aug_instances.reshape(*inst_shape)}
if "instance_image" in ex and "instance" in ex:
inst_shape = ex["instance"].shape
# (B, channels, height, width), (1, num_nodes, 2)
image, instances = ex["instance_image"], ex["instance"].unsqueeze(0)
aug_image, aug_instances = self.augmenter(image, instances)
ex.update(
{
"instance_image": aug_image,
"instance": aug_instances.reshape(*inst_shape),
}
)
elif "image" in ex and "instances" in ex:
inst_shape = ex["instances"].shape # (B, num_instances, num_nodes, 2)
image, instances = ex["image"], ex["instances"].reshape(
inst_shape[0], -1, 2
) # (B, channels, height, width), (B, num_instances x num_nodes, 2)

aug_image, aug_instances = self.augmenter(image, instances)
ex.update(
{
"image": aug_image,
"instances": aug_instances.reshape(*inst_shape),
}
)
yield ex
8 changes: 4 additions & 4 deletions sleap_nn/data/confidence_maps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Generate confidence maps."""
from typing import Optional
from typing import Dict, Optional

import sleap_io as sio
import torch
Expand All @@ -10,7 +10,7 @@

def make_confmaps(
points: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor, sigma: float
):
) -> torch.Tensor:
"""Make confidence maps from a set of points from a single instance.

Args:
Expand Down Expand Up @@ -70,15 +70,15 @@ def __init__(
output_stride: int = 1,
instance_key: str = "instance",
image_key: str = "instance_image",
):
) -> None:
"""Initialize ConfidenceMapGenerator with input `DataPipe`, sigma, and output stride."""
self.source_dp = source_dp
self.sigma = sigma
self.output_stride = output_stride
self.instance_key = instance_key
self.image_key = image_key

def __iter__(self):
def __iter__(self) -> Dict[str, torch.Tensor]:
"""Generate confidence maps for each example."""
for example in self.source_dp:
instance = example[self.instance_key]
alckasoc marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
38 changes: 38 additions & 0 deletions sleap_nn/data/general.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""General purpose transformers for common pipeline processing tasks."""
from typing import Callable, Dict, List, Text

from torch.utils.data.datapipes.datapipe import IterDataPipe


class KeyFilter(IterDataPipe):
"""Transformer for filtering example keys."""

def __init__(self, source_dp: IterDataPipe, keep_keys: List[Text] = None) -> None:
"""Initialize KeyFilter with the source `DataPipe."""
self.dp = source_dp
self.keep_keys = keep_keys

def __iter__(self):
"""Return a dictionary filtered for the relevant outputs.

The input dictionary includes:
- image: the full frame image
- instances: all keypoints of all instances in the frame image
- centroids: all centroids of all instances in the frame image
- instance: the individual instance's keypoints
- instance_bbox: the individual instance's bbox
- instance_image: the individual instance's cropped image
- confidence_maps: the individual instance's heatmap
"""
for example in self.dp:
if self.keep_keys is None:
# If keep_keys is not provided, yield the entire example.
yield example
else:
# Filter the example dictionary based on keep_keys.
filtered_example = {
key: value
for key, value in example.items()
if key in self.keep_keys
}
yield filtered_example
14 changes: 7 additions & 7 deletions sleap_nn/data/instance_centroids.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Handle calculation of instance centroids."""
from typing import Optional
from typing import Dict, Optional

import torch
from torch.utils.data.datapipes.datapipe import IterDataPipe
Expand Down Expand Up @@ -79,15 +79,15 @@ def __init__(
self,
source_dp: IterDataPipe,
anchor_ind: Optional[int] = None,
):
) -> None:
"""Initialize InstanceCentroidFinder with the source `DataPipe."""
self.source_dp = source_dp
self.anchor_ind = anchor_ind

def __iter__(self):
def __iter__(self) -> Dict[str, torch.Tensor]:
"""Add `"centroids"` key to example."""
for example in self.source_dp:
example["centroids"] = find_centroids(
example["instances"], anchor_ind=self.anchor_ind
for ex in self.source_dp:
ex["centroids"] = find_centroids(
ex["instances"], anchor_ind=self.anchor_ind
)
yield example
yield ex
48 changes: 22 additions & 26 deletions sleap_nn/data/instance_cropping.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Handle cropping of instances."""
from typing import Optional
from typing import Optional, Tuple

import numpy as np
import sleap_io as sio
Expand Down Expand Up @@ -58,50 +58,46 @@ class InstanceCropper(IterDataPipe):

Attributes:
source_dp: The previous `DataPipe` with samples that contain an `instances` key.
crop_width: Width of the crop in pixels
crop_height: Height of the crop in pixels
crop_hw: Height and Width of the crop in pixels
"""

def __init__(
self,
source_dp: IterDataPipe,
crop_width: int,
crop_height: int,
):
def __init__(self, source_dp: IterDataPipe, crop_hw: Tuple[int, int]):
"""Initialize InstanceCropper with the source `DataPipe."""
self.source_dp = source_dp
self.crop_width = crop_width
self.crop_height = crop_height
self.crop_hw = crop_hw

def __iter__(self):
"""Generate instance cropped examples."""
for example in self.source_dp:
image = example["image"] # (frames, channels, height, width)
instances = example["instances"] # (frames, n_instances, n_nodes, 2)
centroids = example["centroids"] # (frames, n_instances, 2)
for ex in self.source_dp:
image = ex["image"] # (B, channels, height, width)
instances = ex["instances"] # (B, n_instances, num_nodes, 2)
centroids = ex["centroids"] # (B, n_instances, 2)
for instance, centroid in zip(instances[0], centroids[0]):
# Generate bounding boxes from centroid.
bbox = torch.unsqueeze(
make_centered_bboxes(centroid, self.crop_height, self.crop_width), 0
) # (frames, 4, 2)
instance_bbox = torch.unsqueeze(
make_centered_bboxes(centroid, self.crop_hw[0], self.crop_hw[1]), 0
) # (B, 4, 2)

box_size = (self.crop_height, self.crop_width)
box_size = (self.crop_hw[0], self.crop_hw[1])

# Generate cropped image of shape (frames, channels, crop_height, crop_width)
# Generate cropped image of shape (B, channels, crop_height, crop_width)
instance_image = crop_and_resize(
image,
boxes=bbox,
boxes=instance_bbox,
size=box_size,
)

# Access top left point (x,y) of bounding box and subtract this offset from
# position of nodes.
point = bbox[0][0]
point = instance_bbox[0][0]
center_instance = instance - point

instance_example = {
"instance_image": instance_image, # (frames, channels, crop_height, crop_width)
"bbox": bbox, # (frames, 4, 2)
"instance": center_instance, # (n_instances, 2)
"instance_image": instance_image.squeeze(
0
), # (B=1, channels, crop_height, crop_width)
"instance_bbox": instance_bbox, # (B, 4, 2)
"instance": center_instance, # (num_nodes, 2)
}
yield instance_example
ex.update(instance_example)
yield ex
6 changes: 4 additions & 2 deletions sleap_nn/data/normalization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""This module implements data pipeline blocks for normalization operations."""
from typing import Dict

import torch
from torch.utils.data.datapipes.datapipe import IterDataPipe

Expand All @@ -16,11 +18,11 @@ class Normalizer(IterDataPipe):
def __init__(
self,
source_dp: IterDataPipe,
):
) -> None:
"""Initialize the block."""
self.source_dp = source_dp

def __iter__(self):
def __iter__(self) -> Dict[str, torch.Tensor]:
"""Return an example dictionary with the augmented image and instance."""
for ex in self.source_dp:
if not torch.is_floating_point(ex["image"]):
Expand Down
Loading