Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor datapipes #9

Merged
merged 9 commits into from
Jul 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

.vscode/
2 changes: 0 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ dependencies:
- cudnn
- pytorch
- kornia
- torchdata
- torchvision
- imageio
- av
- ffmpeg
- albumentations
- matplotlib
- pip
- pip:
Expand Down
2 changes: 0 additions & 2 deletions environment_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@ dependencies:
- lightning=2.0.5 # due to dependency conflict Lightning Issue (#18027)
- cpuonly
- kornia
- torchdata
- torchvision
- imageio
- av
- ffmpeg
- albumentations
- matplotlib
- pip
- pip:
Expand Down
2 changes: 0 additions & 2 deletions environment_osx-arm64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@ dependencies:
- lightning=2.0.5 # due to dependency conflict Lightning Issue (#18027)
- pytorch
- kornia
- torchdata
- torchvision
- imageio
- av
- ffmpeg
- albumentations
- matplotlib
- pip
- pip:
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@ classifiers = [
]
dependencies = [
"torch>=2.0.0",
"torchdata",
"torchvision",
"pydantic<2.0",
"lightning==2.0.5",
"imageio",
"imageio-ffmpeg",
"av",
"albumentations",
"hydra-core",
"sleap-io>=0.0.7"
"sleap-io>=0.0.7",
"kornia"
]
dynamic = ["version", "readme"]

Expand All @@ -43,7 +42,8 @@ dev = [
"pydocstyle",
"toml",
"twine",
"build"
"build",
"ipython"
]

[project.scripts]
Expand Down
98 changes: 54 additions & 44 deletions sleap_nn/data/augmentation.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,74 @@
"""Handle rotation and scaling augmentations."""
from typing import List, Tuple, Union
import torchdata.datapipes.iter as dp
import torch
"""This module implements data pipeline blocks for augmentations."""
from typing import Optional
from torch.utils.data.datapipes.datapipe import IterDataPipe
import kornia.augmentation as K


class KorniaAugmenter(dp.IterDataPipe):
"""DataPipe for applying Rotation and Scaling augmentations using Kornia.
class KorniaAugmenter(IterDataPipe):
"""DataPipe for applying rotation and scaling augmentations using Kornia.

This DataPipe will generate augmented samples containing the augmented frame and an sleap_io.Instance
from sleap_io.Labels instance.
This DataPipe will apply augmentations to images and instances in examples from the
input pipeline.

Attributes:
source_dp: DataPipe which is an instance of the LabelsReader class
rotation: range of degrees to select from. If float, randomly selects a value from (-rotation, +rotation)
probability: probability of applying transformation
scale: scaling factor interval. Randomly selects a scale from the range
source_dp: The input `IterDataPipe` with examples that contain `"instances"` and
`"image"` keys.
rotation: Angles in degrees as a scalar float of the amount of rotation. A
random angle in `(-rotation, rotation)` will be sampled and applied to both
images and keypoints. Set to 0 to disable rotation augmentation.
scale: A scaling factor as a scalar float specifying the amount of scaling. A
random factor between `(1 - scale, 1 + scale)` will be sampled and applied
to both images and keypoints. If `None`, no scaling augmentation will be
applied.
probability: Probability of applying the transformations.

Notes:
This block expects the "image" and "instances" keys to be present in the input
examples.

The `"image"` key should contain a torch.Tensor of dtype torch.float32
and of shape `(..., C, H, W)`, i.e., rank >= 3.

The `"instances"` key should contain a torch.Tensor of dtype torch.float32 and
of shape `(..., n_instances, n_nodes, 2)`, i.e., rank >= 3.

