From 06ef1361f8db00cafbf946742ba767a11e36931d Mon Sep 17 00:00:00 2001 From: getzze Date: Fri, 2 Aug 2024 15:04:04 +0100 Subject: [PATCH] option to merge instances --- sleap/gui/learning/runners.py | 1 + sleap/instance.py | 73 +++++++++++++++++++++++++++++++++- sleap/nn/tracker/components.py | 25 +++++++++--- sleap/nn/tracking.py | 22 +++++++++- tests/nn/test_inference.py | 2 +- 5 files changed, 114 insertions(+), 9 deletions(-) diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py index 641e2c4ab..fb2d799e0 100644 --- a/sleap/gui/learning/runners.py +++ b/sleap/gui/learning/runners.py @@ -253,6 +253,7 @@ def make_predict_cli_call( bool_items_as_ints = ( "tracking.pre_cull_to_target", + "tracking.pre_cull_merge_instances", "tracking.post_connect_single_breaks", "tracking.save_shifted_instances", "tracking.oks_score_weighting", diff --git a/sleap/instance.py b/sleap/instance.py index 08a5c6ae6..90c7af78e 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -18,12 +18,14 @@ """ import math +from functools import reduce +from itertools import chain, combinations import numpy as np import cattr from copy import copy -from typing import Dict, List, Optional, Union, Tuple, ForwardRef +from typing import Dict, List, Optional, Union, Sequence, Tuple from numpy.lib.recfunctions import structured_to_unstructured @@ -1177,6 +1179,75 @@ def from_numpy( ) +def all_disjoint(x: Sequence[Sequence]) -> bool: + return all((set(p0).isdisjoint(set(p1))) for p0, p1 in combinations(x, 2)) + + +def create_merged_instances( + instances: List[PredictedInstance], + penalty: float = 0.2, +) -> List[PredictedInstance]: + """Create merged instances from the list of PredictedInstance. + + Only instances with non-overlapping visible nodes are merged. + + Args: + instances: a list of original PredictedInstances to try to merge. + penalty: a float between 0 and 1. All scores of the merged instance + are multplied by (1 - penalty). + + Returns: + a list of PredictedInstance that were merged. + """ + # Ensure same skeleton + skeletons = {inst.skeleton for inst in instances} + if len(skeletons) != 1: + return [] + skeleton = list(skeletons)[0] + + # Ensure same track + tracks = {inst.track for inst in instances} + if len(tracks) != 1: + return [] + track = list(tracks)[0] + + # Ensure non-intersecting visible nodes + merged_instances = [] + instance_subsets = ( + combinations(instances, n) for n in range(2, len(instances) + 1) + ) + instance_subsets = chain.from_iterable(instance_subsets) + for subset in instance_subsets: + nodes = [s.nodes for s in subset] + if not all_disjoint(nodes): + continue + + nodes_points_gen = chain.from_iterable( + instance.nodes_points for instance in subset + ) + predicted_points = {node: point for node, point in nodes_points_gen} + + instance_score = reduce(lambda x, y: x * y, [s.score for s in subset]) + + # Penalize scores of merged instances + if 0 < penalty <= 1: + factor = 1 - penalty + instance_score *= factor + for point in predicted_points.values(): + point.score *= factor + + merged_instance = PredictedInstance( + points=predicted_points, + skeleton=skeleton, + score=instance_score, + track=track, + ) + + merged_instances.append(merged_instance) + + return merged_instances + + def make_instance_cattr() -> cattr.Converter: """Create a cattr converter for Lists of Instances/PredictedInstances. diff --git a/sleap/nn/tracker/components.py b/sleap/nn/tracker/components.py index b2f35b21f..492c01d9f 100644 --- a/sleap/nn/tracker/components.py +++ b/sleap/nn/tracker/components.py @@ -12,6 +12,7 @@ """ + import operator from collections import defaultdict import logging @@ -23,6 +24,7 @@ from sleap import PredictedInstance, Instance, Track from sleap.nn import utils +from sleap.instance import create_merged_instances logger = logging.getLogger(__name__) @@ -249,7 +251,6 @@ def nms_fast(boxes, scores, iou_threshold, target_count=None) -> List[int]: # keep looping while some indexes still remain in the indexes list while len(idxs) > 0: - # we want to add the best box which is the last box in sorted list picked_box_idx = idxs[-1] @@ -351,6 +352,8 @@ def cull_frame_instances( instances_list: List[InstanceType], instance_count: int, iou_threshold: Optional[float] = None, + merge_instances: bool = False, + merging_penalty: float = 0.2, ) -> List["LabeledFrame"]: """ Removes instances (for single frame) over instance per frame threshold. @@ -361,6 +364,9 @@ def cull_frame_instances( iou_threshold: Intersection over Union (IOU) threshold to use when removing overlapping instances over target count; if None, then only use score to determine which instances to remove. + merge_instances: If True, allow merging instances with no overlapping + merging_penalty: a float between 0 and 1. All scores of the merged + instance are multplied by (1 - penalty). Returns: Updated list of frames, also modifies frames in place. @@ -368,6 +374,14 @@ def cull_frame_instances( if not instances_list: return + # Merge instances + if merge_instances: + logger.info("Merging instances with penalty: %f", merging_penalty) + merged_instances = create_merged_instances( + instances_list, penalty=merging_penalty + ) + instances_list.extend(merged_instances) + if len(instances_list) > instance_count: # List of instances which we'll pare down keep_instances = instances_list @@ -387,9 +401,10 @@ def cull_frame_instances( if len(keep_instances) > instance_count: # Sort by ascending score, get target number of instances # from the end of list (i.e., with highest score) - extra_instances = sorted(keep_instances, key=operator.attrgetter("score"))[ - :-instance_count - ] + extra_instances = sorted( + keep_instances, + key=operator.attrgetter("score"), + )[:-instance_count] # Remove the extra instances for inst in extra_instances: @@ -523,7 +538,6 @@ def from_candidate_instances( candidate_tracks = [] if candidate_instances: - # Group candidate instances by track. candidate_instances_by_track = defaultdict(list) for instance in candidate_instances: @@ -536,7 +550,6 @@ def from_candidate_instances( matching_similarities = np.full(dims, np.nan) for i, untracked_instance in enumerate(untracked_instances): - for j, candidate_track in enumerate(candidate_tracks): # Compute similarity between untracked instance and all track # candidates. diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 770b337c6..b32f9b8a0 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -947,6 +947,8 @@ def make_tracker_by_name( target_instance_count: int = 0, pre_cull_to_target: bool = False, pre_cull_iou_threshold: Optional[float] = None, + pre_cull_merge_instances: bool = False, + pre_cull_merging_penalty: float = 0.2, # Post-tracking options to connect broken tracks post_connect_single_breaks: bool = False, # TODO: deprecate these post-tracking cleaning options @@ -1011,13 +1013,15 @@ def make_tracker_by_name( ) pre_cull_function = None - if target_instance_count and pre_cull_to_target: + if (target_instance_count and pre_cull_to_target) or pre_cull_merge_instances: def pre_cull_function(inst_list): cull_frame_instances( inst_list, instance_count=target_instance_count, iou_threshold=pre_cull_iou_threshold, + merge_instances=pre_cull_merge_instances, + merging_penalty=pre_cull_merging_penalty, ) tracker_obj = cls( @@ -1094,6 +1098,22 @@ def get_by_name_factory_options(cls): ) options.append(option) + option = dict(name="pre_cull_merge_instances", default=False) + option["type"] = bool + option["help"] = ( + "If True, allow merging instances with non-overlapping visible nodes " + "to create new instances *before* tracking." + ) + options.append(option) + + option = dict(name="pre_cull_merging_penalty", default=0.2) + option["type"] = float + option["help"] = ( + "A float between 0 and 1. All scores of the merged instances " + "are multplied by (1 - penalty)." + ) + options.append(option) + option = dict(name="post_connect_single_breaks", default=0) option["type"] = int option["help"] = ( diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 9f9332338..a3852ec20 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -1976,7 +1976,7 @@ def test_max_tracks_matching_queue( if trackername == "flow": # Check that saved instances are pruned to track window - for key in tracker.candidate_maker.shifted_instances.keys(): + for key in tracker.candidate_maker.shifted_instances: assert lf.frame_idx - key[0] <= track_window # Keys are pruned assert abs(key[0] - key[1]) <= track_window