From a7b3a0a9b52bc1e97518afae5bd73bf5bcff3e50 Mon Sep 17 00:00:00 2001 From: getzze Date: Fri, 2 Aug 2024 15:06:15 +0100 Subject: [PATCH] simplify max_tracks tracker --- docs/guides/cli.md | 12 +- sleap/config/pipeline_form.yaml | 4 - sleap/gui/learning/runners.py | 23 +- sleap/nn/inference.py | 6 - sleap/nn/tracking.py | 504 +++++++++++--------------- tests/nn/test_inference.py | 22 +- tests/nn/test_tracker_components.py | 38 +- tests/nn/test_tracking_integration.py | 53 +-- 8 files changed, 261 insertions(+), 401 deletions(-) diff --git a/docs/guides/cli.md b/docs/guides/cli.md index 03b806903..c59048ecd 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -36,7 +36,7 @@ optional arguments: ```none usage: sleap-train [-h] [--video-paths VIDEO_PATHS] [--val_labels VAL_LABELS] - [--test_labels TEST_LABELS] [--tensorboard] [--save_viz] + [--test_labels TEST_LABELS] [--tensorboard] [--save_viz] [--keep_viz] [--zmq] [--run_name RUN_NAME] [--prefix PREFIX] [--suffix SUFFIX] training_job_path [labels_path] @@ -124,7 +124,7 @@ usage: sleap-track [-h] [-m MODELS] [--frames FRAMES] [--only-labeled-frames] [- [--verbosity {none,rich,json}] [--video.dataset VIDEO.DATASET] [--video.input_format VIDEO.INPUT_FORMAT] [--video.index VIDEO.INDEX] [--cpu | --first-gpu | --last-gpu | --gpu GPU] [--max_edge_length_ratio MAX_EDGE_LENGTH_RATIO] [--dist_penalty_weight DIST_PENALTY_WEIGHT] [--batch_size BATCH_SIZE] [--open-in-gui] [--peak_threshold PEAK_THRESHOLD] - [-n MAX_INSTANCES] [--tracking.tracker TRACKING.TRACKER] [--tracking.max_tracking TRACKING.MAX_TRACKING] + [-n MAX_INSTANCES] [--tracking.tracker TRACKING.TRACKER] [--tracking.max_tracks TRACKING.MAX_TRACKS] [--tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT] [--tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET] [--tracking.pre_cull_iou_threshold TRACKING.PRE_CULL_IOU_THRESHOLD] [--tracking.post_connect_single_breaks TRACKING.POST_CONNECT_SINGLE_BREAKS] @@ -187,10 +187,8 @@ optional arguments: Limit maximum number of instances in multi-instance models. Not available for ID models. Defaults to None. --tracking.tracker TRACKING.TRACKER Options: simple, flow, simplemaxtracks, flowmaxtracks, None (default: None) - --tracking.max_tracking TRACKING.MAX_TRACKING - If true then the tracker will cap the max number of tracks. (default: False) --tracking.max_tracks TRACKING.MAX_TRACKS - Maximum number of tracks to be tracked by the tracker. (default: None) + Maximum number of tracks to be tracked by the tracker. No limit if None or -1. (default: None) --tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT Target number of instances to track per frame. (default: 0) --tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET @@ -264,13 +262,13 @@ sleap-track -m "models/my_model" --tracking.tracker simple -o "output_prediction **5. Inference with max tracks limit:** ```none -sleap-track -m "models/my_model" --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 -o "output_predictions.slp" "input_video.mp4" +sleap-track -m "models/my_model" --tracking.tracker simple --tracking.max_tracks 4 -o "output_predictions.slp" "input_video.mp4" ``` **6. Re-tracking without pose inference:** ```none -sleap-track --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 -o "retracked.slp" "input_predictions.slp" +sleap-track --tracking.tracker simple --tracking.max_tracks 4 -o "retracked.slp" "input_predictions.slp" ``` **7. Select GPU for pose inference:** diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml index c34faea55..406a02ea8 100644 --- a/sleap/config/pipeline_form.yaml +++ b/sleap/config/pipeline_form.yaml @@ -521,10 +521,6 @@ inference: text: 'Tracking:
This tracker assigns track identities by matching instances from prior frames to instances on subsequent frames.' - # - name: tracking.max_tracking - # label: Limit max number of tracks - # type: bool - # default: false - name: tracking.max_tracks label: Max number of tracks type: optional_int diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py index d0bb1f3ba..641e2c4ab 100644 --- a/sleap/gui/learning/runners.py +++ b/sleap/gui/learning/runners.py @@ -204,7 +204,6 @@ def make_predict_cli_call( # Make path where we'll save predictions (if not specified) if output_path is None: - if self.labels_filename: # Make a predictions directory next to the labels dataset file predictions_dir = os.path.join( @@ -239,15 +238,12 @@ def make_predict_cli_call( if key in self.inference_params and self.inference_params[key] is None: del self.inference_params[key] - # Setting max_tracks to True means we want to use the max_tracking mode. - if "tracking.max_tracks" in self.inference_params: - self.inference_params["tracking.max_tracking"] = True - - # Hacky: Update the tracker name to include "maxtracks" suffix. - if self.inference_params["tracking.tracker"] in ("simple", "flow"): - self.inference_params["tracking.tracker"] = ( - self.inference_params["tracking.tracker"] + "maxtracks" - ) + # Compatibility with using the "maxtracks" suffix to the tracker name. + if "tracking.tracker" in self.inference_params: + compat_trackers = ("simplemaxtracks", "flowmaxtracks") + if self.inference_params["tracking.tracker"] in compat_trackers: + tname = self.inference_params["tracking.tracker"][: -len("maxtracks")] + self.inference_params["tracking.tracker"] = tname # --tracking.kf_init_frame_count enables the kalman filter tracking # so if not set, then remove other (unused) args @@ -257,10 +253,11 @@ def make_predict_cli_call( bool_items_as_ints = ( "tracking.pre_cull_to_target", - "tracking.max_tracking", "tracking.post_connect_single_breaks", "tracking.save_shifted_instances", "tracking.oks_score_weighting", + "tracking.prefer_reassigning_track", + "tracking.allow_reassigning_track", ) for key in bool_items_as_ints: @@ -303,10 +300,8 @@ def predict_subprocess( # Run inference CLI capturing output. with subprocess.Popen(cli_args, stdout=subprocess.PIPE) as proc: - # Poll until finished. while proc.poll() is None: - # Read line. line = proc.stdout.readline() line = line.decode().rstrip() @@ -635,7 +630,6 @@ def run_gui_training( for config_info in config_info_list: if config_info.dont_retrain: - if not config_info.has_trained_model: raise ValueError( "Config is set to not retrain but no trained model found: " @@ -849,7 +843,6 @@ def train_subprocess( success = False with tempfile.TemporaryDirectory() as temp_dir: - # Write a temporary file of the TrainingJob so that we can respect # any changed made to the job attributes after it was loaded. temp_filename = datetime.now().strftime("%y%m%d_%H%M%S") + "_training_job.json" diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index a1a083ba5..e5cb1dc49 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -4947,16 +4947,10 @@ def unpack_sleap_model(model_path): ) predictor.verbosity = progress_reporting if tracker is not None: - use_max_tracker = tracker_max_instances is not None - if use_max_tracker and not tracker.endswith("maxtracks"): - # Append maxtracks to the tracker name to use the right tracker variants. - tracker += "maxtracks" - predictor.tracker = Tracker.make_tracker_by_name( tracker=tracker, track_window=tracker_window, post_connect_single_breaks=True, - max_tracking=use_max_tracker, max_tracks=tracker_max_instances, # clean_instance_count=tracker_max_instances, ) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index a99aab8b5..770b337c6 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -2,6 +2,7 @@ import abc import json +import operator import sys from collections import deque from time import time @@ -114,7 +115,23 @@ class MatchedShiftedFrameInstances: @attr.s(auto_attribs=True) -class FlowCandidateMaker: +class CandidateMaker(abc.ABC): + """Abstract base class for candidate maker.""" + + @abc.abstractmethod + def get_candidates( + self, + track_matching_queue: Deque[MatchedFrameInstances], + t: int, + img: np.ndarray, + *args, + **kwargs, + ) -> List[ShiftedInstance]: + pass + + +@attr.s(auto_attribs=True) +class FlowCandidateMaker(CandidateMaker): """Class for producing optical flow shift matching candidates. Attributes: @@ -209,37 +226,6 @@ def get_shifted_instances( return shifted_instances - def get_candidates( - self, - track_matching_queue: Deque[MatchedFrameInstances], - t: int, - img: np.ndarray, - ) -> List[ShiftedInstance]: - candidate_instances = [] - - # Prune old shifted instances to save time and memory - self.prune_shifted_instances(t) - - for matched_item in track_matching_queue: - ref_t, ref_img, ref_instances = ( - matched_item.t, - matched_item.img_t, - matched_item.instances_t, - ) - - # Check if shifted instance was computed at earlier time - if self.save_shifted_instances: - ref_img, ref_instances = self.get_shifted_instances_from_earlier_time( - ref_t, ref_img, ref_instances, t - ) - - if len(ref_instances) > 0: - candidate_instances.extend( - self.get_shifted_instances(ref_instances, ref_img, ref_t, img, t) - ) - - return candidate_instances - def prune_shifted_instances(self, t: int): """Prune the shifted instances older than `self.track_window`. @@ -363,45 +349,9 @@ def flow_shift_instances( return shifted_instances - -@attr.s(auto_attribs=True) -class FlowMaxTracksCandidateMaker(FlowCandidateMaker): - """Class for producing optical flow shift matching candidates with maximum tracks. - - Attributes: - max_tracks: The maximum number of tracks to avoid redundant tracks. - - """ - - max_tracks: int = None - - @staticmethod - def get_ref_instances( - ref_t: int, - ref_img: np.ndarray, - track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]], - ) -> List[InstanceType]: - """Generates a list of instances based on the reference time and image. - - Args: - ref_t: Previous frame time instance. - ref_img: Previous frame image as a numpy array. - track_matching_queue_dict: A dictionary of mapping between the tracks - and the corresponding instances associated with the track. - """ - instances = [] - for track, matched_items in track_matching_queue_dict.items(): - instances += [ - item.instance_t - for item in matched_items - if item.t == ref_t and np.all(item.img_t == ref_img) - ] - return instances - def get_candidates( self, - track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]], - max_tracking: bool, + track_matching_queue: Deque[MatchedFrameInstances], t: int, img: np.ndarray, *args, @@ -411,42 +361,33 @@ def get_candidates( # Prune old shifted instances to save time and memory self.prune_shifted_instances(t) - # Storing the tracks from the dictionary for counting purpose. - tracks = [] - - for track, matched_items in track_matching_queue_dict.items(): - if not max_tracking or len(tracks) < self.max_tracks: - tracks.append(track) - for matched_item in matched_items: - ref_t, ref_img = ( - matched_item.t, - matched_item.img_t, - ) - ref_instances = self.get_ref_instances( - ref_t, ref_img, track_matching_queue_dict - ) - # Check if shifted instance was computed at earlier time - if self.save_shifted_instances: - ( - ref_img, - ref_instances, - ) = self.get_shifted_instances_from_earlier_time( - ref_t, ref_img, ref_instances, t - ) - - if len(ref_instances) > 0: - candidate_instances.extend( - self.get_shifted_instances( - ref_instances, ref_img, ref_t, img, t - ) - ) + for matched_item in track_matching_queue: + ref_t, ref_img, ref_instances = ( + matched_item.t, + matched_item.img_t, + matched_item.instances_t, + ) + + # Check if shifted instance was computed at earlier time + if self.save_shifted_instances: + ( + ref_img, + ref_instances, + ) = self.get_shifted_instances_from_earlier_time( + ref_t, ref_img, ref_instances, t + ) + + if len(ref_instances) > 0: + candidate_instances.extend( + self.get_shifted_instances(ref_instances, ref_img, ref_t, img, t) + ) return candidate_instances @attr.s(auto_attribs=True) -class SimpleCandidateMaker: +class SimpleCandidateMaker(CandidateMaker): """Class for producing list of matching candidates from prior frames.""" min_points: int = 0 @@ -456,48 +397,25 @@ def uses_image(self): return False def get_candidates( - self, track_matching_queue: Deque[MatchedFrameInstances], *args, **kwargs + self, + track_matching_queue: Deque[MatchedFrameInstances], + *args, + **kwargs, ) -> List[InstanceType]: - # Build a pool of matchable candidate instances. + # Create set of matchable candidate instances from each track. candidate_instances = [] for matched_item in track_matching_queue: ref_t, ref_instances = matched_item.t, matched_item.instances_t for ref_instance in ref_instances: if ref_instance.n_visible_points >= self.min_points: candidate_instances.append(ref_instance) - return candidate_instances - - -@attr.s(auto_attribs=True) -class SimpleMaxTracksCandidateMaker(SimpleCandidateMaker): - """Class to generate instances with maximum number of tracks from prior frames.""" - - max_tracks: int = None - def get_candidates( - self, - track_matching_queue_dict: Dict, - max_tracking: bool, - *args, - **kwargs, - ) -> List[InstanceType]: - # Create set of matchable candidate instances from each track. - candidate_instances = [] - tracks = [] - for track, matched_instances in track_matching_queue_dict.items(): - if not max_tracking or len(tracks) < self.max_tracks: - tracks.append(track) - for ref_instance in matched_instances: - if ref_instance.instance_t.n_visible_points >= self.min_points: - candidate_instances.append(ref_instance.instance_t) return candidate_instances tracker_policies = dict( simple=SimpleCandidateMaker, flow=FlowCandidateMaker, - simplemaxtracks=SimpleMaxTracksCandidateMaker, - flowmaxtracks=FlowMaxTracksCandidateMaker, ) similarity_policies = dict( @@ -708,7 +626,7 @@ class Tracker(BaseTracker): use a robust quantile similarity score for the track. If the value is 1, use the max similarity (non-robust). For selecting a robust score, 0.95 is a good value. - max_tracking: Max tracking is incorporated when this is set to true. + max_tracks: Maximum number of tracks. No limit if set to -1. verbosity: Mode of inference progress reporting. If `"rich"` (the default), an updating progress bar is displayed in the console or notebook. If `"json"`, a JSON-serialized message is printed out which can be captured @@ -717,12 +635,11 @@ class Tracker(BaseTracker): machines where the output is captured to a log file. """ - max_tracks: int = None track_window: int = 5 similarity_function: Optional[Callable] = instance_similarity matching_function: Callable = greedy_matching - candidate_maker: object = attr.ib(factory=FlowCandidateMaker) - max_tracking: bool = False # To enable maximum tracking. + candidate_maker: CandidateMaker = attr.ib(factory=FlowCandidateMaker) + max_tracks: int = -1 cleaner: Optional[Callable] = None # TODO: deprecate target_instance_count: int = 0 @@ -731,15 +648,23 @@ class Tracker(BaseTracker): robust_best_instance: float = 1.0 min_new_track_points: int = 0 + prefer_reassigning_track: bool = False + allow_reassigning_track: bool = False + verbosity: str = attr.ib( + validator=attr.validators.in_(["none", "rich", "json"]), + default="none", + ) + report_rate: float = 2.0 + + #: Hold frames with matched instances as deque of length `track_window`. track_matching_queue: Deque[MatchedFrameInstances] = attr.ib() - # Hold track, instances with instances as a deque with length as track_window. - track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]] = attr.ib( - factory=dict - ) spawned_tracks: List[Track] = attr.ib(factory=list) + #: Found tracks with last time an instance was found + found_tracks: Dict[Track, InstanceType] = attr.ib(factory=dict) + save_tracked_instances: bool = False tracked_instances: Dict[int, List[InstanceType]] = attr.ib( factory=dict @@ -758,44 +683,39 @@ def is_valid(self): return self.similarity_function is not None @track_matching_queue.default - def _init_matching_queue(self): + def _init_track_matching_queue(self): """Factory for instantiating default matching queue with specified size.""" return deque(maxlen=self.track_window) - @property - def has_max_tracking(self) -> bool: - return isinstance( - self.candidate_maker, - (SimpleMaxTracksCandidateMaker, FlowMaxTracksCandidateMaker), - ) - def reset_candidates(self): - if self.has_max_tracking: - for track in self.track_matching_queue_dict: - self.track_matching_queue_dict[track] = deque(maxlen=self.track_window) - else: - self.track_matching_queue = deque(maxlen=self.track_window) + self.track_matching_queue = deque(maxlen=self.track_window) @property def unique_tracks_in_queue(self) -> List[Track]: """Returns the unique tracks in the matching queue.""" - - unique_tracks = set() - if self.has_max_tracking: - for track in self.track_matching_queue_dict.keys(): - unique_tracks.add(track) - - else: - for match_item in self.track_matching_queue: - for instance in match_item.instances_t: - unique_tracks.add(instance.track) - - return list(unique_tracks) + return { + instance.track + for item in self.track_matching_queue + for instance in item.instances_t + } @property def uses_image(self): return getattr(self.candidate_maker, "uses_image", False) + def infer_next_timestep(self, t: Optional[int] = None) -> int: + """Infer timestep if not provided.""" + # Timestep was provided + if t is not None: + return t + + # Default to last timestep + 1 if available. + if len(self.track_matching_queue) > 0: + return self.track_matching_queue[-1].t + 1 + + # Default to 0 + return 0 + def track( self, untracked_instances: List[InstanceType], @@ -817,55 +737,22 @@ def track( return untracked_instances # Infer timestep if not provided. - if t is None: - if self.has_max_tracking: - if len(self.track_matching_queue_dict) > 0: - # Default to last timestep + 1 if available. - # Here we find the track that has the most instances. - track_with_max_instances = max( - self.track_matching_queue_dict, - key=lambda track: len(self.track_matching_queue_dict[track]), - ) - t = ( - self.track_matching_queue_dict[track_with_max_instances][-1].t - + 1 - ) - - else: - t = 0 - else: - if len(self.track_matching_queue) > 0: - # Default to last timestep + 1 if available. - t = self.track_matching_queue[-1].t + 1 - - else: - t = 0 + t = self.infer_next_timestep(t) # Initialize containers for tracked instances at the current timestep. tracked_instances = [] - # Make cache so similarity function doesn't have to recompute everything. - # similarity_cache = dict() - # Process untracked instances. if untracked_instances: if self.pre_cull_function: self.pre_cull_function(untracked_instances) # Build a pool of matchable candidate instances. - if self.has_max_tracking: - candidate_instances = self.candidate_maker.get_candidates( - track_matching_queue_dict=self.track_matching_queue_dict, - max_tracking=self.max_tracking, - t=t, - img=img, - ) - else: - candidate_instances = self.candidate_maker.get_candidates( - track_matching_queue=self.track_matching_queue, - t=t, - img=img, - ) + candidate_instances = self.candidate_maker.get_candidates( + track_matching_queue=self.track_matching_queue, + t=t, + img=img, + ) # Determine matches for untracked instances in current frame. frame_matches = FrameMatches.from_candidate_instances( @@ -881,37 +768,18 @@ def track( # Set track for each of the matched instances. tracked_instances.extend( - self.update_matched_instance_tracks(frame_matches.matches) + self.update_matched_instance_tracks(frame_matches.matches, t) ) - # Spawn a new track for each remaining untracked instance. + # Assign unmatched instances to new tracks or already existing tracks tracked_instances.extend( self.spawn_for_untracked_instances(frame_matches.unmatched_instances, t) ) - # Add the tracked instances to the dictionary of matched instances. - if self.has_max_tracking: - for tracked_instance in tracked_instances: - if tracked_instance.track in self.track_matching_queue_dict: - self.track_matching_queue_dict[tracked_instance.track].append( - MatchedFrameInstance(t, tracked_instance, img) - ) - elif ( - not self.max_tracking - or len(self.track_matching_queue_dict) < self.max_tracks - ): - self.track_matching_queue_dict[tracked_instance.track] = deque( - maxlen=self.track_window - ) - self.track_matching_queue_dict[tracked_instance.track].append( - MatchedFrameInstance(t, tracked_instance, img) - ) - - else: - # Add the tracked instances to the matching buffer. - self.track_matching_queue.append( - MatchedFrameInstances(t, tracked_instances, img) - ) + # Add the tracked instances to the matching buffer. + self.track_matching_queue.append( + MatchedFrameInstances(t, tracked_instances, img) + ) # Save tracked instances internally. if self.save_tracked_instances: @@ -919,8 +787,11 @@ def track( return tracked_instances - @staticmethod - def update_matched_instance_tracks(matches: List[Match]) -> List[InstanceType]: + def update_matched_instance_tracks( + self, + matches: List[Match], + t: int, + ) -> List[InstanceType]: inst_list = [] for match in matches: # Assign to track and save. @@ -931,31 +802,108 @@ def update_matched_instance_tracks(matches: List[Match]) -> List[InstanceType]: tracking_score=match.score, ) ) + # Keep the last instance for this track + self.found_tracks[match.track] = t return inst_list + def spawn_new_track(self, inst: InstanceType, t: int) -> Optional[InstanceType]: + """Try spawning a new track for instance.""" + if self.max_tracks >= 0 and len(self.found_tracks) >= self.max_tracks: + return None + + # Spawn new track. + new_track = Track(spawned_on=t, name=f"track_{len(self.spawned_tracks)}") + self.spawned_tracks.append(new_track) + + # After setting track, keep the last instance for this track + self.found_tracks[new_track] = t + + # Assign instance to the new track and save. + return attr.evolve(inst, track=new_track) + + def assign_to_track( + self, + inst: InstanceType, + t: int, + not_assigned_tracks: List[Track], + ) -> Optional[InstanceType]: + """Try assigning instance to a track from a candidate list (best last). + + `not_assigned_tracks` will be modified. + """ + if len(not_assigned_tracks) == 0: + return None + + existing_track = not_assigned_tracks.pop() + + # After setting track, keep the last instance for this track + self.found_tracks[existing_track] = t + + # Assign instance to the existing track and save. + return attr.evolve(inst, track=existing_track, tracking_score=0) + def spawn_for_untracked_instances( - self, unmatched_instances: List[InstanceType], t: int + self, + unmatched_instances: List[InstanceType], + t: int, ) -> List[InstanceType]: + """Assign an existing track or spawn a new track for untracked instances.""" + # Early return + if len(unmatched_instances) == 0: + return [] + + # Use the tracks that have not been assigned an instance since track_window + not_assigned_tracks_dict = { + track: last_t + for track, last_t in self.found_tracks.items() + if last_t < t - self.track_window + } + # Sort tracks by last used last, so we can pop the last used track + not_assigned_tracks = sorted( + not_assigned_tracks_dict, + key=self.found_tracks.get, + ) + + # Tracks left to assign (all if max_tracks is negative + n_remaining_tracks = ( + max(0, self.max_tracks - len(not_assigned_tracks)) + if self.max_tracks >= 0 + else len(unmatched_instances) + ) + + # Sort instances by descending instance-level grouping score + sorted_instances = sorted( + unmatched_instances, + key=operator.attrgetter("score"), + reverse=True, + )[:n_remaining_tracks] + results = [] - for inst in unmatched_instances: - # Skip if this instance is too small to spawn a new track with. + for inst in sorted_instances: + # Skip if not enough visible nodes to assign to a track or spawn a new track if inst.n_visible_points < self.min_new_track_points: continue - # Skip if we've reached the maximum number of tracks. - if ( - self.has_max_tracking - and self.max_tracking - and len(self.track_matching_queue_dict) >= self.max_tracks - ): - break - - # Spawn new track. - new_track = Track(spawned_on=t, name=f"track_{len(self.spawned_tracks)}") - self.spawned_tracks.append(new_track) - - # Assign instance to the new track and save. - results.append(attr.evolve(inst, track=new_track)) + # Try spawning new track (if this is preferred, before reassigning track) + if not self.prefer_reassigning_track: + matched_inst = self.spawn_new_track(inst, t) + if matched_inst is not None: + results.append(matched_inst) + continue + + # Try assigning to an existing track + if self.allow_reassigning_track: + matched_inst = self.assign_to_track(inst, t, not_assigned_tracks) + if matched_inst is not None: + results.append(matched_inst) + continue + + # Try spawning new track (if this is not preferred, after reassigning track) + if self.prefer_reassigning_track: + matched_inst = self.spawn_new_track(inst, t) + if matched_inst is not None: + results.append(matched_inst) + continue return results @@ -1009,7 +957,8 @@ def make_tracker_by_name( kf_node_indices: Optional[list] = None, # Max tracking options max_tracks: Optional[int] = None, - max_tracking: bool = False, + prefer_reassigning_track: bool = False, + allow_reassigning_track: bool = False, # Object keypoint similarity options oks_errors: Optional[list] = None, oks_score_weighting: bool = False, @@ -1018,11 +967,8 @@ def make_tracker_by_name( report_rate: float = 2.0, **kwargs, ) -> BaseTracker: - # Parse max_tracking arguments, only True if max_tracks is not None and > 0 - max_tracking = max_tracking if max_tracks else False - if max_tracking and tracker in ("simple", "flow"): - # Force a candidate maker of 'maxtracks' type - tracker += "maxtracks" + # Parse max_tracks, set to -1 if None + max_tracks = max_tracks if max_tracks is not None and max_tracks >= 0 else -1 if tracker.lower() == "none": candidate_maker = None @@ -1058,9 +1004,6 @@ def make_tracker_by_name( candidate_maker.save_shifted_instances = save_shifted_instances candidate_maker.track_window = track_window - if tracker == "simplemaxtracks" or tracker == "flowmaxtracks": - candidate_maker.max_tracks = max_tracks - cleaner = None if clean_instance_count: cleaner = TrackCleaner( @@ -1081,14 +1024,15 @@ def pre_cull_function(inst_list): track_window=track_window, robust_best_instance=robust, min_new_track_points=min_new_track_points, + max_tracks=max_tracks, similarity_function=similarity_function, matching_function=matching_function, candidate_maker=candidate_maker, cleaner=cleaner, pre_cull_function=pre_cull_function, - max_tracking=max_tracking, - max_tracks=max_tracks, target_instance_count=target_instance_count, + allow_reassigning_track=allow_reassigning_track, + prefer_reassigning_track=prefer_reassigning_track, post_connect_single_breaks=post_connect_single_breaks, verbosity=progress_reporting, report_rate=report_rate, @@ -1120,17 +1064,12 @@ def get_by_name_factory_options(cls): ] options.append(option) - option = dict(name="max_tracking", default=False) - option["type"] = bool - option["help"] = ( - "If true then the tracker will cap the max number of tracks. " - "Falls back to false if `max_tracks` is not defined or 0." - ) - options.append(option) - option = dict(name="max_tracks", default=None) option["type"] = int - option["help"] = "Maximum number of tracks to be tracked by the tracker." + option["help"] = ( + "Maximum number of tracks to be tracked by the tracker. " + "No maximum if set to -1." + ) options.append(option) option = dict(name="target_instance_count", default=0) @@ -1202,6 +1141,21 @@ def get_by_name_factory_options(cls): option["help"] = "Minimum number of instance points for spawning new track" options.append(option) + option = dict(name="allow_reassigning_track", default=False) + option["type"] = bool + option[ + "help" + ] = "Allow assigning existing but unused track to unmatched instances." + options.append(option) + + option = dict(name="prefer_reassigning_track", default=False) + option["type"] = bool + option["help"] = ( + "Try first to reassign to an existing track before trying to " + "spawn a new track with unmatched instances." + ) + options.append(option) + option = dict(name="min_match_points", default=0) option["type"] = int option["help"] = "Minimum points for match candidates" @@ -1328,19 +1282,6 @@ class FlowTracker(Tracker): candidate_maker: object = attr.ib(factory=FlowCandidateMaker) -attr.s(auto_attribs=True) - - -class FlowMaxTracker(Tracker): - """Pre-configured tracker to use optical flow shifted candidates with max tracks.""" - - max_tracks: int = attr.ib(kw_only=True) - similarity_function: Callable = instance_similarity - matching_function: Callable = greedy_matching - candidate_maker: object = attr.ib(factory=FlowMaxTracksCandidateMaker) - max_tracking: bool = True - - @attr.s(auto_attribs=True) class SimpleTracker(Tracker): """A Tracker pre-configured to use simple, non-image-based candidates.""" @@ -1350,17 +1291,6 @@ class SimpleTracker(Tracker): candidate_maker: object = attr.ib(factory=SimpleCandidateMaker) -@attr.s(auto_attribs=True) -class SimpleMaxTracker(Tracker): - """Pre-configured tracker to use simple, non-image-based candidates with max tracks.""" - - max_tracks: int = attr.ib(kw_only=True) - similarity_function: Callable = instance_iou - matching_function: Callable = hungarian_matching - candidate_maker: object = attr.ib(factory=SimpleMaxTracksCandidateMaker) - max_tracking: bool = True - - @attr.s(auto_attribs=True) class KalmanInitSet: init_frame_count: int diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index fd615ea81..9f9332338 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -58,7 +58,6 @@ from sleap.nn.tracking import ( MatchedFrameInstance, FlowCandidateMaker, - FlowMaxTracksCandidateMaker, Tracker, ) from sleap.instance import Track @@ -1273,6 +1272,7 @@ def test_make_export_cli(): assert args.max_instances == max_instances +@pytest.mark.slow() def test_topdown_predictor_save( min_centroid_model_path, min_centered_instance_model_path, tmp_path ): @@ -1315,6 +1315,7 @@ def test_topdown_predictor_save( ) +@pytest.mark.slow() def test_topdown_id_predictor_save( min_centroid_model_path, min_topdown_multiclass_model_path, tmp_path ): @@ -1361,9 +1362,7 @@ def test_topdown_id_predictor_save( "output_path,tracker_method", [ ("not_default", "flow"), - ("not_default", "flowmaxtracks"), (None, "simple"), - (None, "simplemaxtracks"), ], ) def test_retracking( @@ -1380,7 +1379,6 @@ def test_retracking( if tracker_method == "flow": cmd += " --tracking.save_shifted_instances 1" elif tracker_method == "simplemaxtracks" or tracker_method == "flowmaxtracks": - cmd += " --tracking.max_tracking 1" cmd += " --tracking.max_tracks 2" if output_path == "not_default": output_path = Path(tmpdir, "tracked_slp.slp") @@ -1944,8 +1942,8 @@ def test_flow_tracker(centered_pair_predictions_sorted: Labels, tmpdir): @pytest.mark.parametrize( "max_tracks, trackername", [ - (2, "flowmaxtracks"), - (2, "simplemaxtracks"), + (2, "flow"), + (2, "simple"), ], ) def test_max_tracks_matching_queue( @@ -1953,7 +1951,6 @@ def test_max_tracks_matching_queue( ): """Test flow max tracks instance generation.""" labels: Labels = centered_pair_predictions - max_tracking = True track_window = 5 # Setup flow max tracker @@ -1961,11 +1958,10 @@ def test_max_tracks_matching_queue( tracker=trackername, track_window=track_window, save_shifted_instances=True, - max_tracking=max_tracking, max_tracks=max_tracks, ) - tracker.candidate_maker = cast(FlowMaxTracksCandidateMaker, tracker.candidate_maker) + tracker.candidate_maker = cast(FlowCandidateMaker, tracker.candidate_maker) # Run tracking frames = sorted(labels.labeled_frames, key=lambda lf: lf.frame_idx) @@ -1978,18 +1974,17 @@ def test_max_tracks_matching_queue( track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx]) tracker.track(**track_args) - if trackername == "flowmaxtracks": + if trackername == "flow": # Check that saved instances are pruned to track window for key in tracker.candidate_maker.shifted_instances.keys(): assert lf.frame_idx - key[0] <= track_window # Keys are pruned assert abs(key[0] - key[1]) <= track_window # Check if the length of each of the tracks is not more than the track window - for track in tracker.track_matching_queue_dict.keys(): - assert len(tracker.track_matching_queue_dict[track]) <= track_window + assert len(tracker.track_matching_queue) <= track_window # Check if number of tracks that are generated are not more than the maximum tracks - assert len(tracker.track_matching_queue_dict) <= max_tracks + assert len(tracker.unique_tracks_in_queue) <= max_tracks def test_movenet_inference(movenet_video): @@ -2012,6 +2007,7 @@ def test_movenet_inference(movenet_video): assert preds["instance_peaks"].shape == (1, 1, 17, 2) +@pytest.mark.slow() def test_movenet_predictor(min_dance_labels, movenet_video): predictor = MoveNetPredictor.from_trained_models("thunder") predictor.verbosity = "none" diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index ffdd35257..be5789dc9 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -32,9 +32,7 @@ def run_tracker_by_name(frames=None, img_scale: float = 0, **kwargs): assert len(new_frames) == len(frames) -@pytest.mark.parametrize( - "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"] -) +@pytest.mark.parametrize("tracker", ["simple", "flow"]) @pytest.mark.parametrize("similarity", ["instance", "iou", "centroid"]) @pytest.mark.parametrize("match", ["greedy", "hungarian"]) @pytest.mark.parametrize("img_scale", [0, 1, 0.25]) @@ -60,9 +58,7 @@ def test_tracker_by_name( ) -@pytest.mark.parametrize( - "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"] -) +@pytest.mark.parametrize("tracker", ["simple", "flow"]) @pytest.mark.parametrize("oks_score_weighting", ["True", "False"]) @pytest.mark.parametrize("oks_normalization", ["all", "ref", "union"]) def test_oks_tracker_by_name( @@ -245,7 +241,7 @@ def make_inst(x, y): return insts -def test_max_tracking_large_gap_single_track(): +def test_max_tracks_large_gap_single_track(): # Track 2 instances with gap > window size preds = make_insts( [ @@ -280,11 +276,9 @@ def test_max_tracking_large_gap_single_track(): tracker = Tracker.make_tracker_by_name( tracker="simple", - # tracker="simplemaxtracks", match="hungarian", track_window=2, - # max_tracks=2, - # max_tracking=True, + max_tracks=-1, ) tracked = [] @@ -296,12 +290,10 @@ def test_max_tracking_large_gap_single_track(): assert len(all_tracks) == 3 tracker = Tracker.make_tracker_by_name( - # tracker="simple", - tracker="simplemaxtracks", + tracker="simple", match="hungarian", track_window=2, max_tracks=2, - max_tracking=True, ) tracked = [] @@ -313,7 +305,7 @@ def test_max_tracking_large_gap_single_track(): assert len(all_tracks) == 2 -def test_max_tracking_small_gap_on_both_tracks(): +def test_max_tracks_small_gap_on_both_tracks(): # Test 2 instances with both tracks with gap > window size preds = make_insts( [ @@ -344,11 +336,9 @@ def test_max_tracking_small_gap_on_both_tracks(): tracker = Tracker.make_tracker_by_name( tracker="simple", - # tracker="simplemaxtracks", match="hungarian", track_window=2, - # max_tracks=2, - # max_tracking=True, + max_tracks=-1, ) tracked = [] @@ -360,12 +350,10 @@ def test_max_tracking_small_gap_on_both_tracks(): assert len(all_tracks) == 4 tracker = Tracker.make_tracker_by_name( - # tracker="simple", - tracker="simplemaxtracks", + tracker="simple", match="hungarian", track_window=2, max_tracks=2, - max_tracking=True, ) tracked = [] @@ -377,7 +365,7 @@ def test_max_tracking_small_gap_on_both_tracks(): assert len(all_tracks) == 2 -def test_max_tracking_extra_detections(): +def test_max_tracks_extra_detections(): # Test having more than 2 detected instances in a frame preds = make_insts( [ @@ -413,11 +401,9 @@ def test_max_tracking_extra_detections(): tracker = Tracker.make_tracker_by_name( tracker="simple", - # tracker="simplemaxtracks", match="hungarian", track_window=2, - # max_tracks=2, - # max_tracking=True, + max_tracks=-1, ) tracked = [] @@ -429,12 +415,10 @@ def test_max_tracking_extra_detections(): assert len(all_tracks) == 4 tracker = Tracker.make_tracker_by_name( - # tracker="simple", - tracker="simplemaxtracks", + tracker="simple", match="hungarian", track_window=2, max_tracks=2, - max_tracking=True, ) tracked = [] diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index 84707fd44..ab8d52b87 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -22,10 +22,10 @@ def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path): assert len(labels.tracks) == 8 -def test_simplemax_tracker(tmpdir, centered_pair_predictions_slp_path): +def test_simple_max_tracks(tmpdir, centered_pair_predictions_slp_path): cli = ( - "--tracking.tracker simplemaxtracks " - "--tracking.max_tracking 1 --tracking.max_tracks 2 " + "--tracking.tracker simple " + "--tracking.max_tracks 2 " "--frames 200-300 " f"-o {tmpdir}/simplemaxtracks.slp " f"{centered_pair_predictions_slp_path}" @@ -91,8 +91,6 @@ def main(f, dir): trackers = dict( simple=sleap.nn.tracker.simple.SimpleTracker, flow=sleap.nn.tracker.flow.FlowTracker, - simplemaxtracks=sleap.nn.tracker.SimpleMaxTracker, - flowmaxtracks=sleap.nn.tracker.FlowMaxTracker, ) matchers = dict( hungarian=sleap.nn.tracker.components.hungarian_matching, @@ -108,21 +106,12 @@ def main(f, dir): 0.25, ) - def make_tracker( - tracker_name, matcher_name, sim_name, max_tracks, max_tracking=False, scale=0 - ): - if tracker_name == "simplemaxtracks" or tracker_name == "flowmaxtracks": - tracker = trackers[tracker_name]( - matching_function=matchers[matcher_name], - similarity_function=similarities[sim_name], - max_tracks=max_tracks, - max_tracking=max_tracking, - ) - else: - tracker = trackers[tracker_name]( - matching_function=matchers[matcher_name], - similarity_function=similarities[sim_name], - ) + def make_tracker(tracker_name, matcher_name, sim_name, max_tracks, scale=0): + tracker = trackers[tracker_name]( + matching_function=matchers[matcher_name], + similarity_function=similarities[sim_name], + max_tracks=max_tracks, + ) if scale: tracker.candidate_maker.img_scale = scale return tracker @@ -151,36 +140,16 @@ def make_tracker_and_filename(*args, **kwargs): tracker, gt_filename = make_tracker_and_filename( tracker_name=tracker_name, matcher_name=matcher_name, - sim_name=sim_name, - scale=scale, - ) - f(frames, tracker, gt_filename) - elif tracker_name == "flowmaxtracks": - # If this tracker supports scale, try multiple scales - for scale in scales: - tracker, gt_filename = make_tracker_and_filename( - tracker_name=tracker_name, - matcher_name=matcher_name, - sim_name=sim_name, max_tracks=2, - max_tracking=True, + sim_name=sim_name, scale=scale, ) f(frames, tracker, gt_filename) - elif tracker_name == "simplemaxtracks": - tracker, gt_filename = make_tracker_and_filename( - tracker_name=tracker_name, - matcher_name=matcher_name, - sim_name=sim_name, - max_tracks=2, - max_tracking=True, - scale=0, - ) - f(frames, tracker, gt_filename) else: tracker, gt_filename = make_tracker_and_filename( tracker_name=tracker_name, matcher_name=matcher_name, + max_tracks=2, sim_name=sim_name, scale=0, )