The augmented versions will be returned with the same keys and shapes.
"""

def __init__(
self,
source_dp: dp.IterDataPipe,
rotation: Union[float, Tuple[float, float], List[float]] = 90,
source_dp: IterDataPipe,
rotation: float = 15.0,
scale: Optional[float] = 0.05,
probability: float = 0.5,
scale: Tuple[float, float] = (0.1, 0.3),
):
"""Initialize the class variables with the DataPipe and the augmenter with rotation and scaling."""
"""Initialize the block and the augmentation pipeline."""
self.source_dp = source_dp
self.datapipe = self.source_dp.map(self.normalize)
self.rotation = rotation
self.scale = (1 - scale, 1 + scale)
self.probability = probability
self.augmenter = K.AugmentationSequential(
K.RandomRotation(degrees=rotation, p=probability, keepdim=True),
K.RandomAffine(degrees=0, scale=scale, keepdim=True),
K.RandomAffine(
degrees=self.rotation,
scale=self.scale,
p=self.probability,
keepdim=True,
same_on_batch=True,
),
data_keys=["input", "keypoints"],
keepdim=True,
same_on_batch=True,
)

@classmethod
def normalize(self, data):
"""Function to normalize the image.

This function will convert the image to type Double and normalizes it.

Args:
data: A dictionary sample (`image and key-points`) from the LabelsReader class.

Returns:
A dictionary with the normalized image and instance.

"""
image = data["image"]
instance = data["instance"]
image = image.type(torch.DoubleTensor)
image = image / 255
return {"image": image, "instance": instance}

def __iter__(self):
"""Returns a dictionary sample with the augmented image and the transformed instance."""
for dict in self.datapipe:
image = dict["image"]
instance = dict["instance"]
aug_image, aug_instance = self.augmenter(image, instance)
yield {"image": aug_image, "instance": aug_instance}
"""Return an example dictionary with the augmented image and instance."""
for ex in self.source_dp:
img = ex["image"]
pts = ex["instances"]
pts_shape = pts.shape
pts = pts.reshape(-1, pts_shape[-2], pts_shape[-1])
img, pts = self.augmenter(img, pts)
pts = pts.reshape(pts_shape)
ex["image"] = img
ex["instances"] = pts
yield ex
12 changes: 5 additions & 7 deletions sleap_nn/data/instance_centroids.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
"""Handle calculation of instance centroids."""
import torchdata.datapipes.iter as dp
import lightning.pytorch as pl
from torch.utils.data.datapipes.datapipe import IterDataPipe
from typing import Optional
import sleap_io as sio
import numpy as np
import torch


Expand Down Expand Up @@ -64,21 +61,22 @@ def find_centroids(
return centroids


class InstanceCentroidFinder(dp.IterDataPipe):
class InstanceCentroidFinder(IterDataPipe):
"""Datapipe for finding centroids of instances.

This DataPipe will produce examples that contain a 'centroids' key.

