Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter OOB points while training #2061

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def linkcode_resolve(domain, info):
# These paths are either relative to html_static_path
# or fully qualified paths (eg. https://...)
html_css_files = [
'css/tabs.css',
"css/tabs.css",
]

# Custom sidebar templates, must be a dictionary that maps document names
Expand Down
6 changes: 6 additions & 0 deletions sleap/nn/data/instance_cropping.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,15 @@ def find_instance_crop_size(

# Calculate crop size
min_crop_size_no_pad = min_crop_size - padding
height, width, _ = labels[0].image.shape
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
max_length = 0.0
for inst in labels.user_instances:
pts = inst.points_array

pts[pts < 0] = np.NaN
pts[:, 0][pts[:, 0] > height - 1] = np.NaN
pts[:, 1][pts[:, 1] > width - 1] = np.NaN

pts *= input_scaling
max_length = np.maximum(max_length, np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0]))
max_length = np.maximum(max_length, np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1]))
Expand Down
19 changes: 19 additions & 0 deletions sleap/nn/data/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import attr
from typing import Text, Optional, List, Sequence, Union, Tuple
import sleap
from sleap.instance import Instance


@attr.s(auto_attribs=True)
Expand Down Expand Up @@ -198,6 +199,24 @@ def py_fetch_lf(ind):
raw_image = lf.image
raw_image_size = np.array(raw_image.shape).astype("int32")

height, width, _ = raw_image_size

# Filter OOB points
instances = []
for instance in lf.instances:
pts = instance.numpy()
# negative coords
pts[pts < 0] = np.NaN

# coordinates outside img frame
pts[:, 0][pts[:, 0] > height - 1] = np.NaN
pts[:, 1][pts[:, 1] > width - 1] = np.NaN
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved

instances.append(
Instance.from_numpy(pts, instance.skeleton, instance.track)
)
lf.instances = instances

if self.user_instances_only:
insts = lf.user_instances
else:
Expand Down
31 changes: 31 additions & 0 deletions tests/nn/data/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,37 @@ def test_labels_reader_no_visible_points(min_labels):
assert len(labels_reader) == 0


def test_labels_filter_oob_points(min_labels):
# There should be two instances in the labels dataset
labels = min_labels.copy()
assert len(labels.labeled_frames[0].instances) == 2

assert labels[0].instances[0].numpy().shape[0] == 2 # 2 nodes

labels[0].instances[0][0] = (390, 100) # exceeds img height
print(labels[0].instances)

labels_reader = providers.LabelsReader.from_user_instances(labels)
examples = list(iter(labels_reader.make_dataset()))
assert len(examples) == 1

assert all(np.isnan(examples[0]["instances"][0][0]))

# test with negative keypoints
labels = min_labels.copy()
assert len(labels.labeled_frames[0].instances) == 2

assert labels[0].instances[0].numpy().shape[0] == 2 # 2 nodes

labels[0].instances[0][1] = (-10, 100)

eberrigan marked this conversation as resolved.
Show resolved Hide resolved
labels_reader = providers.LabelsReader.from_user_instances(labels)
examples = list(iter(labels_reader.make_dataset()))
assert len(examples) == 1

assert all(np.isnan(examples[0]["instances"][0][1]))


gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
def test_labels_reader_subset(min_labels):
labels = sleap.Labels([min_labels[0], min_labels[0], min_labels[0]])
assert len(labels) == 3
Expand Down
Loading