diff --git a/docs/conf.py b/docs/conf.py index a429c7928..796497f6b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -178,7 +178,7 @@ def linkcode_resolve(domain, info): # These paths are either relative to html_static_path # or fully qualified paths (eg. https://...) html_css_files = [ - 'css/tabs.css', + "css/tabs.css", ] # Custom sidebar templates, must be a dictionary that maps document names diff --git a/sleap/nn/data/instance_cropping.py b/sleap/nn/data/instance_cropping.py index 1cfd5eee7..73cab215c 100644 --- a/sleap/nn/data/instance_cropping.py +++ b/sleap/nn/data/instance_cropping.py @@ -6,6 +6,7 @@ from typing import Optional, List, Text import sleap from sleap.nn.config import InstanceCroppingConfig +from sleap.nn.data.utils import filter_oob_points def find_instance_crop_size( @@ -42,12 +43,21 @@ def find_instance_crop_size( # Calculate crop size min_crop_size_no_pad = min_crop_size - padding max_length = 0.0 - for inst in labels.user_instances: - pts = inst.points_array - pts *= input_scaling - max_length = np.maximum(max_length, np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])) - max_length = np.maximum(max_length, np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1])) - max_length = np.maximum(max_length, min_crop_size_no_pad) + for lf in labels: + for inst in lf: + if isinstance(inst, sleap.PredictedInstance): + continue + + pts = filter_oob_points(inst.numpy(), lf.image.shape[:2]) + + pts *= input_scaling + max_length: float = np.nanmax( + [max_length, np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])] + ) + max_length: float = np.nanmax( + [max_length, np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1])] + ) + max_length: float = np.nanmax([max_length, min_crop_size_no_pad]) max_length += float(padding) crop_size = np.math.ceil(max_length / float(maximum_stride)) * maximum_stride diff --git a/sleap/nn/data/providers.py b/sleap/nn/data/providers.py index 9e93d0b18..d0a54858f 100644 --- a/sleap/nn/data/providers.py +++ b/sleap/nn/data/providers.py @@ -5,6 +5,8 @@ import attr from typing import Text, Optional, List, Sequence, Union, Tuple import sleap +from sleap.instance import Instance +from sleap.nn.data.utils import filter_oob_points @attr.s(auto_attribs=True) @@ -202,30 +204,51 @@ def py_fetch_lf(ind): insts = lf.user_instances else: insts = lf.instances - insts = [inst for inst in insts if len(inst) > 0] - if self.with_track_only: - insts = [inst for inst in insts if inst.track is not None] - n_instances = len(insts) - n_nodes = len(insts[0].skeleton) if n_instances > 0 else 0 - - instances = np.full((n_instances, n_nodes, 2), np.nan, dtype="float32") - for i, instance in enumerate(insts): - instances[i] = instance.numpy() - - skeleton_inds = np.array( - [self.labels.skeletons.index(inst.skeleton) for inst in insts] - ).astype("int32") - track_inds = np.array( - [ - self.tracks.index(inst.track) if inst.track is not None else -1 - for inst in insts - ] - ).astype("int32") + + instances = [] + + for inst in insts: + + # Filter OOB + pts = filter_oob_points(inst.numpy(), raw_image_size[:2]) + + instance = Instance.from_numpy(pts, inst.skeleton, inst.track) + + if len(instance) > 0: + + if self.with_track_only: + if instance.track is not None: + instances.append(instance) + + else: + instances.append(instance) + + n_instances = len(instances) + n_nodes = len(instances[0].skeleton) if n_instances > 0 else 0 + + insts = np.full((n_instances, n_nodes, 2), np.nan, dtype="float32") + track_inds = [] + skeleton_inds = [] + for i, instance in enumerate(instances): + + track_inds.append( + self.tracks.index(instance.track) + if instance.track is not None + else -1 + ) + + skeleton_inds.append(self.labels.skeletons.index(instance.skeleton)) + + insts[i] = instance.numpy() + + track_inds = np.array(track_inds).astype("int32") + skeleton_inds = np.array(skeleton_inds).astype("int32") + n_tracks = np.array(len(self.tracks)).astype("int32") return ( raw_image, raw_image_size, - instances, + insts, video_ind, frame_ind, skeleton_inds, diff --git a/sleap/nn/data/utils.py b/sleap/nn/data/utils.py index 6b938220a..153be89bc 100644 --- a/sleap/nn/data/utils.py +++ b/sleap/nn/data/utils.py @@ -6,6 +6,16 @@ from typing import Any, List, Tuple, Dict, Text, Optional +def filter_oob_points(pts: np.ndarray, img_hw: tuple) -> np.ndarray: + """Convert negative/ out-of-boundary pts to NaNs.""" + pts[pts < 0] = np.NaN + height, width = img_hw + pts[:, 0][pts[:, 0] > width - 1] = np.NaN + pts[:, 1][pts[:, 1] > height - 1] = np.NaN + + return pts + + def ensure_list(x: Any) -> List[Any]: """Convert the input into a list if it is not already.""" if not isinstance(x, list): diff --git a/tests/nn/data/test_instance_cropping.py b/tests/nn/data/test_instance_cropping.py index 688f50dbd..cc4894516 100644 --- a/tests/nn/data/test_instance_cropping.py +++ b/tests/nn/data/test_instance_cropping.py @@ -11,6 +11,24 @@ from sleap.nn.config import InstanceCroppingConfig +def test_find_instance_crop_size(min_labels): + labels = min_labels.copy() + assert len(labels.labeled_frames[0].instances) == 2 + + crop_size = instance_cropping.find_instance_crop_size(labels) + assert crop_size == 74 + + assert labels[0].instances[0].numpy().shape[0] == 2 # 2 nodes + + labels[0].instances[1][0] = (390, 187.9) # exceeds img height + crop_size = instance_cropping.find_instance_crop_size(labels) + assert crop_size == 60 + + labels[0].instances[1][0] = (-100, 187.9) # exceeds img height + crop_size = instance_cropping.find_instance_crop_size(labels) + assert crop_size == 60 + + def test_normalize_bboxes(): bbox = tf.convert_to_tensor([[0, 0, 3, 3]], tf.float32) norm_bbox = instance_cropping.normalize_bboxes(bbox, 9, 9) diff --git a/tests/nn/data/test_providers.py b/tests/nn/data/test_providers.py index f30216e6a..4540c539b 100644 --- a/tests/nn/data/test_providers.py +++ b/tests/nn/data/test_providers.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import tensorflow as tf from sleap.nn.system import use_cpu_only @@ -68,6 +69,26 @@ def test_labels_reader_no_visible_points(min_labels): assert len(labels_reader) == 0 +@pytest.mark.parametrize( + "oob_point,test_case", + [((390, 187.9), "exceeding_image_dim"), ((-100, 187.9), "negative_coordinates")], +) +def test_labels_filter_oob_points(min_labels, oob_point, test_case): + # There should be two instances in the labels dataset + labels = min_labels.copy() + assert len(labels.labeled_frames[0].instances) == 2 + + assert labels[0].instances[0].numpy().shape[0] == 2 # 2 nodes + + labels[0].instances[0][0] = oob_point + + labels_reader = providers.LabelsReader.from_user_instances(labels) + examples = list(iter(labels_reader.make_dataset())) + assert len(examples) == 1 + + assert all(np.isnan(examples[0]["instances"][0][0])) + + def test_labels_reader_subset(min_labels): labels = sleap.Labels([min_labels[0], min_labels[0], min_labels[0]]) assert len(labels) == 3 diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 0a978de0a..8cb6fbc6f 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -393,7 +393,7 @@ def test_topdown_model(test_pipeline): assert tuple(out["instance_peak_vals"].shape) == (8, 2, 2) assert tuple(out["n_valid"].shape) == (8,) - assert (out["n_valid"] == [1, 1, 1, 2, 2, 2, 2, 2]).all() + assert (out["n_valid"] == [1, 1, 1, 2, 2, 2, 1, 1]).all() def test_inference_layer(): @@ -2039,7 +2039,11 @@ def test_movenet_predictor(min_dance_labels, movenet_video): [labels_pr[0][0].numpy(), labels_pr[1][0].numpy()], axis=0 ) - np.testing.assert_allclose(points_gt, points_pr, atol=0.75) + assert_allclose( + points_gt[~np.isnan(points_gt).any(axis=1)], + points_pr[~np.isnan(points_gt).any(axis=1)], + atol=0.75, + ) @pytest.mark.parametrize( diff --git a/tests/nn/test_training.py b/tests/nn/test_training.py index 72db17bb5..96bc6e4c0 100644 --- a/tests/nn/test_training.py +++ b/tests/nn/test_training.py @@ -220,6 +220,30 @@ def test_train_topdown(training_labels, cfg): assert tuple(trainer.keras_model.outputs[0].shape) == (None, 96, 96, 2) +@pytest.mark.parametrize( + "oob_point,test_case", + [((390, 187.9), "exceeding_image_dim"), ((-100, 187.9), "negative_coordinates")], +) +def test_train_topdown_with_oob_pts(min_labels, cfg, oob_point, test_case): + # pt exceeding img dim + labels = min_labels + labels.append( + sleap.LabeledFrame( + video=labels.videos[0], frame_idx=1, instances=labels[0].instances + ) + ) + labels[0].instances[1][0] = oob_point # crop size=60 + + cfg.model.heads.centered_instance = CenteredInstanceConfmapsHeadConfig( + sigma=1.5, output_stride=1, offset_refinement=False + ) + trainer = TopdownConfmapsModelTrainer.from_config(cfg, training_labels=labels) + trainer.setup() + trainer.train() + assert trainer.keras_model.output_names[0] == "CenteredInstanceConfmapsHead" + assert tuple(trainer.keras_model.outputs[0].shape) == (None, 80, 80, 2) + + def test_train_topdown_with_offset(training_labels, cfg): cfg.model.heads.centered_instance = CenteredInstanceConfmapsHeadConfig( sigma=1.5, output_stride=1, offset_refinement=True