Skip to content

Commit

Permalink
Don't create instances during inference if no points were found (#1073)
Browse files Browse the repository at this point in the history
* Don't create instances during inference if no points were found

* Add points check for all predictors

* Fix single instance predictor logic and test

* Add tests for all predictors

Co-authored-by: roomrys <[email protected]>
  • Loading branch information
talmo and roomrys authored Dec 9, 2022
1 parent eac2e2b commit ebd2e1e
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 11 deletions.
27 changes: 18 additions & 9 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ def call(
tf.int64
) # (batch_size, n_centroids, 1, 1, 2)
dists = a - b # (batch_size, n_centroids, n_insts, n_nodes, 2)
dists = tf.sqrt(tf.reduce_sum(dists ** 2, axis=-1)) # reduce over xy
dists = tf.sqrt(tf.reduce_sum(tf.math.square(dists), axis=-1)) # reduce over xy
dists = tf.reduce_min(dists, axis=-1) # reduce over nodes
dists = dists.to_tensor(
tf.cast(np.NaN, tf.float32)
Expand Down Expand Up @@ -1453,14 +1453,17 @@ def _make_labeled_frames_from_generator(
ex["instance_peak_vals"],
):
# Loop over instances.
predicted_instances = [
sleap.instance.PredictedInstance.from_arrays(
points=points[0],
point_confidences=confidences[0],
instance_score=np.nansum(confidences[0]),
skeleton=skeleton,
)
]
if np.isnan(points[0]).all():
predicted_instances = []
else:
predicted_instances = [
sleap.instance.PredictedInstance.from_arrays(
points=points[0],
point_confidences=confidences[0],
instance_score=np.nansum(confidences[0]),
skeleton=skeleton,
)
]

predicted_frames.append(
sleap.LabeledFrame(
Expand Down Expand Up @@ -2434,6 +2437,9 @@ def _make_labeled_frames_from_generator(
# Loop over instances.
predicted_instances = []
for pts, confs, score in zip(points, confidences, scores):
if np.isnan(pts).all():
continue

predicted_instances.append(
sleap.instance.PredictedInstance.from_arrays(
points=pts,
Expand Down Expand Up @@ -2999,6 +3005,9 @@ def _make_labeled_frames_from_generator(
# Loop over instances.
predicted_instances = []
for pts, confs, score in zip(points, confidences, scores):
if np.isnan(pts).all():
continue

predicted_instances.append(
sleap.instance.PredictedInstance.from_arrays(
points=pts,
Expand Down
75 changes: 73 additions & 2 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,14 +573,25 @@ def test_single_instance_predictor(
def test_single_instance_predictor_high_peak_thresh(
min_labels_robot, min_single_instance_robot_model_path
):
predictor = SingleInstancePredictor.from_trained_models(
min_single_instance_robot_model_path, peak_threshold=0
)
predictor.verbosity = "none"
labels_pr = predictor.predict(min_labels_robot)
assert len(labels_pr) == 2
assert len(labels_pr[0]) == 1
assert labels_pr[0][0].n_visible_points == 2
assert len(labels_pr[1]) == 1
assert labels_pr[1][0].n_visible_points == 2

predictor = SingleInstancePredictor.from_trained_models(
min_single_instance_robot_model_path, peak_threshold=1.5
)
predictor.verbosity = "none"
labels_pr = predictor.predict(min_labels_robot)
assert len(labels_pr) == 2
assert labels_pr[0][0].n_visible_points == 0
assert labels_pr[1][0].n_visible_points == 0
assert len(labels_pr[0]) == 0
assert len(labels_pr[1]) == 0


def test_topdown_predictor_centroid(min_labels, min_centroid_model_path):
Expand Down Expand Up @@ -612,6 +623,16 @@ def test_topdown_predictor_centroid(min_labels, min_centroid_model_path):
assert len(labels_pr[0].instances) == 2


def test_topdown_predictor_centroid_high_threshold(min_labels, min_centroid_model_path):
predictor = TopDownPredictor.from_trained_models(
centroid_model_path=min_centroid_model_path, peak_threshold=1.5
)
predictor.verbosity = "none"
labels_pr = predictor.predict(min_labels)
assert len(labels_pr) == 1
assert len(labels_pr[0].instances) == 0


def test_topdown_predictor_centered_instance(
min_labels, min_centered_instance_model_path
):
Expand All @@ -636,6 +657,18 @@ def test_topdown_predictor_centered_instance(
assert_allclose(points_gt[inds1.numpy()], points_pr[inds2.numpy()], atol=1.5)


def test_topdown_predictor_centered_instance_high_threshold(
min_labels, min_centered_instance_model_path
):
predictor = TopDownPredictor.from_trained_models(
confmap_model_path=min_centered_instance_model_path, peak_threshold=1.5
)
predictor.verbosity = "none"
labels_pr = predictor.predict(min_labels)
assert len(labels_pr) == 1
assert len(labels_pr[0].instances) == 0


def test_bottomup_predictor(min_labels, min_bottomup_model_path):
predictor = BottomUpPredictor.from_trained_models(
model_path=min_bottomup_model_path
Expand Down Expand Up @@ -666,6 +699,16 @@ def test_bottomup_predictor(min_labels, min_bottomup_model_path):
assert len(labels_pr[0]) == 0


def test_bottomup_predictor_high_peak_thresh(min_labels, min_bottomup_model_path):
predictor = BottomUpPredictor.from_trained_models(
model_path=min_bottomup_model_path, peak_threshold=1.5
)
predictor.verbosity = "none"
labels_pr = predictor.predict(min_labels)
assert len(labels_pr) == 1
assert len(labels_pr[0].instances) == 0


def test_bottomup_multiclass_predictor(
min_tracks_2node_labels, min_bottomup_multiclass_model_path
):
Expand Down Expand Up @@ -698,6 +741,20 @@ def test_bottomup_multiclass_predictor(
labels_pr[0][1].track.name == "male"


def test_bottomup_multiclass_predictor_high_threshold(
min_tracks_2node_labels, min_bottomup_multiclass_model_path
):
labels_gt = sleap.Labels(min_tracks_2node_labels[[0]])
predictor = BottomUpMultiClassPredictor.from_trained_models(
model_path=min_bottomup_multiclass_model_path,
peak_threshold=1.5,
integral_refinement=False,
)
labels_pr = predictor.predict(labels_gt)
assert len(labels_pr) == 1
assert len(labels_pr[0].instances) == 0


def test_topdown_multiclass_predictor(
min_tracks_2node_labels, min_topdown_multiclass_model_path
):
Expand All @@ -724,6 +781,20 @@ def test_topdown_multiclass_predictor(
)


def test_topdown_multiclass_predictor_high_threshold(
min_tracks_2node_labels, min_topdown_multiclass_model_path
):
labels_gt = sleap.Labels(min_tracks_2node_labels[[0]])
predictor = TopDownMultiClassPredictor.from_trained_models(
confmap_model_path=min_topdown_multiclass_model_path,
peak_threshold=1.5,
integral_refinement=False,
)
labels_pr = predictor.predict(labels_gt)
assert len(labels_pr) == 1
assert len(labels_pr[0].instances) == 0


def test_load_model(
min_single_instance_robot_model_path,
min_centroid_model_path,
Expand Down

0 comments on commit ebd2e1e

Please sign in to comment.