Attributes:
source_dp: The previous `DataPipe` with samples that contain an `instances` key.
source_dp: The input `IterDataPipe` with examples that contain an `"instances"`
key.
anchor_ind: The index of the node to use as the anchor for the centroid. If not
provided or if not present in the instance, the midpoint of the bounding box
is used instead.
"""

def __init__(
self,
source_dp: dp.IterDataPipe,
source_dp: IterDataPipe,
anchor_ind: Optional[int] = None,
):
"""Initialize InstanceCentroidFinder with the source `DataPipe."""
Expand Down
28 changes: 28 additions & 0 deletions sleap_nn/data/normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""This module implements data pipeline blocks for normalization operations."""
import torch
from torch.utils.data.datapipes.datapipe import IterDataPipe


class Normalizer(IterDataPipe):
"""DataPipe for applying normalization.

This DataPipe will normalize the image from `uint8` to `float32` and scale the
values to the range `[0, 1]`.

Attributes:
source_dp: The input `IterDataPipe` with examples that contain `"images"` key.
"""

def __init__(
self,
source_dp: IterDataPipe,
):
"""Initialize the block."""
self.source_dp = source_dp

def __iter__(self):
"""Return an example dictionary with the augmented image and instance."""
for ex in self.source_dp:
if not torch.is_floating_point(ex["image"]):
ex["image"] = ex["image"].to(torch.float32) / 255.0
yield ex
9 changes: 4 additions & 5 deletions sleap_nn/data/providers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Handle importing of sleap data."""
import torchdata.datapipes.iter as dp
import lightning.pytorch as pl
"""This module implements pipeline blocks for reading input data such as labels."""
from torch.utils.data.datapipes.datapipe import IterDataPipe
import torch
import sleap_io as sio
import numpy as np


class LabelsReader(dp.IterDataPipe):
class LabelsReader(IterDataPipe):
"""Datapipe for reading frames from Labels object.

This DataPipe will produce examples containing a frame and an sleap_io.Instance
Expand Down Expand Up @@ -51,5 +50,5 @@ def __iter__(self):

yield {
"image": torch.from_numpy(image),
"instances": torch.from_numpy(instances),
"instances": torch.from_numpy(instances.astype("float32")),
}
30 changes: 9 additions & 21 deletions tests/data/test_augmentation.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,22 @@
"""Module for testing augmentations with Kornia"""

import pytest
from sleap_nn.data.augmentation import KorniaAugmenter
from sleap_nn.data.providers import LabelsReader
import sleap_io as sio
from sleap_nn.data.normalization import Normalizer
from torch.utils.data import DataLoader
import torch


def test_kornia_augmentation(minimal_instance: sio.Labels):
def test_kornia_augmentation(minimal_instance):
"""Test the Kornia augmentations."""
labels = sio.load_slp(minimal_instance)
lf = labels[0]
org_img = lf.image
org_img = torch.Tensor(org_img).permute(2, 0, 1)
org_pts = torch.from_numpy(lf[0].numpy())
p = LabelsReader.from_filename(minimal_instance)
p = Normalizer(p)
p = KorniaAugmenter(p, rotation=90, probability=1.0, scale=0.05)

datapipe = LabelsReader.from_filename(minimal_instance)
datapipe = KorniaAugmenter(datapipe, rotation=90, probability=1.0, scale=(0.1, 0.3))

dataloader = DataLoader(datapipe)
sample = next(iter(dataloader))
image, instance = sample["image"], sample["instance"]
img, pts = image[0], instance[0]
sample = next(iter(p))
img, pts = sample["image"], sample["instances"]

assert torch.is_tensor(img)
assert torch.is_tensor(pts)
assert img.shape == org_img.shape
assert pts.shape == org_pts.shape


if __name__ == "__main__":
pytest.main([f"{__file__}::test_kornia_augmentation"])
assert img.shape == (1, 1, 384, 384)
assert pts.shape == (1, 2, 2, 2)
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@


def test_instance_centroids(minimal_instance):
"""Test InstanceCentroidFinder

Args:
minimal_instance: minimal_instance testing fixture
"""

# Undefined anchor_ind
datapipe = LabelsReader.from_filename(minimal_instance)
datapipe = InstanceCentroidFinder(datapipe)
Expand Down
11 changes: 11 additions & 0 deletions tests/data/test_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from sleap_nn.data.providers import LabelsReader
from sleap_nn.data.normalization import Normalizer
import torch


def test_normalizer(minimal_instance):
p = LabelsReader.from_filename(minimal_instance)
p = Normalizer(p)

ex = next(iter(p))
assert ex["image"].dtype == torch.float32
7 changes: 0 additions & 7 deletions tests/test_providers.py → tests/data/test_providers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
from sleap_nn.data.providers import LabelsReader
import sleap_io as sio
import pytest
import torch


def test_providers(minimal_instance):
"""Test LabelsReader

Args:
minimal_instance: minimal_instance testing fixture
"""
l = LabelsReader.from_filename(minimal_instance)
sample = next(iter(l))
instances, image = sample["instances"], sample["image"]
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
@pytest.fixture
def sleap_data_dir(pytestconfig):
"""Dir path to sleap data."""
return Path(pytestconfig.rootdir) / "tests/data"
return Path(pytestconfig.rootdir) / "tests/assets"


@pytest.fixture
Expand Down