From 02413bb24e45a38f38cdb523a2fb383c78d8e195 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Fri, 21 Jul 2023 17:34:03 -0700 Subject: [PATCH 1/9] Move fixture data to assets to avoid conflict with data submodule --- tests/{data => assets}/minimal_instance.pkg.slp | Bin tests/fixtures/datasets.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename tests/{data => assets}/minimal_instance.pkg.slp (100%) diff --git a/tests/data/minimal_instance.pkg.slp b/tests/assets/minimal_instance.pkg.slp similarity index 100% rename from tests/data/minimal_instance.pkg.slp rename to tests/assets/minimal_instance.pkg.slp diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index 85d11b14..c818b430 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -5,7 +5,7 @@ @pytest.fixture def sleap_data_dir(pytestconfig): """Dir path to sleap data.""" - return Path(pytestconfig.rootdir) / "tests/data" + return Path(pytestconfig.rootdir) / "tests/assets" @pytest.fixture From cd861ef570db514d838416d767735bb33d8c30e8 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Fri, 21 Jul 2023 17:34:43 -0700 Subject: [PATCH 2/9] Switch to torch built-in datapipes --- sleap_nn/data/instance_centroids.py | 12 +++++------- sleap_nn/data/providers.py | 9 ++++----- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/sleap_nn/data/instance_centroids.py b/sleap_nn/data/instance_centroids.py index ce6e461b..81b8e59a 100644 --- a/sleap_nn/data/instance_centroids.py +++ b/sleap_nn/data/instance_centroids.py @@ -1,9 +1,6 @@ """Handle calculation of instance centroids.""" -import torchdata.datapipes.iter as dp -import lightning.pytorch as pl +from torch.utils.data.datapipes.datapipe import IterDataPipe from typing import Optional -import sleap_io as sio -import numpy as np import torch @@ -64,13 +61,14 @@ def find_centroids( return centroids -class InstanceCentroidFinder(dp.IterDataPipe): +class InstanceCentroidFinder(IterDataPipe): """Datapipe for finding centroids of instances. This DataPipe will produce examples that contain a 'centroids' key. Attributes: - source_dp: The previous `DataPipe` with samples that contain an `instances` key. + source_dp: The input `IterDataPipe` with examples that contain an `"instances"` + key. anchor_ind: The index of the node to use as the anchor for the centroid. If not provided or if not present in the instance, the midpoint of the bounding box is used instead. @@ -78,7 +76,7 @@ class InstanceCentroidFinder(dp.IterDataPipe): def __init__( self, - source_dp: dp.IterDataPipe, + source_dp: IterDataPipe, anchor_ind: Optional[int] = None, ): """Initialize InstanceCentroidFinder with the source `DataPipe.""" diff --git a/sleap_nn/data/providers.py b/sleap_nn/data/providers.py index 60ea4999..c5503add 100644 --- a/sleap_nn/data/providers.py +++ b/sleap_nn/data/providers.py @@ -1,12 +1,11 @@ -"""Handle importing of sleap data.""" -import torchdata.datapipes.iter as dp -import lightning.pytorch as pl +"""This module implements pipeline blocks for reading input data such as labels.""" +from torch.utils.data.datapipes.datapipe import IterDataPipe import torch import sleap_io as sio import numpy as np -class LabelsReader(dp.IterDataPipe): +class LabelsReader(IterDataPipe): """Datapipe for reading frames from Labels object. This DataPipe will produce examples containing a frame and an sleap_io.Instance @@ -51,5 +50,5 @@ def __iter__(self): yield { "image": torch.from_numpy(image), - "instances": torch.from_numpy(instances), + "instances": torch.from_numpy(instances.astype("float32")), } From 3a3c80552f06a5635c74f441e9fc03d5b84a7ada Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Fri, 21 Jul 2023 17:34:53 -0700 Subject: [PATCH 3/9] Refactor augmentation block --- sleap_nn/data/augmentation.py | 98 +++++++++++++++++++---------------- 1 file changed, 54 insertions(+), 44 deletions(-) diff --git a/sleap_nn/data/augmentation.py b/sleap_nn/data/augmentation.py index dfbfba93..cb0732e9 100644 --- a/sleap_nn/data/augmentation.py +++ b/sleap_nn/data/augmentation.py @@ -1,64 +1,74 @@ -"""Handle rotation and scaling augmentations.""" -from typing import List, Tuple, Union -import torchdata.datapipes.iter as dp -import torch +"""This module implements data pipeline blocks for augmentations.""" +from typing import Optional +from torch.utils.data.datapipes.datapipe import IterDataPipe import kornia.augmentation as K -class KorniaAugmenter(dp.IterDataPipe): - """DataPipe for applying Rotation and Scaling augmentations using Kornia. +class KorniaAugmenter(IterDataPipe): + """DataPipe for applying rotation and scaling augmentations using Kornia. - This DataPipe will generate augmented samples containing the augmented frame and an sleap_io.Instance - from sleap_io.Labels instance. + This DataPipe will apply augmentations to images and instances in examples from the + input pipeline. Attributes: - source_dp: DataPipe which is an instance of the LabelsReader class - rotation: range of degrees to select from. If float, randomly selects a value from (-rotation, +rotation) - probability: probability of applying transformation - scale: scaling factor interval. Randomly selects a scale from the range + source_dp: The input `IterDataPipe` with examples that contain `"instances"` and + `"image"` keys. + rotation: Angles in degrees as a scalar float of the amount of rotation. A + random angle in `(-rotation, rotation)` will be sampled and applied to both + images and keypoints. Set to 0 to disable rotation augmentation. + scale: A scaling factor as a scalar float specifying the amount of scaling. A + random factor between `(1 - scale, 1 + scale)` will be sampled and applied + to both images and keypoints. If `None`, no scaling augmentation will be + applied. + probability: Probability of applying the transformations. + Notes: + This block expects the "image" and "instances" keys to be present in the input + examples. + + The `"image"` key should contain a torch.Tensor of dtype torch.float32 + and of shape `(..., C, H, W)`, i.e., rank >= 3. + + The `"instances"` key should contain a torch.Tensor of dtype torch.float32 and + of shape `(..., n_instances, n_nodes, 2)`, i.e., rank >= 3. + + The augmented versions will be returned with the same keys and shapes. """ def __init__( self, - source_dp: dp.IterDataPipe, - rotation: Union[float, Tuple[float, float], List[float]] = 90, + source_dp: IterDataPipe, + rotation: float = 15.0, + scale: Optional[float] = 0.05, probability: float = 0.5, - scale: Tuple[float, float] = (0.1, 0.3), ): - """Initialize the class variables with the DataPipe and the augmenter with rotation and scaling.""" + """Initialize the block and the augmentation pipeline.""" self.source_dp = source_dp - self.datapipe = self.source_dp.map(self.normalize) + self.rotation = rotation + self.scale = (1 - scale, 1 + scale) + self.probability = probability self.augmenter = K.AugmentationSequential( - K.RandomRotation(degrees=rotation, p=probability, keepdim=True), - K.RandomAffine(degrees=0, scale=scale, keepdim=True), + K.RandomAffine( + degrees=self.rotation, + scale=self.scale, + p=self.probability, + keepdim=True, + same_on_batch=True, + ), data_keys=["input", "keypoints"], keepdim=True, + same_on_batch=True, ) - @classmethod - def normalize(self, data): - """Function to normalize the image. - - This function will convert the image to type Double and normalizes it. - - Args: - data: A dictionary sample (`image and key-points`) from the LabelsReader class. - - Returns: - A dictionary with the normalized image and instance. - - """ - image = data["image"] - instance = data["instance"] - image = image.type(torch.DoubleTensor) - image = image / 255 - return {"image": image, "instance": instance} - def __iter__(self): - """Returns a dictionary sample with the augmented image and the transformed instance.""" - for dict in self.datapipe: - image = dict["image"] - instance = dict["instance"] - aug_image, aug_instance = self.augmenter(image, instance) - yield {"image": aug_image, "instance": aug_instance} + """Return an example dictionary with the augmented image and instance.""" + for ex in self.source_dp: + img = ex["image"] + pts = ex["instances"] + pts_shape = pts.shape + pts = pts.reshape(-1, pts_shape[-2], pts_shape[-1]) + img, pts = self.augmenter(img, pts) + pts = pts.reshape(pts_shape) + ex["image"] = img + ex["instances"] = pts + yield ex From 7f4bbf504aa5a760d1564ff780677e9ac3384de2 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Fri, 21 Jul 2023 17:35:14 -0700 Subject: [PATCH 4/9] Fix tests --- tests/data/test_augmentation.py | 30 +++++++-------------- tests/{ => data}/test_instance_centroids.py | 6 ----- tests/{ => data}/test_providers.py | 7 ----- 3 files changed, 9 insertions(+), 34 deletions(-) rename tests/{ => data}/test_instance_centroids.py (91%) rename tests/{ => data}/test_providers.py (71%) diff --git a/tests/data/test_augmentation.py b/tests/data/test_augmentation.py index 69f5ea72..114a93d4 100644 --- a/tests/data/test_augmentation.py +++ b/tests/data/test_augmentation.py @@ -1,34 +1,22 @@ """Module for testing augmentations with Kornia""" -import pytest from sleap_nn.data.augmentation import KorniaAugmenter from sleap_nn.data.providers import LabelsReader -import sleap_io as sio +from sleap_nn.data.normalization import Normalizer from torch.utils.data import DataLoader import torch -def test_kornia_augmentation(minimal_instance: sio.Labels): +def test_kornia_augmentation(minimal_instance): """Test the Kornia augmentations.""" - labels = sio.load_slp(minimal_instance) - lf = labels[0] - org_img = lf.image - org_img = torch.Tensor(org_img).permute(2, 0, 1) - org_pts = torch.from_numpy(lf[0].numpy()) + p = LabelsReader.from_filename(minimal_instance) + p = Normalizer(p) + p = KorniaAugmenter(p, rotation=90, probability=1.0, scale=0.05) - datapipe = LabelsReader.from_filename(minimal_instance) - datapipe = KorniaAugmenter(datapipe, rotation=90, probability=1.0, scale=(0.1, 0.3)) - - dataloader = DataLoader(datapipe) - sample = next(iter(dataloader)) - image, instance = sample["image"], sample["instance"] - img, pts = image[0], instance[0] + sample = next(iter(p)) + img, pts = sample["image"], sample["instances"] assert torch.is_tensor(img) assert torch.is_tensor(pts) - assert img.shape == org_img.shape - assert pts.shape == org_pts.shape - - -if __name__ == "__main__": - pytest.main([f"{__file__}::test_kornia_augmentation"]) + assert img.shape == (1, 1, 384, 384) + assert pts.shape == (1, 2, 2, 2) diff --git a/tests/test_instance_centroids.py b/tests/data/test_instance_centroids.py similarity index 91% rename from tests/test_instance_centroids.py rename to tests/data/test_instance_centroids.py index c9de5e03..2984aacf 100644 --- a/tests/test_instance_centroids.py +++ b/tests/data/test_instance_centroids.py @@ -8,12 +8,6 @@ def test_instance_centroids(minimal_instance): - """Test InstanceCentroidFinder - - Args: - minimal_instance: minimal_instance testing fixture - """ - # Undefined anchor_ind datapipe = LabelsReader.from_filename(minimal_instance) datapipe = InstanceCentroidFinder(datapipe) diff --git a/tests/test_providers.py b/tests/data/test_providers.py similarity index 71% rename from tests/test_providers.py rename to tests/data/test_providers.py index cf5dfeb1..d1be712d 100644 --- a/tests/test_providers.py +++ b/tests/data/test_providers.py @@ -1,15 +1,8 @@ from sleap_nn.data.providers import LabelsReader -import sleap_io as sio -import pytest import torch def test_providers(minimal_instance): - """Test LabelsReader - - Args: - minimal_instance: minimal_instance testing fixture - """ l = LabelsReader.from_filename(minimal_instance) sample = next(iter(l)) instances, image = sample["instances"], sample["image"] From d9ce00da13e04b7f3169cc72f042db930a35f302 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Fri, 21 Jul 2023 17:35:25 -0700 Subject: [PATCH 5/9] Add normalization block --- sleap_nn/data/normalization.py | 28 ++++++++++++++++++++++++++++ tests/data/test_normalization.py | 11 +++++++++++ 2 files changed, 39 insertions(+) create mode 100644 sleap_nn/data/normalization.py create mode 100644 tests/data/test_normalization.py diff --git a/sleap_nn/data/normalization.py b/sleap_nn/data/normalization.py new file mode 100644 index 00000000..8c150ebc --- /dev/null +++ b/sleap_nn/data/normalization.py @@ -0,0 +1,28 @@ +"""This module implements data pipeline blocks for normalization operations.""" +import torch +from torch.utils.data.datapipes.datapipe import IterDataPipe + + +class Normalizer(IterDataPipe): + """DataPipe for applying normalization. + + This DataPipe will normalize the image from `uint8` to `float32` and scale the + values to the range `[0, 1]`. + + Attributes: + source_dp: The input `IterDataPipe` with examples that contain `"images"` key. + """ + + def __init__( + self, + source_dp: IterDataPipe, + ): + """Initialize the block.""" + self.source_dp = source_dp + + def __iter__(self): + """Return an example dictionary with the augmented image and instance.""" + for ex in self.source_dp: + if not torch.is_floating_point(ex["image"]): + ex["image"] = ex["image"].to(torch.float32) / 255.0 + yield ex diff --git a/tests/data/test_normalization.py b/tests/data/test_normalization.py new file mode 100644 index 00000000..77373320 --- /dev/null +++ b/tests/data/test_normalization.py @@ -0,0 +1,11 @@ +from sleap_nn.data.providers import LabelsReader +from sleap_nn.data.normalization import Normalizer +import torch + + +def test_normalizer(minimal_instance): + p = LabelsReader.from_filename(minimal_instance) + p = Normalizer(p) + + ex = next(iter(p)) + assert ex["image"].dtype == torch.float32 From c40efc0d2ff81f91f79d668a401bd1e5b0a2ee9a Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Fri, 21 Jul 2023 17:36:00 -0700 Subject: [PATCH 6/9] Add .vscode to gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 68bc17f9..5e8f6be6 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,5 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +.vscode/ \ No newline at end of file From 906f9f29eafcaef54cccef3788621748e64ff414 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Fri, 21 Jul 2023 17:38:08 -0700 Subject: [PATCH 7/9] Remove torchdata dependency --- environment.yml | 1 - environment_cpu.yml | 1 - environment_osx-arm64.yml | 1 - pyproject.toml | 4 ++-- 4 files changed, 2 insertions(+), 5 deletions(-) diff --git a/environment.yml b/environment.yml index a0af165c..f4ed8632 100644 --- a/environment.yml +++ b/environment.yml @@ -14,7 +14,6 @@ dependencies: - cudnn - pytorch - kornia - - torchdata - torchvision - imageio - av diff --git a/environment_cpu.yml b/environment_cpu.yml index a89f8dff..22d6dd8c 100644 --- a/environment_cpu.yml +++ b/environment_cpu.yml @@ -12,7 +12,6 @@ dependencies: - lightning=2.0.5 # due to dependency conflict Lightning Issue (#18027) - cpuonly - kornia - - torchdata - torchvision - imageio - av diff --git a/environment_osx-arm64.yml b/environment_osx-arm64.yml index 0a5f4ef0..0a750b2c 100644 --- a/environment_osx-arm64.yml +++ b/environment_osx-arm64.yml @@ -11,7 +11,6 @@ dependencies: - lightning=2.0.5 # due to dependency conflict Lightning Issue (#18027) - pytorch - kornia - - torchdata - torchvision - imageio - av diff --git a/pyproject.toml b/pyproject.toml index a914ee32..248072de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,6 @@ classifiers = [ ] dependencies = [ "torch>=2.0.0", - "torchdata", "torchvision", "pydantic<2.0", "lightning==2.0.5", @@ -27,7 +26,8 @@ dependencies = [ "av", "albumentations", "hydra-core", - "sleap-io>=0.0.7" + "sleap-io>=0.0.7", + "kornia" ] dynamic = ["version", "readme"] From cf7b92726c9bbe913037436bbba1204eaced05c8 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Fri, 21 Jul 2023 17:40:44 -0700 Subject: [PATCH 8/9] Add ipython to dev requirements --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 248072de..e0d33a0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,8 @@ dev = [ "pydocstyle", "toml", "twine", - "build" + "build", + "ipython" ] [project.scripts] From fc87457d1860ac0fa447d4eccdcb3f50af988d34 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Fri, 21 Jul 2023 17:44:58 -0700 Subject: [PATCH 9/9] Remove unused albumentations dependency --- environment.yml | 1 - environment_cpu.yml | 1 - environment_osx-arm64.yml | 1 - pyproject.toml | 1 - 4 files changed, 4 deletions(-) diff --git a/environment.yml b/environment.yml index f4ed8632..b6db05c5 100644 --- a/environment.yml +++ b/environment.yml @@ -18,7 +18,6 @@ dependencies: - imageio - av - ffmpeg - - albumentations - matplotlib - pip - pip: diff --git a/environment_cpu.yml b/environment_cpu.yml index 22d6dd8c..26c794f0 100644 --- a/environment_cpu.yml +++ b/environment_cpu.yml @@ -16,7 +16,6 @@ dependencies: - imageio - av - ffmpeg - - albumentations - matplotlib - pip - pip: diff --git a/environment_osx-arm64.yml b/environment_osx-arm64.yml index 0a750b2c..65db9f6f 100644 --- a/environment_osx-arm64.yml +++ b/environment_osx-arm64.yml @@ -15,7 +15,6 @@ dependencies: - imageio - av - ffmpeg - - albumentations - matplotlib - pip - pip: diff --git a/pyproject.toml b/pyproject.toml index e0d33a0e..45fc3c0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ dependencies = [ "imageio", "imageio-ffmpeg", "av", - "albumentations", "hydra-core", "sleap-io>=0.0.7", "kornia"