From 12dd8593bd3bce41b494b6e4ec478e8c83fc2f1f Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Tue, 17 Dec 2024 08:49:31 -0800 Subject: [PATCH 01/12] Filter OOB points while training --- docs/conf.py | 2 +- sleap/nn/data/providers.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index a429c7928..796497f6b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -178,7 +178,7 @@ def linkcode_resolve(domain, info): # These paths are either relative to html_static_path # or fully qualified paths (eg. https://...) html_css_files = [ - 'css/tabs.css', + "css/tabs.css", ] # Custom sidebar templates, must be a dictionary that maps document names diff --git a/sleap/nn/data/providers.py b/sleap/nn/data/providers.py index 9e93d0b18..3952c2b55 100644 --- a/sleap/nn/data/providers.py +++ b/sleap/nn/data/providers.py @@ -5,6 +5,7 @@ import attr from typing import Text, Optional, List, Sequence, Union, Tuple import sleap +from sleap.instance import Instance @attr.s(auto_attribs=True) @@ -198,6 +199,25 @@ def py_fetch_lf(ind): raw_image = lf.image raw_image_size = np.array(raw_image.shape).astype("int32") + height, width = raw_image_size + + # Filter OOB points + instances = [] + for instance in lf.instances: + pts = instance.numpy() + # negative coords + pts[pts < 0] = np.NaN + + # coordinates outside img frame + pts[:, 0][pts[:, 0] > height - 1] = np.NaN + pts[:, 1][pts[:, 1] > width - 1] = np.NaN + + # remove all nans + pts = pts[~np.isnan(pts).any(axis=1), :] + + instances.append(Instance.from_numpy(pts, lf.skeleton, lf.track)) + lf.instances = instances + if self.user_instances_only: insts = lf.user_instances else: From d99b51bfb821d52873f690523297978d5c6fd920 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Tue, 17 Dec 2024 09:14:16 -0800 Subject: [PATCH 02/12] Fix img shape --- sleap/nn/data/providers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/nn/data/providers.py b/sleap/nn/data/providers.py index 3952c2b55..d2ab62f8b 100644 --- a/sleap/nn/data/providers.py +++ b/sleap/nn/data/providers.py @@ -199,7 +199,7 @@ def py_fetch_lf(ind): raw_image = lf.image raw_image_size = np.array(raw_image.shape).astype("int32") - height, width = raw_image_size + height, width, _ = raw_image_size # Filter OOB points instances = [] From f00afd6fbe60a0e0167e4ca3c7439ba3b958092a Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Tue, 17 Dec 2024 10:26:34 -0800 Subject: [PATCH 03/12] Convert oob pts to nans --- sleap/nn/data/providers.py | 7 +++---- tests/nn/data/test_providers.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/sleap/nn/data/providers.py b/sleap/nn/data/providers.py index d2ab62f8b..1d05bc5d9 100644 --- a/sleap/nn/data/providers.py +++ b/sleap/nn/data/providers.py @@ -212,10 +212,9 @@ def py_fetch_lf(ind): pts[:, 0][pts[:, 0] > height - 1] = np.NaN pts[:, 1][pts[:, 1] > width - 1] = np.NaN - # remove all nans - pts = pts[~np.isnan(pts).any(axis=1), :] - - instances.append(Instance.from_numpy(pts, lf.skeleton, lf.track)) + instances.append( + Instance.from_numpy(pts, instance.skeleton, instance.track) + ) lf.instances = instances if self.user_instances_only: diff --git a/tests/nn/data/test_providers.py b/tests/nn/data/test_providers.py index f30216e6a..3d5159780 100644 --- a/tests/nn/data/test_providers.py +++ b/tests/nn/data/test_providers.py @@ -68,6 +68,37 @@ def test_labels_reader_no_visible_points(min_labels): assert len(labels_reader) == 0 +def test_labels_filter_oob_points(min_labels): + # There should be two instances in the labels dataset + labels = min_labels.copy() + assert len(labels.labeled_frames[0].instances) == 2 + + assert labels[0].instances[0].numpy().shape[0] == 2 # 2 nodes + + labels[0].instances[0][0] = (390, 100) # exceeds img height + print(labels[0].instances) + + labels_reader = providers.LabelsReader.from_user_instances(labels) + examples = list(iter(labels_reader.make_dataset())) + assert len(examples) == 1 + + assert all(np.isnan(examples[0]["instances"][0][0])) + + # test with negative keypoints + labels = min_labels.copy() + assert len(labels.labeled_frames[0].instances) == 2 + + assert labels[0].instances[0].numpy().shape[0] == 2 # 2 nodes + + labels[0].instances[0][1] = (-10, 100) + + labels_reader = providers.LabelsReader.from_user_instances(labels) + examples = list(iter(labels_reader.make_dataset())) + assert len(examples) == 1 + + assert all(np.isnan(examples[0]["instances"][0][1])) + + def test_labels_reader_subset(min_labels): labels = sleap.Labels([min_labels[0], min_labels[0], min_labels[0]]) assert len(labels) == 3 From ed56b36f48efc3a1ddf05ef9662a9b714cebf1f9 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Tue, 17 Dec 2024 10:47:48 -0800 Subject: [PATCH 04/12] Filter oob while computing crop size --- sleap/nn/data/instance_cropping.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sleap/nn/data/instance_cropping.py b/sleap/nn/data/instance_cropping.py index 1cfd5eee7..09347881e 100644 --- a/sleap/nn/data/instance_cropping.py +++ b/sleap/nn/data/instance_cropping.py @@ -41,9 +41,15 @@ def find_instance_crop_size( # Calculate crop size min_crop_size_no_pad = min_crop_size - padding + height, width, _ = labels[0].image.shape max_length = 0.0 for inst in labels.user_instances: pts = inst.points_array + + pts[pts < 0] = np.NaN + pts[:, 0][pts[:, 0] > height - 1] = np.NaN + pts[:, 1][pts[:, 1] > width - 1] = np.NaN + pts *= input_scaling max_length = np.maximum(max_length, np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])) max_length = np.maximum(max_length, np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1])) From f9efd81158d5bc6d7b7c589687f75c9d329e939c Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Tue, 17 Dec 2024 11:39:50 -0800 Subject: [PATCH 05/12] Fix test and img size --- sleap/nn/data/instance_cropping.py | 29 +++++++++++++++++------------ tests/nn/test_inference.py | 2 +- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/sleap/nn/data/instance_cropping.py b/sleap/nn/data/instance_cropping.py index 09347881e..7e8167ac0 100644 --- a/sleap/nn/data/instance_cropping.py +++ b/sleap/nn/data/instance_cropping.py @@ -41,19 +41,24 @@ def find_instance_crop_size( # Calculate crop size min_crop_size_no_pad = min_crop_size - padding - height, width, _ = labels[0].image.shape max_length = 0.0 - for inst in labels.user_instances: - pts = inst.points_array - - pts[pts < 0] = np.NaN - pts[:, 0][pts[:, 0] > height - 1] = np.NaN - pts[:, 1][pts[:, 1] > width - 1] = np.NaN - - pts *= input_scaling - max_length = np.maximum(max_length, np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])) - max_length = np.maximum(max_length, np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1])) - max_length = np.maximum(max_length, min_crop_size_no_pad) + for lf in labels: + for inst in lf.user_instances: + pts = inst.points_array + + pts[pts < 0] = np.NaN + height, width, _ = lf.image.shape + pts[:, 0][pts[:, 0] > height - 1] = np.NaN + pts[:, 1][pts[:, 1] > width - 1] = np.NaN + + pts *= input_scaling + max_length = np.maximum( + max_length, np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0]) + ) + max_length = np.maximum( + max_length, np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1]) + ) + max_length = np.maximum(max_length, min_crop_size_no_pad) max_length += float(padding) crop_size = np.math.ceil(max_length / float(maximum_stride)) * maximum_stride diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 0a978de0a..adbd51715 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -393,7 +393,7 @@ def test_topdown_model(test_pipeline): assert tuple(out["instance_peak_vals"].shape) == (8, 2, 2) assert tuple(out["n_valid"].shape) == (8,) - assert (out["n_valid"] == [1, 1, 1, 2, 2, 2, 2, 2]).all() + assert (out["n_valid"] == [1, 1, 1, 2, 2, 2, 1, 1]).all() def test_inference_layer(): From 83fa2e5ab639915157174766ba4a5a50c4664c02 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Tue, 17 Dec 2024 14:05:21 -0800 Subject: [PATCH 06/12] Fix inference test --- tests/nn/test_inference.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index adbd51715..8cb6fbc6f 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -2039,7 +2039,11 @@ def test_movenet_predictor(min_dance_labels, movenet_video): [labels_pr[0][0].numpy(), labels_pr[1][0].numpy()], axis=0 ) - np.testing.assert_allclose(points_gt, points_pr, atol=0.75) + assert_allclose( + points_gt[~np.isnan(points_gt).any(axis=1)], + points_pr[~np.isnan(points_gt).any(axis=1)], + atol=0.75, + ) @pytest.mark.parametrize( From 6a2e1641fc76033269f61621d9469e2d36ad0556 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Wed, 18 Dec 2024 11:27:27 -0800 Subject: [PATCH 07/12] Fix img shape --- sleap/nn/data/instance_cropping.py | 6 +++--- sleap/nn/data/providers.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sleap/nn/data/instance_cropping.py b/sleap/nn/data/instance_cropping.py index 7e8167ac0..dd9f60737 100644 --- a/sleap/nn/data/instance_cropping.py +++ b/sleap/nn/data/instance_cropping.py @@ -47,9 +47,9 @@ def find_instance_crop_size( pts = inst.points_array pts[pts < 0] = np.NaN - height, width, _ = lf.image.shape - pts[:, 0][pts[:, 0] > height - 1] = np.NaN - pts[:, 1][pts[:, 1] > width - 1] = np.NaN + height, width = lf.image.shape[:2] + pts[:, 0][pts[:, 0] > width - 1] = np.NaN + pts[:, 1][pts[:, 1] > height - 1] = np.NaN pts *= input_scaling max_length = np.maximum( diff --git a/sleap/nn/data/providers.py b/sleap/nn/data/providers.py index 1d05bc5d9..c9fd30cac 100644 --- a/sleap/nn/data/providers.py +++ b/sleap/nn/data/providers.py @@ -199,7 +199,7 @@ def py_fetch_lf(ind): raw_image = lf.image raw_image_size = np.array(raw_image.shape).astype("int32") - height, width, _ = raw_image_size + height, width = raw_image_size[:2] # Filter OOB points instances = [] @@ -209,8 +209,8 @@ def py_fetch_lf(ind): pts[pts < 0] = np.NaN # coordinates outside img frame - pts[:, 0][pts[:, 0] > height - 1] = np.NaN - pts[:, 1][pts[:, 1] > width - 1] = np.NaN + pts[:, 0][pts[:, 0] > width - 1] = np.NaN + pts[:, 1][pts[:, 1] > height - 1] = np.NaN instances.append( Instance.from_numpy(pts, instance.skeleton, instance.track) From 042e55bd8cc8287410893eb167386bcd6e6ed448 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Wed, 18 Dec 2024 12:22:55 -0800 Subject: [PATCH 08/12] Add more tests --- tests/nn/data/test_instance_cropping.py | 18 ++++++++++++ tests/nn/data/test_providers.py | 1 - tests/nn/test_training.py | 38 +++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 1 deletion(-) diff --git a/tests/nn/data/test_instance_cropping.py b/tests/nn/data/test_instance_cropping.py index 688f50dbd..cc4894516 100644 --- a/tests/nn/data/test_instance_cropping.py +++ b/tests/nn/data/test_instance_cropping.py @@ -11,6 +11,24 @@ from sleap.nn.config import InstanceCroppingConfig +def test_find_instance_crop_size(min_labels): + labels = min_labels.copy() + assert len(labels.labeled_frames[0].instances) == 2 + + crop_size = instance_cropping.find_instance_crop_size(labels) + assert crop_size == 74 + + assert labels[0].instances[0].numpy().shape[0] == 2 # 2 nodes + + labels[0].instances[1][0] = (390, 187.9) # exceeds img height + crop_size = instance_cropping.find_instance_crop_size(labels) + assert crop_size == 60 + + labels[0].instances[1][0] = (-100, 187.9) # exceeds img height + crop_size = instance_cropping.find_instance_crop_size(labels) + assert crop_size == 60 + + def test_normalize_bboxes(): bbox = tf.convert_to_tensor([[0, 0, 3, 3]], tf.float32) norm_bbox = instance_cropping.normalize_bboxes(bbox, 9, 9) diff --git a/tests/nn/data/test_providers.py b/tests/nn/data/test_providers.py index 3d5159780..5b8140904 100644 --- a/tests/nn/data/test_providers.py +++ b/tests/nn/data/test_providers.py @@ -76,7 +76,6 @@ def test_labels_filter_oob_points(min_labels): assert labels[0].instances[0].numpy().shape[0] == 2 # 2 nodes labels[0].instances[0][0] = (390, 100) # exceeds img height - print(labels[0].instances) labels_reader = providers.LabelsReader.from_user_instances(labels) examples = list(iter(labels_reader.make_dataset())) diff --git a/tests/nn/test_training.py b/tests/nn/test_training.py index 72db17bb5..58485b761 100644 --- a/tests/nn/test_training.py +++ b/tests/nn/test_training.py @@ -220,6 +220,44 @@ def test_train_topdown(training_labels, cfg): assert tuple(trainer.keras_model.outputs[0].shape) == (None, 96, 96, 2) +def test_train_topdown_with_oob_pts(min_labels, cfg): + # pt exceeding img dim + labels = min_labels + labels.append( + sleap.LabeledFrame( + video=labels.videos[0], frame_idx=1, instances=labels[0].instances + ) + ) + labels[0].instances[1][0] = (390, 187.9) # crop size=60 + + cfg.model.heads.centered_instance = CenteredInstanceConfmapsHeadConfig( + sigma=1.5, output_stride=1, offset_refinement=False + ) + trainer = TopdownConfmapsModelTrainer.from_config(cfg, training_labels=labels) + trainer.setup() + trainer.train() + assert trainer.keras_model.output_names[0] == "CenteredInstanceConfmapsHead" + assert tuple(trainer.keras_model.outputs[0].shape) == (None, 80, 80, 2) + + # negative pts + labels = min_labels + labels.append( + sleap.LabeledFrame( + video=labels.videos[0], frame_idx=1, instances=labels[0].instances + ) + ) + labels[0].instances[1][0] = (-100, 187.9) # crop size=60 + + cfg.model.heads.centered_instance = CenteredInstanceConfmapsHeadConfig( + sigma=1.5, output_stride=1, offset_refinement=False + ) + trainer = TopdownConfmapsModelTrainer.from_config(cfg, training_labels=labels) + trainer.setup() + trainer.train() + assert trainer.keras_model.output_names[0] == "CenteredInstanceConfmapsHead" + assert tuple(trainer.keras_model.outputs[0].shape) == (None, 80, 80, 2) + + def test_train_topdown_with_offset(training_labels, cfg): cfg.model.heads.centered_instance = CenteredInstanceConfmapsHeadConfig( sigma=1.5, output_stride=1, offset_refinement=True From a75b0df309bd9c6d8645d2d5b62b1bab1483fe3e Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Wed, 18 Dec 2024 19:46:41 -0800 Subject: [PATCH 09/12] Modify test cases --- tests/nn/data/test_providers.py | 23 +++++++---------------- tests/nn/test_training.py | 26 ++++++-------------------- 2 files changed, 13 insertions(+), 36 deletions(-) diff --git a/tests/nn/data/test_providers.py b/tests/nn/data/test_providers.py index 5b8140904..4540c539b 100644 --- a/tests/nn/data/test_providers.py +++ b/tests/nn/data/test_providers.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import tensorflow as tf from sleap.nn.system import use_cpu_only @@ -68,14 +69,18 @@ def test_labels_reader_no_visible_points(min_labels): assert len(labels_reader) == 0 -def test_labels_filter_oob_points(min_labels): +@pytest.mark.parametrize( + "oob_point,test_case", + [((390, 187.9), "exceeding_image_dim"), ((-100, 187.9), "negative_coordinates")], +) +def test_labels_filter_oob_points(min_labels, oob_point, test_case): # There should be two instances in the labels dataset labels = min_labels.copy() assert len(labels.labeled_frames[0].instances) == 2 assert labels[0].instances[0].numpy().shape[0] == 2 # 2 nodes - labels[0].instances[0][0] = (390, 100) # exceeds img height + labels[0].instances[0][0] = oob_point labels_reader = providers.LabelsReader.from_user_instances(labels) examples = list(iter(labels_reader.make_dataset())) @@ -83,20 +88,6 @@ def test_labels_filter_oob_points(min_labels): assert all(np.isnan(examples[0]["instances"][0][0])) - # test with negative keypoints - labels = min_labels.copy() - assert len(labels.labeled_frames[0].instances) == 2 - - assert labels[0].instances[0].numpy().shape[0] == 2 # 2 nodes - - labels[0].instances[0][1] = (-10, 100) - - labels_reader = providers.LabelsReader.from_user_instances(labels) - examples = list(iter(labels_reader.make_dataset())) - assert len(examples) == 1 - - assert all(np.isnan(examples[0]["instances"][0][1])) - def test_labels_reader_subset(min_labels): labels = sleap.Labels([min_labels[0], min_labels[0], min_labels[0]]) diff --git a/tests/nn/test_training.py b/tests/nn/test_training.py index 58485b761..96bc6e4c0 100644 --- a/tests/nn/test_training.py +++ b/tests/nn/test_training.py @@ -220,7 +220,11 @@ def test_train_topdown(training_labels, cfg): assert tuple(trainer.keras_model.outputs[0].shape) == (None, 96, 96, 2) -def test_train_topdown_with_oob_pts(min_labels, cfg): +@pytest.mark.parametrize( + "oob_point,test_case", + [((390, 187.9), "exceeding_image_dim"), ((-100, 187.9), "negative_coordinates")], +) +def test_train_topdown_with_oob_pts(min_labels, cfg, oob_point, test_case): # pt exceeding img dim labels = min_labels labels.append( @@ -228,25 +232,7 @@ def test_train_topdown_with_oob_pts(min_labels, cfg): video=labels.videos[0], frame_idx=1, instances=labels[0].instances ) ) - labels[0].instances[1][0] = (390, 187.9) # crop size=60 - - cfg.model.heads.centered_instance = CenteredInstanceConfmapsHeadConfig( - sigma=1.5, output_stride=1, offset_refinement=False - ) - trainer = TopdownConfmapsModelTrainer.from_config(cfg, training_labels=labels) - trainer.setup() - trainer.train() - assert trainer.keras_model.output_names[0] == "CenteredInstanceConfmapsHead" - assert tuple(trainer.keras_model.outputs[0].shape) == (None, 80, 80, 2) - - # negative pts - labels = min_labels - labels.append( - sleap.LabeledFrame( - video=labels.videos[0], frame_idx=1, instances=labels[0].instances - ) - ) - labels[0].instances[1][0] = (-100, 187.9) # crop size=60 + labels[0].instances[1][0] = oob_point # crop size=60 cfg.model.heads.centered_instance = CenteredInstanceConfmapsHeadConfig( sigma=1.5, output_stride=1, offset_refinement=False From 11af641ef95ca90d72ad93ec3a6559aad6864d8e Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Thu, 19 Dec 2024 09:12:52 -0800 Subject: [PATCH 10/12] Refactor filter oob --- sleap/nn/data/providers.py | 79 +++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 35 deletions(-) diff --git a/sleap/nn/data/providers.py b/sleap/nn/data/providers.py index c9fd30cac..319a2d007 100644 --- a/sleap/nn/data/providers.py +++ b/sleap/nn/data/providers.py @@ -201,50 +201,59 @@ def py_fetch_lf(ind): height, width = raw_image_size[:2] - # Filter OOB points + if self.user_instances_only: + insts = lf.user_instances + else: + insts = lf.instances + instances = [] - for instance in lf.instances: - pts = instance.numpy() - # negative coords - pts[pts < 0] = np.NaN - # coordinates outside img frame - pts[:, 0][pts[:, 0] > width - 1] = np.NaN - pts[:, 1][pts[:, 1] > height - 1] = np.NaN + for inst in insts: + + if len(inst) > 0: + + # Filter OOB + pts = inst.numpy() + pts[pts < 0] = np.NaN + + pts[:, 0][pts[:, 0] > width - 1] = np.NaN + pts[:, 1][pts[:, 1] > height - 1] = np.NaN + + instance = Instance.from_numpy(pts, inst.skeleton, inst.track) + + if self.with_track_only: + if instance.track is not None: + instances.append(instance) + + else: + instances.append(instance) - instances.append( - Instance.from_numpy(pts, instance.skeleton, instance.track) + n_instances = len(instances) + n_nodes = len(instances[0].skeleton) if n_instances > 0 else 0 + + insts = np.full((n_instances, n_nodes, 2), np.nan, dtype="float32") + track_inds = [] + skeleton_inds = [] + for i, instance in enumerate(instances): + + track_inds.append( + self.tracks.index(instance.track) + if instance.track is not None + else -1 ) - lf.instances = instances - if self.user_instances_only: - insts = lf.user_instances - else: - insts = lf.instances - insts = [inst for inst in insts if len(inst) > 0] - if self.with_track_only: - insts = [inst for inst in insts if inst.track is not None] - n_instances = len(insts) - n_nodes = len(insts[0].skeleton) if n_instances > 0 else 0 - - instances = np.full((n_instances, n_nodes, 2), np.nan, dtype="float32") - for i, instance in enumerate(insts): - instances[i] = instance.numpy() - - skeleton_inds = np.array( - [self.labels.skeletons.index(inst.skeleton) for inst in insts] - ).astype("int32") - track_inds = np.array( - [ - self.tracks.index(inst.track) if inst.track is not None else -1 - for inst in insts - ] - ).astype("int32") + skeleton_inds.append(self.labels.skeletons.index(instance.skeleton)) + + insts[i] = instance.numpy() + + track_inds = np.array(track_inds).astype("int32") + skeleton_inds = np.array(skeleton_inds).astype("int32") + n_tracks = np.array(len(self.tracks)).astype("int32") return ( raw_image, raw_image_size, - instances, + insts, video_ind, frame_ind, skeleton_inds, From 889fc5afa39ecd21171afc5dd7029c0cd998ee03 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Thu, 19 Dec 2024 11:11:31 -0800 Subject: [PATCH 11/12] Fix for empty instances --- sleap/nn/data/providers.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sleap/nn/data/providers.py b/sleap/nn/data/providers.py index 319a2d007..b52a8601e 100644 --- a/sleap/nn/data/providers.py +++ b/sleap/nn/data/providers.py @@ -210,16 +210,16 @@ def py_fetch_lf(ind): for inst in insts: - if len(inst) > 0: + # Filter OOB + pts = inst.numpy() + pts[pts < 0] = np.NaN - # Filter OOB - pts = inst.numpy() - pts[pts < 0] = np.NaN + pts[:, 0][pts[:, 0] > width - 1] = np.NaN + pts[:, 1][pts[:, 1] > height - 1] = np.NaN - pts[:, 0][pts[:, 0] > width - 1] = np.NaN - pts[:, 1][pts[:, 1] > height - 1] = np.NaN + instance = Instance.from_numpy(pts, inst.skeleton, inst.track) - instance = Instance.from_numpy(pts, inst.skeleton, inst.track) + if len(instance) > 0: if self.with_track_only: if instance.track is not None: From 768ca909f9af4dc8ee03df751a2027178b2e2141 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Fri, 20 Dec 2024 08:59:33 -0800 Subject: [PATCH 12/12] Add function to filter oob --- sleap/nn/data/instance_cropping.py | 21 ++++++++++----------- sleap/nn/data/providers.py | 9 ++------- sleap/nn/data/utils.py | 10 ++++++++++ 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/sleap/nn/data/instance_cropping.py b/sleap/nn/data/instance_cropping.py index dd9f60737..73cab215c 100644 --- a/sleap/nn/data/instance_cropping.py +++ b/sleap/nn/data/instance_cropping.py @@ -6,6 +6,7 @@ from typing import Optional, List, Text import sleap from sleap.nn.config import InstanceCroppingConfig +from sleap.nn.data.utils import filter_oob_points def find_instance_crop_size( @@ -43,22 +44,20 @@ def find_instance_crop_size( min_crop_size_no_pad = min_crop_size - padding max_length = 0.0 for lf in labels: - for inst in lf.user_instances: - pts = inst.points_array + for inst in lf: + if isinstance(inst, sleap.PredictedInstance): + continue - pts[pts < 0] = np.NaN - height, width = lf.image.shape[:2] - pts[:, 0][pts[:, 0] > width - 1] = np.NaN - pts[:, 1][pts[:, 1] > height - 1] = np.NaN + pts = filter_oob_points(inst.numpy(), lf.image.shape[:2]) pts *= input_scaling - max_length = np.maximum( - max_length, np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0]) + max_length: float = np.nanmax( + [max_length, np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])] ) - max_length = np.maximum( - max_length, np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1]) + max_length: float = np.nanmax( + [max_length, np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1])] ) - max_length = np.maximum(max_length, min_crop_size_no_pad) + max_length: float = np.nanmax([max_length, min_crop_size_no_pad]) max_length += float(padding) crop_size = np.math.ceil(max_length / float(maximum_stride)) * maximum_stride diff --git a/sleap/nn/data/providers.py b/sleap/nn/data/providers.py index b52a8601e..d0a54858f 100644 --- a/sleap/nn/data/providers.py +++ b/sleap/nn/data/providers.py @@ -6,6 +6,7 @@ from typing import Text, Optional, List, Sequence, Union, Tuple import sleap from sleap.instance import Instance +from sleap.nn.data.utils import filter_oob_points @attr.s(auto_attribs=True) @@ -199,8 +200,6 @@ def py_fetch_lf(ind): raw_image = lf.image raw_image_size = np.array(raw_image.shape).astype("int32") - height, width = raw_image_size[:2] - if self.user_instances_only: insts = lf.user_instances else: @@ -211,11 +210,7 @@ def py_fetch_lf(ind): for inst in insts: # Filter OOB - pts = inst.numpy() - pts[pts < 0] = np.NaN - - pts[:, 0][pts[:, 0] > width - 1] = np.NaN - pts[:, 1][pts[:, 1] > height - 1] = np.NaN + pts = filter_oob_points(inst.numpy(), raw_image_size[:2]) instance = Instance.from_numpy(pts, inst.skeleton, inst.track) diff --git a/sleap/nn/data/utils.py b/sleap/nn/data/utils.py index 6b938220a..153be89bc 100644 --- a/sleap/nn/data/utils.py +++ b/sleap/nn/data/utils.py @@ -6,6 +6,16 @@ from typing import Any, List, Tuple, Dict, Text, Optional +def filter_oob_points(pts: np.ndarray, img_hw: tuple) -> np.ndarray: + """Convert negative/ out-of-boundary pts to NaNs.""" + pts[pts < 0] = np.NaN + height, width = img_hw + pts[:, 0][pts[:, 0] > width - 1] = np.NaN + pts[:, 1][pts[:, 1] > height - 1] = np.NaN + + return pts + + def ensure_list(x: Any) -> List[Any]: """Convert the input into a list if it is not already.""" if not isinstance(x, list):