diff --git a/docs/guides/cli.md b/docs/guides/cli.md index 03b806903..c59048ecd 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -36,7 +36,7 @@ optional arguments: ```none usage: sleap-train [-h] [--video-paths VIDEO_PATHS] [--val_labels VAL_LABELS] - [--test_labels TEST_LABELS] [--tensorboard] [--save_viz] + [--test_labels TEST_LABELS] [--tensorboard] [--save_viz] [--keep_viz] [--zmq] [--run_name RUN_NAME] [--prefix PREFIX] [--suffix SUFFIX] training_job_path [labels_path] @@ -124,7 +124,7 @@ usage: sleap-track [-h] [-m MODELS] [--frames FRAMES] [--only-labeled-frames] [- [--verbosity {none,rich,json}] [--video.dataset VIDEO.DATASET] [--video.input_format VIDEO.INPUT_FORMAT] [--video.index VIDEO.INDEX] [--cpu | --first-gpu | --last-gpu | --gpu GPU] [--max_edge_length_ratio MAX_EDGE_LENGTH_RATIO] [--dist_penalty_weight DIST_PENALTY_WEIGHT] [--batch_size BATCH_SIZE] [--open-in-gui] [--peak_threshold PEAK_THRESHOLD] - [-n MAX_INSTANCES] [--tracking.tracker TRACKING.TRACKER] [--tracking.max_tracking TRACKING.MAX_TRACKING] + [-n MAX_INSTANCES] [--tracking.tracker TRACKING.TRACKER] [--tracking.max_tracks TRACKING.MAX_TRACKS] [--tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT] [--tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET] [--tracking.pre_cull_iou_threshold TRACKING.PRE_CULL_IOU_THRESHOLD] [--tracking.post_connect_single_breaks TRACKING.POST_CONNECT_SINGLE_BREAKS] @@ -187,10 +187,8 @@ optional arguments: Limit maximum number of instances in multi-instance models. Not available for ID models. Defaults to None. --tracking.tracker TRACKING.TRACKER Options: simple, flow, simplemaxtracks, flowmaxtracks, None (default: None) - --tracking.max_tracking TRACKING.MAX_TRACKING - If true then the tracker will cap the max number of tracks. (default: False) --tracking.max_tracks TRACKING.MAX_TRACKS - Maximum number of tracks to be tracked by the tracker. (default: None) + Maximum number of tracks to be tracked by the tracker. No limit if None or -1. (default: None) --tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT Target number of instances to track per frame. (default: 0) --tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET @@ -264,13 +262,13 @@ sleap-track -m "models/my_model" --tracking.tracker simple -o "output_prediction **5. Inference with max tracks limit:** ```none -sleap-track -m "models/my_model" --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 -o "output_predictions.slp" "input_video.mp4" +sleap-track -m "models/my_model" --tracking.tracker simple --tracking.max_tracks 4 -o "output_predictions.slp" "input_video.mp4" ``` **6. Re-tracking without pose inference:** ```none -sleap-track --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 -o "retracked.slp" "input_predictions.slp" +sleap-track --tracking.tracker simple --tracking.max_tracks 4 -o "retracked.slp" "input_predictions.slp" ``` **7. Select GPU for pose inference:** diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml index d130b9cb9..406a02ea8 100644 --- a/sleap/config/pipeline_form.yaml +++ b/sleap/config/pipeline_form.yaml @@ -424,10 +424,6 @@ inference: This tracker "shifts" instances from previous frames using optical flow before matching instances in each frame to the shifted instances from prior frames.' - # - name: tracking.max_tracking - # label: Limit max number of tracks - # type: bool - default: false - name: tracking.max_tracks label: Max number of tracks type: optional_int @@ -459,10 +455,12 @@ inference: none_label: Use max (non-robust) range: 0,1 default: 0.95 - # - name: tracking.save_shifted_instances - # label: Save shifted instances - # type: bool - # default: false + - name: tracking.save_shifted_instances + label: Save shifted instances + help: 'Save the flow-shifted instances between elapsed frames. It improves + instance matching at the cost of using a bit more of memory.' + type: bool + default: true - type: text text: 'Kalman filter-based tracking:
Uses the above tracking options to track instances for an initial @@ -523,10 +521,6 @@ inference: text: 'Tracking:
This tracker assigns track identities by matching instances from prior frames to instances on subsequent frames.' - # - name: tracking.max_tracking - # label: Limit max number of tracks - # type: bool - # default: false - name: tracking.max_tracks label: Max number of tracks type: optional_int diff --git a/sleap/gui/learning/dialog.py b/sleap/gui/learning/dialog.py index 2c2617036..7a9c94358 100644 --- a/sleap/gui/learning/dialog.py +++ b/sleap/gui/learning/dialog.py @@ -733,7 +733,7 @@ def run(self): # count < 0 means there was an error and we didn't get any results. if new_counts is not None and new_counts >= 0: total_count = items_for_inference.total_frame_count - no_result_count = total_count - new_counts + no_result_count = max(0, total_count - new_counts) message = ( f"Inference ran on {total_count} frames." diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py index d0bb1f3ba..fb2d799e0 100644 --- a/sleap/gui/learning/runners.py +++ b/sleap/gui/learning/runners.py @@ -204,7 +204,6 @@ def make_predict_cli_call( # Make path where we'll save predictions (if not specified) if output_path is None: - if self.labels_filename: # Make a predictions directory next to the labels dataset file predictions_dir = os.path.join( @@ -239,15 +238,12 @@ def make_predict_cli_call( if key in self.inference_params and self.inference_params[key] is None: del self.inference_params[key] - # Setting max_tracks to True means we want to use the max_tracking mode. - if "tracking.max_tracks" in self.inference_params: - self.inference_params["tracking.max_tracking"] = True - - # Hacky: Update the tracker name to include "maxtracks" suffix. - if self.inference_params["tracking.tracker"] in ("simple", "flow"): - self.inference_params["tracking.tracker"] = ( - self.inference_params["tracking.tracker"] + "maxtracks" - ) + # Compatibility with using the "maxtracks" suffix to the tracker name. + if "tracking.tracker" in self.inference_params: + compat_trackers = ("simplemaxtracks", "flowmaxtracks") + if self.inference_params["tracking.tracker"] in compat_trackers: + tname = self.inference_params["tracking.tracker"][: -len("maxtracks")] + self.inference_params["tracking.tracker"] = tname # --tracking.kf_init_frame_count enables the kalman filter tracking # so if not set, then remove other (unused) args @@ -257,10 +253,12 @@ def make_predict_cli_call( bool_items_as_ints = ( "tracking.pre_cull_to_target", - "tracking.max_tracking", + "tracking.pre_cull_merge_instances", "tracking.post_connect_single_breaks", "tracking.save_shifted_instances", "tracking.oks_score_weighting", + "tracking.prefer_reassigning_track", + "tracking.allow_reassigning_track", ) for key in bool_items_as_ints: @@ -303,10 +301,8 @@ def predict_subprocess( # Run inference CLI capturing output. with subprocess.Popen(cli_args, stdout=subprocess.PIPE) as proc: - # Poll until finished. while proc.poll() is None: - # Read line. line = proc.stdout.readline() line = line.decode().rstrip() @@ -635,7 +631,6 @@ def run_gui_training( for config_info in config_info_list: if config_info.dont_retrain: - if not config_info.has_trained_model: raise ValueError( "Config is set to not retrain but no trained model found: " @@ -849,7 +844,6 @@ def train_subprocess( success = False with tempfile.TemporaryDirectory() as temp_dir: - # Write a temporary file of the TrainingJob so that we can respect # any changed made to the job attributes after it was loaded. temp_filename = datetime.now().strftime("%y%m%d_%H%M%S") + "_training_job.json" diff --git a/sleap/gui/widgets/video.py b/sleap/gui/widgets/video.py index 949703020..4c2370e09 100644 --- a/sleap/gui/widgets/video.py +++ b/sleap/gui/widgets/video.py @@ -816,6 +816,8 @@ def __init__(self, state=None, player=None, *args, **kwargs): self.click_mode = "" self.in_zoom = False + self._down_pos = None + self.zoomFactor = 1 anchor_mode = QGraphicsView.AnchorUnderMouse self.setTransformationAnchor(anchor_mode) @@ -1039,7 +1041,7 @@ def mouseReleaseEvent(self, event): scenePos = self.mapToScene(event.pos()) # check if mouse moved during click - has_moved = event.pos() != self._down_pos + has_moved = self._down_pos is not None and event.pos() != self._down_pos if event.button() == Qt.LeftButton: if self.in_zoom: diff --git a/sleap/instance.py b/sleap/instance.py index 08a5c6ae6..90c7af78e 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -18,12 +18,14 @@ """ import math +from functools import reduce +from itertools import chain, combinations import numpy as np import cattr from copy import copy -from typing import Dict, List, Optional, Union, Tuple, ForwardRef +from typing import Dict, List, Optional, Union, Sequence, Tuple from numpy.lib.recfunctions import structured_to_unstructured @@ -1177,6 +1179,75 @@ def from_numpy( ) +def all_disjoint(x: Sequence[Sequence]) -> bool: + return all((set(p0).isdisjoint(set(p1))) for p0, p1 in combinations(x, 2)) + + +def create_merged_instances( + instances: List[PredictedInstance], + penalty: float = 0.2, +) -> List[PredictedInstance]: + """Create merged instances from the list of PredictedInstance. + + Only instances with non-overlapping visible nodes are merged. + + Args: + instances: a list of original PredictedInstances to try to merge. + penalty: a float between 0 and 1. All scores of the merged instance + are multplied by (1 - penalty). + + Returns: + a list of PredictedInstance that were merged. + """ + # Ensure same skeleton + skeletons = {inst.skeleton for inst in instances} + if len(skeletons) != 1: + return [] + skeleton = list(skeletons)[0] + + # Ensure same track + tracks = {inst.track for inst in instances} + if len(tracks) != 1: + return [] + track = list(tracks)[0] + + # Ensure non-intersecting visible nodes + merged_instances = [] + instance_subsets = ( + combinations(instances, n) for n in range(2, len(instances) + 1) + ) + instance_subsets = chain.from_iterable(instance_subsets) + for subset in instance_subsets: + nodes = [s.nodes for s in subset] + if not all_disjoint(nodes): + continue + + nodes_points_gen = chain.from_iterable( + instance.nodes_points for instance in subset + ) + predicted_points = {node: point for node, point in nodes_points_gen} + + instance_score = reduce(lambda x, y: x * y, [s.score for s in subset]) + + # Penalize scores of merged instances + if 0 < penalty <= 1: + factor = 1 - penalty + instance_score *= factor + for point in predicted_points.values(): + point.score *= factor + + merged_instance = PredictedInstance( + points=predicted_points, + skeleton=skeleton, + score=instance_score, + track=track, + ) + + merged_instances.append(merged_instance) + + return merged_instances + + def make_instance_cattr() -> cattr.Converter: """Create a cattr converter for Lists of Instances/PredictedInstances. diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 421378d56..e5cb1dc49 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -46,6 +46,11 @@ from threading import Thread from queue import Queue +if sys.version_info >= (3, 8): + from functools import cached_property +else: # cached_property is define only for python >=3.8 + cached_property = property + import tensorflow as tf import numpy as np @@ -54,7 +59,7 @@ from sleap.nn.config import TrainingJobConfig, DataConfig from sleap.nn.data.resizing import SizeMatcher from sleap.nn.model import Model -from sleap.nn.tracking import Tracker, run_tracker +from sleap.nn.tracking import Tracker from sleap.nn.paf_grouping import PAFScorer from sleap.nn.data.pipelines import ( Provider, @@ -69,7 +74,7 @@ ) from sleap.nn.utils import reset_input_layer from sleap.io.dataset import Labels -from sleap.util import frame_list, make_scoped_dictionary +from sleap.util import frame_list, make_scoped_dictionary, RateColumn from sleap.instance import PredictedInstance, LabeledFrame from tensorflow.python.framework.convert_to_constants import ( @@ -144,17 +149,6 @@ def get_keras_model_path(path: Text) -> str: return os.path.join(path, "best_model.h5") -class RateColumn(rich.progress.ProgressColumn): - """Renders the progress rate.""" - - def render(self, task: "Task") -> rich.progress.Text: - """Show progress rate.""" - speed = task.speed - if speed is None: - return rich.progress.Text("?", style="progress.data.speed") - return rich.progress.Text(f"{speed:.1f} FPS", style="progress.data.speed") - - @attr.s(auto_attribs=True) class Predictor(ABC): """Base interface class for predictors.""" @@ -167,7 +161,7 @@ class Predictor(ABC): report_rate: float = attr.ib(default=2.0, kw_only=True) model_paths: List[str] = attr.ib(factory=list, kw_only=True) - @property + @cached_property def report_period(self) -> float: """Time between progress reports in seconds.""" return 1.0 / self.report_rate @@ -374,6 +368,122 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: def _initialize_inference_model(self): pass + def _process_batch(self, ex: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """Run prediction model on batch. + + This method handles running inference on a batch and postprocessing. + + Args: + ex: a dictionary holding the input for inference. + + Returns: + The input dictionary updated with the predictions. + """ + # Skip inference if model is not loaded + if self.inference_model is None: + return ex + + # Run inference on current batch. + preds = self.inference_model.predict_on_batch(ex, numpy=True) + + # Add model outputs to the input data example. + ex.update(preds) + + # Convert to numpy arrays if not already. + if isinstance(ex["video_ind"], tf.Tensor): + ex["video_ind"] = ex["video_ind"].numpy().flatten() + if isinstance(ex["frame_ind"], tf.Tensor): + ex["frame_ind"] = ex["frame_ind"].numpy().flatten() + + # Adjust for potential SizeMatcher scaling. + offset_x = ex.get("offset_x", 0) + offset_y = ex.get("offset_y", 0) + ex["instance_peaks"] -= np.reshape([offset_x, offset_y], [-1, 1, 1, 2]) + ex["instance_peaks"] /= np.expand_dims( + np.expand_dims(ex["scale"], axis=1), axis=1 + ) + + return ex + + def _run_batch_json( + self, + examples: List[Dict[str, np.ndarray]], + n_total: int, + max_length: int = 30, + ) -> Iterator[Dict[str, np.ndarray]]: + n_processed = 0 + n_recent = deque(maxlen=max_length) + elapsed_recent = deque(maxlen=max_length) + last_report = time() + t0_all = time() + t0_batch = time() + for ex in examples: + # Process batch of examples. + ex = self._process_batch(ex) + + # Track timing and progress. + elapsed_batch = time() - t0_batch + t0_batch = time() + n_batch = len(ex["frame_ind"]) + n_processed += n_batch + elapsed_all = time() - t0_all + + # Compute recent rate. + n_recent.append(n_batch) + elapsed_recent.append(elapsed_batch) + rate = sum(n_recent) / sum(elapsed_recent) + eta = (n_total - n_processed) / rate + + # Report. + if time() > last_report + self.report_period: + print( + json.dumps( + { + "n_processed": n_processed, + "n_total": n_total, + "elapsed": elapsed_all, + "rate": rate, + "eta": eta, + } + ), + flush=True, + ) + last_report = time() + + # Return results. + yield ex + + def _run_batch_rich( + self, + examples: List[Dict[str, np.ndarray]], + n_total: int, + ) -> Iterator[Dict[str, np.ndarray]]: + with rich.progress.Progress( + "{task.description}", + rich.progress.BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + "ETA:", + rich.progress.TimeRemainingColumn(), + RateColumn(), + auto_refresh=False, + refresh_per_second=self.report_rate, + speed_estimate_period=5, + ) as progress: + task = progress.add_task("Predicting...", total=n_total) + last_report = time() + for ex in examples: + ex = self._process_batch(ex) + + progress.update(task, advance=len(ex["frame_ind"])) + + # Handle refreshing manually to support notebooks. + if time() > last_report + self.report_period: + progress.refresh() + last_report = time() + + # Return results. + yield ex + def _predict_generator( self, data_provider: Provider ) -> Iterator[Dict[str, np.ndarray]]: @@ -395,103 +505,22 @@ def _predict_generator( if self.inference_model is None: self._initialize_inference_model() - def process_batch(ex): - # Run inference on current batch. - preds = self.inference_model.predict_on_batch(ex, numpy=True) - - # Add model outputs to the input data example. - ex.update(preds) - - # Convert to numpy arrays if not already. - if isinstance(ex["video_ind"], tf.Tensor): - ex["video_ind"] = ex["video_ind"].numpy().flatten() - if isinstance(ex["frame_ind"], tf.Tensor): - ex["frame_ind"] = ex["frame_ind"].numpy().flatten() - - # Adjust for potential SizeMatcher scaling. - offset_x = ex.get("offset_x", 0) - offset_y = ex.get("offset_y", 0) - ex["instance_peaks"] -= np.reshape([offset_x, offset_y], [-1, 1, 1, 2]) - ex["instance_peaks"] /= np.expand_dims( - np.expand_dims(ex["scale"], axis=1), axis=1 - ) - - return ex + # Compile loop examples before starting time to improve ETA + n_total = len(data_provider) + examples = self.pipeline.make_dataset() # Loop over data batches with optional progress reporting. if self.verbosity == "rich": - with rich.progress.Progress( - "{task.description}", - rich.progress.BarColumn(), - "[progress.percentage]{task.percentage:>3.0f}%", - "ETA:", - rich.progress.TimeRemainingColumn(), - RateColumn(), - auto_refresh=False, - refresh_per_second=self.report_rate, - speed_estimate_period=5, - ) as progress: - task = progress.add_task("Predicting...", total=len(data_provider)) - last_report = time() - for ex in self.pipeline.make_dataset(): - ex = process_batch(ex) - progress.update(task, advance=len(ex["frame_ind"])) - - # Handle refreshing manually to support notebooks. - elapsed_since_last_report = time() - last_report - if elapsed_since_last_report > self.report_period: - progress.refresh() - - # Return results. - yield ex + for ex in self._run_batch_rich(examples, n_total=n_total): + yield ex elif self.verbosity == "json": - n_processed = 0 - n_total = len(data_provider) - n_recent = deque(maxlen=30) - elapsed_recent = deque(maxlen=30) - last_report = time() - t0_all = time() - t0_batch = time() - for ex in self.pipeline.make_dataset(): - # Process batch of examples. - ex = process_batch(ex) - - # Track timing and progress. - elapsed_batch = time() - t0_batch - t0_batch = time() - n_batch = len(ex["frame_ind"]) - n_processed += n_batch - elapsed_all = time() - t0_all - - # Compute recent rate. - n_recent.append(n_batch) - elapsed_recent.append(elapsed_batch) - rate = sum(n_recent) / sum(elapsed_recent) - eta = (n_total - n_processed) / rate - - # Report. - elapsed_since_last_report = time() - last_report - if elapsed_since_last_report > self.report_period: - print( - json.dumps( - { - "n_processed": n_processed, - "n_total": n_total, - "elapsed": elapsed_all, - "rate": rate, - "eta": eta, - } - ), - flush=True, - ) - last_report = time() - - # Return results. + for ex in self._run_batch_json(examples, n_total=n_total): yield ex + else: - for ex in self.pipeline.make_dataset(): - yield process_batch(ex) + for ex in examples: + yield self._process_batch(ex) def predict( self, data: Union[Provider, sleap.Labels, sleap.Video], make_labels: bool = True @@ -4918,16 +4947,10 @@ def unpack_sleap_model(model_path): ) predictor.verbosity = progress_reporting if tracker is not None: - use_max_tracker = tracker_max_instances is not None - if use_max_tracker and not tracker.endswith("maxtracks"): - # Append maxtracks to the tracker name to use the right tracker variants. - tracker += "maxtracks" - predictor.tracker = Tracker.make_tracker_by_name( tracker=tracker, track_window=tracker_window, post_connect_single_breaks=True, - max_tracking=use_max_tracker, max_tracks=tracker_max_instances, # clean_instance_count=tracker_max_instances, ) @@ -5388,7 +5411,9 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]: ) ) else: - provider_list.append(LabelsReader(labels)) + provider_list.append( + LabelsReader(labels, example_indices=frame_list(args.frames)) + ) data_path_list.append(file_path) @@ -5482,7 +5507,10 @@ def _make_tracker_from_cli(args: argparse.Namespace) -> Optional[Tracker]: """ policy_args = make_scoped_dictionary(vars(args), exclude_nones=True) if "tracking" in policy_args: - tracker = Tracker.make_tracker_by_name(**policy_args["tracking"]) + tracker = Tracker.make_tracker_by_name( + progress_reporting=args.verbosity, + **policy_args["tracking"], + ) return tracker return None @@ -5649,14 +5677,16 @@ def main(args: Optional[list] = None): data_path = data_path_list[0] # Load predictions - data_path = args.data_path print("Loading predictions...") - labels_pr = sleap.load_file(data_path) + labels_pr = sleap.load_file(data_path.as_posix()) frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx) + if provider.example_indices is not None: + # Convert indices to a set to search in O(1), otherwise it is much slower + index_set = set(provider.example_indices) + frames = list(filter(lambda lf: lf.frame_idx in index_set, frames)) print("Starting tracker...") - frames = run_tracker(frames=frames, tracker=tracker) - tracker.final_pass(frames) + frames = tracker.run_tracker(frames=frames) labels_pr = Labels(labeled_frames=frames) @@ -5677,7 +5707,7 @@ def main(args: Optional[list] = None): labels_pr.provenance["sleap_version"] = sleap.__version__ labels_pr.provenance["platform"] = platform.platform() labels_pr.provenance["command"] = " ".join(sys.argv) - labels_pr.provenance["data_path"] = data_path + labels_pr.provenance["data_path"] = os.fspath(data_path) labels_pr.provenance["output_path"] = output_path labels_pr.provenance["total_elapsed"] = total_elapsed labels_pr.provenance["start_timestamp"] = start_timestamp diff --git a/sleap/nn/tracker/components.py b/sleap/nn/tracker/components.py index b2f35b21f..492c01d9f 100644 --- a/sleap/nn/tracker/components.py +++ b/sleap/nn/tracker/components.py @@ -12,6 +12,7 @@ """ + import operator from collections import defaultdict import logging @@ -23,6 +24,7 @@ from sleap import PredictedInstance, Instance, Track from sleap.nn import utils +from sleap.instance import create_merged_instances logger = logging.getLogger(__name__) @@ -249,7 +251,6 @@ def nms_fast(boxes, scores, iou_threshold, target_count=None) -> List[int]: # keep looping while some indexes still remain in the indexes list while len(idxs) > 0: - # we want to add the best box which is the last box in sorted list picked_box_idx = idxs[-1] @@ -351,6 +352,8 @@ def cull_frame_instances( instances_list: List[InstanceType], instance_count: int, iou_threshold: Optional[float] = None, + merge_instances: bool = False, + merging_penalty: float = 0.2, ) -> List["LabeledFrame"]: """ Removes instances (for single frame) over instance per frame threshold. @@ -361,6 +364,9 @@ def cull_frame_instances( iou_threshold: Intersection over Union (IOU) threshold to use when removing overlapping instances over target count; if None, then only use score to determine which instances to remove. + merge_instances: If True, allow merging instances with no overlapping + merging_penalty: a float between 0 and 1. All scores of the merged + instance are multplied by (1 - penalty). Returns: Updated list of frames, also modifies frames in place. @@ -368,6 +374,14 @@ def cull_frame_instances( if not instances_list: return + # Merge instances + if merge_instances: + logger.info("Merging instances with penalty: %f", merging_penalty) + merged_instances = create_merged_instances( + instances_list, penalty=merging_penalty + ) + instances_list.extend(merged_instances) + if len(instances_list) > instance_count: # List of instances which we'll pare down keep_instances = instances_list @@ -387,9 +401,10 @@ def cull_frame_instances( if len(keep_instances) > instance_count: # Sort by ascending score, get target number of instances # from the end of list (i.e., with highest score) - extra_instances = sorted(keep_instances, key=operator.attrgetter("score"))[ - :-instance_count - ] + extra_instances = sorted( + keep_instances, + key=operator.attrgetter("score"), + )[:-instance_count] # Remove the extra instances for inst in extra_instances: @@ -523,7 +538,6 @@ def from_candidate_instances( candidate_tracks = [] if candidate_instances: - # Group candidate instances by track. candidate_instances_by_track = defaultdict(list) for instance in candidate_instances: @@ -536,7 +550,6 @@ def from_candidate_instances( matching_similarities = np.full(dims, np.nan) for i, untracked_instance in enumerate(untracked_instances): - for j, candidate_track in enumerate(candidate_tracks): # Compute similarity between untracked instance and all track # candidates. diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 2b02839de..b32f9b8a0 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -1,11 +1,17 @@ """Tracking tools for linking grouped instances over time.""" -from collections import deque, defaultdict import abc +import json +import operator +import sys +from collections import deque +from time import time +from typing import Callable, Deque, Dict, Iterable, Iterator, List, Optional, Tuple + import attr -import numpy as np import cv2 -from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple +import numpy as np +import rich.progress from sleap import Track, LabeledFrame, Skeleton @@ -24,8 +30,13 @@ Match, ) from sleap.nn.tracker.kalman import BareKalmanTracker - from sleap.nn.data.normalization import ensure_int +from sleap.util import RateColumn + +if sys.version_info >= (3, 8): + from functools import cached_property +else: # cached_property is define only for python >=3.8 + cached_property = property @attr.s(eq=False, slots=True, auto_attribs=True) @@ -64,7 +75,6 @@ def from_instance( shift_score: float = 0.0, with_skeleton: bool = False, ): - points_array = new_points_array if points_array is None: points_array = ref_instance.points_array @@ -105,7 +115,23 @@ class MatchedShiftedFrameInstances: @attr.s(auto_attribs=True) -class FlowCandidateMaker: +class CandidateMaker(abc.ABC): + """Abstract base class for candidate maker.""" + + @abc.abstractmethod + def get_candidates( + self, + track_matching_queue: Deque[MatchedFrameInstances], + t: int, + img: np.ndarray, + *args, + **kwargs, + ) -> List[ShiftedInstance]: + pass + + +@attr.s(auto_attribs=True) +class FlowCandidateMaker(CandidateMaker): """Class for producing optical flow shift matching candidates. Attributes: @@ -129,7 +155,7 @@ class FlowCandidateMaker: img_scale: float = 1.0 of_window_size: int = 21 of_max_levels: int = 3 - save_shifted_instances: bool = False + save_shifted_instances: bool = True track_window: int = 5 shifted_instances: Dict[ @@ -200,37 +226,6 @@ def get_shifted_instances( return shifted_instances - def get_candidates( - self, - track_matching_queue: Deque[MatchedFrameInstances], - t: int, - img: np.ndarray, - ) -> List[ShiftedInstance]: - candidate_instances = [] - - # Prune old shifted instances to save time and memory - self.prune_shifted_instances(t) - - for matched_item in track_matching_queue: - ref_t, ref_img, ref_instances = ( - matched_item.t, - matched_item.img_t, - matched_item.instances_t, - ) - - # Check if shifted instance was computed at earlier time - if self.save_shifted_instances: - ref_img, ref_instances = self.get_shifted_instances_from_earlier_time( - ref_t, ref_img, ref_instances, t - ) - - if len(ref_instances) > 0: - candidate_instances.extend( - self.get_shifted_instances(ref_instances, ref_img, ref_t, img, t) - ) - - return candidate_instances - def prune_shifted_instances(self, t: int): """Prune the shifted instances older than `self.track_window`. @@ -354,45 +349,9 @@ def flow_shift_instances( return shifted_instances - -@attr.s(auto_attribs=True) -class FlowMaxTracksCandidateMaker(FlowCandidateMaker): - """Class for producing optical flow shift matching candidates with maximum tracks. - - Attributes: - max_tracks: The maximum number of tracks to avoid redundant tracks. - - """ - - max_tracks: int = None - - @staticmethod - def get_ref_instances( - ref_t: int, - ref_img: np.ndarray, - track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]], - ) -> List[InstanceType]: - """Generates a list of instances based on the reference time and image. - - Args: - ref_t: Previous frame time instance. - ref_img: Previous frame image as a numpy array. - track_matching_queue_dict: A dictionary of mapping between the tracks - and the corresponding instances associated with the track. - """ - instances = [] - for track, matched_items in track_matching_queue_dict.items(): - instances += [ - item.instance_t - for item in matched_items - if item.t == ref_t and np.all(item.img_t == ref_img) - ] - return instances - def get_candidates( self, - track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]], - max_tracking: bool, + track_matching_queue: Deque[MatchedFrameInstances], t: int, img: np.ndarray, *args, @@ -402,42 +361,33 @@ def get_candidates( # Prune old shifted instances to save time and memory self.prune_shifted_instances(t) - # Storing the tracks from the dictionary for counting purpose. - tracks = [] - - for track, matched_items in track_matching_queue_dict.items(): - if not max_tracking or len(tracks) < self.max_tracks: - tracks.append(track) - for matched_item in matched_items: - ref_t, ref_img = ( - matched_item.t, - matched_item.img_t, - ) - ref_instances = self.get_ref_instances( - ref_t, ref_img, track_matching_queue_dict - ) - # Check if shifted instance was computed at earlier time - if self.save_shifted_instances: - ( - ref_img, - ref_instances, - ) = self.get_shifted_instances_from_earlier_time( - ref_t, ref_img, ref_instances, t - ) - - if len(ref_instances) > 0: - candidate_instances.extend( - self.get_shifted_instances( - ref_instances, ref_img, ref_t, img, t - ) - ) + for matched_item in track_matching_queue: + ref_t, ref_img, ref_instances = ( + matched_item.t, + matched_item.img_t, + matched_item.instances_t, + ) + + # Check if shifted instance was computed at earlier time + if self.save_shifted_instances: + ( + ref_img, + ref_instances, + ) = self.get_shifted_instances_from_earlier_time( + ref_t, ref_img, ref_instances, t + ) + + if len(ref_instances) > 0: + candidate_instances.extend( + self.get_shifted_instances(ref_instances, ref_img, ref_t, img, t) + ) return candidate_instances @attr.s(auto_attribs=True) -class SimpleCandidateMaker: +class SimpleCandidateMaker(CandidateMaker): """Class for producing list of matching candidates from prior frames.""" min_points: int = 0 @@ -447,48 +397,25 @@ def uses_image(self): return False def get_candidates( - self, track_matching_queue: Deque[MatchedFrameInstances], *args, **kwargs + self, + track_matching_queue: Deque[MatchedFrameInstances], + *args, + **kwargs, ) -> List[InstanceType]: - # Build a pool of matchable candidate instances. + # Create set of matchable candidate instances from each track. candidate_instances = [] for matched_item in track_matching_queue: ref_t, ref_instances = matched_item.t, matched_item.instances_t for ref_instance in ref_instances: if ref_instance.n_visible_points >= self.min_points: candidate_instances.append(ref_instance) - return candidate_instances - -@attr.s(auto_attribs=True) -class SimpleMaxTracksCandidateMaker(SimpleCandidateMaker): - """Class to generate instances with maximum number of tracks from prior frames.""" - - max_tracks: int = None - - def get_candidates( - self, - track_matching_queue_dict: Dict, - max_tracking: bool, - *args, - **kwargs, - ) -> List[InstanceType]: - # Create set of matchable candidate instances from each track. - candidate_instances = [] - tracks = [] - for track, matched_instances in track_matching_queue_dict.items(): - if not max_tracking or len(tracks) < self.max_tracks: - tracks.append(track) - for ref_instance in matched_instances: - if ref_instance.instance_t.n_visible_points >= self.min_points: - candidate_instances.append(ref_instance.instance_t) return candidate_instances tracker_policies = dict( simple=SimpleCandidateMaker, flow=FlowCandidateMaker, - simplemaxtracks=SimpleMaxTracksCandidateMaker, - flowmaxtracks=FlowMaxTracksCandidateMaker, ) similarity_policies = dict( @@ -508,10 +435,149 @@ def get_candidates( class BaseTracker(abc.ABC): """Abstract base class for tracker.""" + verbosity: str + report_rate: float + @property def is_valid(self): return False + @cached_property + def report_period(self) -> float: + """Time between progress reports in seconds.""" + return 1.0 / self.report_rate + + def run_step(self, lf: LabeledFrame) -> LabeledFrame: + # Clear the tracks + for inst in lf.instances: + inst.track = None + + track_args = dict(untracked_instances=lf.instances, t=lf.frame_idx) + if self.uses_image: + track_args["img"] = lf.video[lf.frame_idx] + else: + track_args["img"] = None + + return LabeledFrame( + frame_idx=lf.frame_idx, + video=lf.video, + instances=self.track(**track_args), + ) + + def _run_tracker_json( + self, + frames: List[LabeledFrame], + max_length: int = 30, + ) -> Iterator[LabeledFrame]: + n_total = len(frames) + n_processed = 0 + n_batch = 0 + n_recent = deque(maxlen=max_length) + elapsed_recent = deque(maxlen=max_length) + last_report = time() + t0_all = time() + t0_batch = time() + + for lf in frames: + new_lf = self.run_step(lf) + + # Track timing and progress + elapsed_all = time() - t0_all + n_processed += 1 + n_batch += 1 + + # Report + if time() > last_report + self.report_period: + elapsed_batch = time() - t0_batch + t0_batch = time() + + # Compute recent rate + n_recent.append(n_batch) + n_batch = 0 + elapsed_recent.append(elapsed_batch) + rate = sum(n_recent) / sum(elapsed_recent) + eta = (n_total - n_processed) / rate + + print( + json.dumps( + { + "n_processed": n_processed, + "n_total": n_total, + "elapsed": elapsed_all, + "rate": rate, + "eta": eta, + } + ), + flush=True, + ) + last_report = time() + + yield new_lf + + def _run_tracker_rich(self, frames: List[LabeledFrame]) -> Iterator[LabeledFrame]: + with rich.progress.Progress( + "{task.description}", + rich.progress.BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + "ETA:", + rich.progress.TimeRemainingColumn(), + RateColumn(), + auto_refresh=False, + refresh_per_second=self.report_rate, + speed_estimate_period=5, + ) as progress: + task = progress.add_task("Tracking...", total=len(frames)) + last_report = time() + for lf in frames: + new_lf = self.run_step(lf) + + progress.update(task, advance=1) + + # Handle refreshing manually to support notebooks. + if time() > last_report + self.report_period: + progress.refresh() + last_report = time() + + yield new_lf + + def run_tracker( + self, + frames: List[LabeledFrame], + *, + verbosity: Optional[str] = None, + final_pass: bool = True, + ) -> List[LabeledFrame]: + """Run the tracker on a set of labeled frames. + + Args: + frames: A list of labeled frames with instances. + + Returns: + The input frames with the new tracks assigned. If the frames already had tracks, + they will be cleared if the tracker has been re-initialized. + """ + # Return original frames if we aren't retracking + if not self.is_valid: + return frames + + verbosity = verbosity or self.verbosity + + # Run tracking on every frame + if verbosity == "rich": + new_lfs = list(self._run_tracker_rich(frames)) + + elif verbosity == "json": + new_lfs = list(self._run_tracker_json(frames)) + + else: + new_lfs = list(self.run_step(lf) for lf in frames) + + # Run final_pass + if final_pass: + self.final_pass(new_lfs) + + return new_lfs + @abc.abstractmethod def track( self, @@ -560,15 +626,20 @@ class Tracker(BaseTracker): use a robust quantile similarity score for the track. If the value is 1, use the max similarity (non-robust). For selecting a robust score, 0.95 is a good value. - max_tracking: Max tracking is incorporated when this is set to true. + max_tracks: Maximum number of tracks. No limit if set to -1. + verbosity: Mode of inference progress reporting. If `"rich"` (the + default), an updating progress bar is displayed in the console or notebook. + If `"json"`, a JSON-serialized message is printed out which can be captured + for programmatic progress monitoring. If `"none"`, nothing is displayed + during tracking -- this is recommended when running on clusters or headless + machines where the output is captured to a log file. """ - max_tracks: int = None track_window: int = 5 similarity_function: Optional[Callable] = instance_similarity matching_function: Callable = greedy_matching - candidate_maker: object = attr.ib(factory=FlowCandidateMaker) - max_tracking: bool = False # To enable maximum tracking. + candidate_maker: CandidateMaker = attr.ib(factory=FlowCandidateMaker) + max_tracks: int = -1 cleaner: Optional[Callable] = None # TODO: deprecate target_instance_count: int = 0 @@ -577,15 +648,23 @@ class Tracker(BaseTracker): robust_best_instance: float = 1.0 min_new_track_points: int = 0 + prefer_reassigning_track: bool = False + allow_reassigning_track: bool = False + + verbosity: str = attr.ib( + validator=attr.validators.in_(["none", "rich", "json"]), + default="none", + ) + report_rate: float = 2.0 + #: Hold frames with matched instances as deque of length `track_window`. track_matching_queue: Deque[MatchedFrameInstances] = attr.ib() - # Hold track, instances with instances as a deque with length as track_window. - track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]] = attr.ib( - factory=dict - ) spawned_tracks: List[Track] = attr.ib(factory=list) + #: Found tracks with last time an instance was found + found_tracks: Dict[Track, InstanceType] = attr.ib(factory=dict) + save_tracked_instances: bool = False tracked_instances: Dict[int, List[InstanceType]] = attr.ib( factory=dict @@ -593,49 +672,50 @@ class Tracker(BaseTracker): last_matches: Optional[FrameMatches] = None + verbosity: str = attr.ib( + validator=attr.validators.in_(["none", "rich", "json"]), + default="none", + ) + report_rate: float = 2.0 + @property def is_valid(self): return self.similarity_function is not None @track_matching_queue.default - def _init_matching_queue(self): + def _init_track_matching_queue(self): """Factory for instantiating default matching queue with specified size.""" return deque(maxlen=self.track_window) - @property - def has_max_tracking(self) -> bool: - return isinstance( - self.candidate_maker, - (SimpleMaxTracksCandidateMaker, FlowMaxTracksCandidateMaker), - ) - def reset_candidates(self): - if self.has_max_tracking: - for track in self.track_matching_queue_dict: - self.track_matching_queue_dict[track] = deque(maxlen=self.track_window) - else: - self.track_matching_queue = deque(maxlen=self.track_window) + self.track_matching_queue = deque(maxlen=self.track_window) @property def unique_tracks_in_queue(self) -> List[Track]: """Returns the unique tracks in the matching queue.""" - - unique_tracks = set() - if self.has_max_tracking: - for track in self.track_matching_queue_dict.keys(): - unique_tracks.add(track) - - else: - for match_item in self.track_matching_queue: - for instance in match_item.instances_t: - unique_tracks.add(instance.track) - - return list(unique_tracks) + return { + instance.track + for item in self.track_matching_queue + for instance in item.instances_t + } @property def uses_image(self): return getattr(self.candidate_maker, "uses_image", False) + def infer_next_timestep(self, t: Optional[int] = None) -> int: + """Infer timestep if not provided.""" + # Timestep was provided + if t is not None: + return t + + # Default to last timestep + 1 if available. + if len(self.track_matching_queue) > 0: + return self.track_matching_queue[-1].t + 1 + + # Default to 0 + return 0 + def track( self, untracked_instances: List[InstanceType], @@ -657,58 +737,22 @@ def track( return untracked_instances # Infer timestep if not provided. - if t is None: - if self.has_max_tracking: - if len(self.track_matching_queue_dict) > 0: - - # Default to last timestep + 1 if available. - # Here we find the track that has the most instances. - track_with_max_instances = max( - self.track_matching_queue_dict, - key=lambda track: len(self.track_matching_queue_dict[track]), - ) - t = ( - self.track_matching_queue_dict[track_with_max_instances][-1].t - + 1 - ) - - else: - t = 0 - else: - if len(self.track_matching_queue) > 0: - - # Default to last timestep + 1 if available. - t = self.track_matching_queue[-1].t + 1 - - else: - t = 0 + t = self.infer_next_timestep(t) # Initialize containers for tracked instances at the current timestep. tracked_instances = [] - # Make cache so similarity function doesn't have to recompute everything. - # similarity_cache = dict() - # Process untracked instances. if untracked_instances: - if self.pre_cull_function: self.pre_cull_function(untracked_instances) # Build a pool of matchable candidate instances. - if self.has_max_tracking: - candidate_instances = self.candidate_maker.get_candidates( - track_matching_queue_dict=self.track_matching_queue_dict, - max_tracking=self.max_tracking, - t=t, - img=img, - ) - else: - candidate_instances = self.candidate_maker.get_candidates( - track_matching_queue=self.track_matching_queue, - t=t, - img=img, - ) + candidate_instances = self.candidate_maker.get_candidates( + track_matching_queue=self.track_matching_queue, + t=t, + img=img, + ) # Determine matches for untracked instances in current frame. frame_matches = FrameMatches.from_candidate_instances( @@ -724,37 +768,18 @@ def track( # Set track for each of the matched instances. tracked_instances.extend( - self.update_matched_instance_tracks(frame_matches.matches) + self.update_matched_instance_tracks(frame_matches.matches, t) ) - # Spawn a new track for each remaining untracked instance. + # Assign unmatched instances to new tracks or already existing tracks tracked_instances.extend( self.spawn_for_untracked_instances(frame_matches.unmatched_instances, t) ) - # Add the tracked instances to the dictionary of matched instances. - if self.has_max_tracking: - for tracked_instance in tracked_instances: - if tracked_instance.track in self.track_matching_queue_dict: - self.track_matching_queue_dict[tracked_instance.track].append( - MatchedFrameInstance(t, tracked_instance, img) - ) - elif ( - not self.max_tracking - or len(self.track_matching_queue_dict) < self.max_tracks - ): - self.track_matching_queue_dict[tracked_instance.track] = deque( - maxlen=self.track_window - ) - self.track_matching_queue_dict[tracked_instance.track].append( - MatchedFrameInstance(t, tracked_instance, img) - ) - - else: - # Add the tracked instances to the matching buffer. - self.track_matching_queue.append( - MatchedFrameInstances(t, tracked_instances, img) - ) + # Add the tracked instances to the matching buffer. + self.track_matching_queue.append( + MatchedFrameInstances(t, tracked_instances, img) + ) # Save tracked instances internally. if self.save_tracked_instances: @@ -762,8 +787,11 @@ def track( return tracked_instances - @staticmethod - def update_matched_instance_tracks(matches: List[Match]) -> List[InstanceType]: + def update_matched_instance_tracks( + self, + matches: List[Match], + t: int, + ) -> List[InstanceType]: inst_list = [] for match in matches: # Assign to track and save. @@ -774,32 +802,108 @@ def update_matched_instance_tracks(matches: List[Match]) -> List[InstanceType]: tracking_score=match.score, ) ) + # Keep the last instance for this track + self.found_tracks[match.track] = t return inst_list + def spawn_new_track(self, inst: InstanceType, t: int) -> Optional[InstanceType]: + """Try spawning a new track for instance.""" + if self.max_tracks >= 0 and len(self.found_tracks) >= self.max_tracks: + return None + + # Spawn new track. + new_track = Track(spawned_on=t, name=f"track_{len(self.spawned_tracks)}") + self.spawned_tracks.append(new_track) + + # After setting track, keep the last instance for this track + self.found_tracks[new_track] = t + + # Assign instance to the new track and save. + return attr.evolve(inst, track=new_track) + + def assign_to_track( + self, + inst: InstanceType, + t: int, + not_assigned_tracks: List[Track], + ) -> Optional[InstanceType]: + """Try assigning instance to a track from a candidate list (best last). + + `not_assigned_tracks` will be modified. + """ + if len(not_assigned_tracks) == 0: + return None + + existing_track = not_assigned_tracks.pop() + + # After setting track, keep the last instance for this track + self.found_tracks[existing_track] = t + + # Assign instance to the existing track and save. + return attr.evolve(inst, track=existing_track, tracking_score=0) + def spawn_for_untracked_instances( - self, unmatched_instances: List[InstanceType], t: int + self, + unmatched_instances: List[InstanceType], + t: int, ) -> List[InstanceType]: - results = [] - for inst in unmatched_instances: + """Assign an existing track or spawn a new track for untracked instances.""" + # Early return + if len(unmatched_instances) == 0: + return [] + + # Use the tracks that have not been assigned an instance since track_window + not_assigned_tracks_dict = { + track: last_t + for track, last_t in self.found_tracks.items() + if last_t < t - self.track_window + } + # Sort tracks by last used last, so we can pop the last used track + not_assigned_tracks = sorted( + not_assigned_tracks_dict, + key=self.found_tracks.get, + ) - # Skip if this instance is too small to spawn a new track with. - if inst.n_visible_points < self.min_new_track_points: - continue + # Tracks left to assign (all if max_tracks is negative + n_remaining_tracks = ( + max(0, self.max_tracks - len(not_assigned_tracks)) + if self.max_tracks >= 0 + else len(unmatched_instances) + ) - # Skip if we've reached the maximum number of tracks. - if ( - self.has_max_tracking - and self.max_tracking - and len(self.track_matching_queue_dict) >= self.max_tracks - ): - break + # Sort instances by descending instance-level grouping score + sorted_instances = sorted( + unmatched_instances, + key=operator.attrgetter("score"), + reverse=True, + )[:n_remaining_tracks] - # Spawn new track. - new_track = Track(spawned_on=t, name=f"track_{len(self.spawned_tracks)}") - self.spawned_tracks.append(new_track) + results = [] + for inst in sorted_instances: + # Skip if not enough visible nodes to assign to a track or spawn a new track + if inst.n_visible_points < self.min_new_track_points: + continue - # Assign instance to the new track and save. - results.append(attr.evolve(inst, track=new_track)) + # Try spawning new track (if this is preferred, before reassigning track) + if not self.prefer_reassigning_track: + matched_inst = self.spawn_new_track(inst, t) + if matched_inst is not None: + results.append(matched_inst) + continue + + # Try assigning to an existing track + if self.allow_reassigning_track: + matched_inst = self.assign_to_track(inst, t, not_assigned_tracks) + if matched_inst is not None: + results.append(matched_inst) + continue + + # Try spawning new track (if this is not preferred, after reassigning track) + if self.prefer_reassigning_track: + matched_inst = self.spawn_new_track(inst, t) + if matched_inst is not None: + results.append(matched_inst) + continue return results @@ -843,6 +947,8 @@ def make_tracker_by_name( target_instance_count: int = 0, pre_cull_to_target: bool = False, pre_cull_iou_threshold: Optional[float] = None, + pre_cull_merge_instances: bool = False, + pre_cull_merging_penalty: float = 0.2, # Post-tracking options to connect broken tracks post_connect_single_breaks: bool = False, # TODO: deprecate these post-tracking cleaning options @@ -853,18 +959,18 @@ def make_tracker_by_name( kf_node_indices: Optional[list] = None, # Max tracking options max_tracks: Optional[int] = None, - max_tracking: bool = False, + prefer_reassigning_track: bool = False, + allow_reassigning_track: bool = False, # Object keypoint similarity options oks_errors: Optional[list] = None, oks_score_weighting: bool = False, oks_normalization: str = "all", + progress_reporting: str = "rich", + report_rate: float = 2.0, **kwargs, ) -> BaseTracker: - # Parse max_tracking arguments, only True if max_tracks is not None and > 0 - max_tracking = max_tracking if max_tracks else False - if max_tracking and tracker in ("simple", "flow"): - # Force a candidate maker of 'maxtracks' type - tracker += "maxtracks" + # Parse max_tracks, set to -1 if None + max_tracks = max_tracks if max_tracks is not None and max_tracks >= 0 else -1 if tracker.lower() == "none": candidate_maker = None @@ -900,9 +1006,6 @@ def make_tracker_by_name( candidate_maker.save_shifted_instances = save_shifted_instances candidate_maker.track_window = track_window - if tracker == "simplemaxtracks" or tracker == "flowmaxtracks": - candidate_maker.max_tracks = max_tracks - cleaner = None if clean_instance_count: cleaner = TrackCleaner( @@ -910,28 +1013,33 @@ def make_tracker_by_name( ) pre_cull_function = None - if target_instance_count and pre_cull_to_target: + if (target_instance_count and pre_cull_to_target) or pre_cull_merge_instances: def pre_cull_function(inst_list): cull_frame_instances( inst_list, instance_count=target_instance_count, iou_threshold=pre_cull_iou_threshold, + merge_instances=pre_cull_merge_instances, + merging_penalty=pre_cull_merging_penalty, ) tracker_obj = cls( track_window=track_window, robust_best_instance=robust, min_new_track_points=min_new_track_points, + max_tracks=max_tracks, similarity_function=similarity_function, matching_function=matching_function, candidate_maker=candidate_maker, cleaner=cleaner, pre_cull_function=pre_cull_function, - max_tracking=max_tracking, - max_tracks=max_tracks, target_instance_count=target_instance_count, + allow_reassigning_track=allow_reassigning_track, + prefer_reassigning_track=prefer_reassigning_track, post_connect_single_breaks=post_connect_single_breaks, + verbosity=progress_reporting, + report_rate=report_rate, ) if target_instance_count and kf_init_frame_count: @@ -951,7 +1059,6 @@ def pre_cull_function(inst_list): @classmethod def get_by_name_factory_options(cls): - options = [] option = dict(name="tracker", default="None") @@ -961,17 +1068,12 @@ def get_by_name_factory_options(cls): ] options.append(option) - option = dict(name="max_tracking", default=False) - option["type"] = bool - option["help"] = ( - "If true then the tracker will cap the max number of tracks. " - "Falls back to false if `max_tracks` is not defined or 0." - ) - options.append(option) - option = dict(name="max_tracks", default=None) option["type"] = int - option["help"] = "Maximum number of tracks to be tracked by the tracker." + option["help"] = ( + "Maximum number of tracks to be tracked by the tracker. " + "No maximum if set to -1." + ) options.append(option) option = dict(name="target_instance_count", default=0) @@ -996,6 +1098,22 @@ def get_by_name_factory_options(cls): ) options.append(option) + option = dict(name="pre_cull_merge_instances", default=False) + option["type"] = bool + option["help"] = ( + "If True, allow merging instances with non-overlapping visible nodes " + "to create new instances *before* tracking." + ) + options.append(option) + + option = dict(name="pre_cull_merging_penalty", default=0.2) + option["type"] = float + option["help"] = ( + "A float between 0 and 1. All scores of the merged instances " + "are multplied by (1 - penalty)." + ) + options.append(option) + option = dict(name="post_connect_single_breaks", default=0) option["type"] = int option["help"] = ( @@ -1043,6 +1161,21 @@ def get_by_name_factory_options(cls): option["help"] = "Minimum number of instance points for spawning new track" options.append(option) + option = dict(name="allow_reassigning_track", default=False) + option["type"] = bool + option[ + "help" + ] = "Allow assigning existing but unused track to unmatched instances." + options.append(option) + + option = dict(name="prefer_reassigning_track", default=False) + option["type"] = bool + option["help"] = ( + "Try first to reassign to an existing track before trying to " + "spawn a new track with unmatched instances." + ) + options.append(option) + option = dict(name="min_match_points", default=0) option["type"] = int option["help"] = "Minimum points for match candidates" @@ -1066,11 +1199,12 @@ def get_by_name_factory_options(cls): option["help"] = "For optical-flow: Number of pyramid scale levels to consider" options.append(option) - option = dict(name="save_shifted_instances", default=0) + option = dict(name="save_shifted_instances", default=1) option["type"] = int option["help"] = ( "If non-zero and tracking.tracker is set to flow, save the shifted " - "instances between elapsed frames" + "instances between elapsed frames. It uses a bit more of memory but gives " + "better instance matches." ) options.append(option) @@ -1084,9 +1218,10 @@ def int_list_func(s): option = dict(name="kf_init_frame_count", default="0") option["type"] = int - option[ - "help" - ] = "For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used." + option["help"] = ( + "For Kalman filter: Number of frames to track with other tracker. " + "0 means no Kalman filters will be used." + ) options.append(option) def float_list_func(s): @@ -1107,9 +1242,10 @@ def float_list_func(s): option = dict(name="oks_score_weighting", default="0") option["type"] = int option["help"] = ( - "For Object Keypoint similarity: if 0 (default), only the distance between the reference " - "and query keypoint is used to compute the similarity. If 1, each distance is weighted " - "by the prediction scores of the reference and query keypoint." + "For Object Keypoint similarity: if 0 (default), only the distance " + "between the reference and query keypoint is used to compute the " + "similarity. If 1, each distance is weighted by the prediction scores " + "of the reference and query keypoint." ) options.append(option) @@ -1117,10 +1253,10 @@ def float_list_func(s): option["type"] = str option["options"] = ["all", "ref", "union"] option["help"] = ( - "For Object Keypoint similarity: Determine how to normalize similarity score. " + "Object Keypoint similarity: Determine how to normalize similarity score. " "If 'all', similarity score is normalized by number of reference points. " - "If 'ref', similarity score is normalized by number of visible reference points. " - "If 'union', similarity score is normalized by number of points both visible " + "If 'ref', score is normalized by number of visible reference points. " + "If 'union', score is normalized by number of points both visible " "in query and reference instance." ) options.append(option) @@ -1140,11 +1276,21 @@ def add_cli_parser_args(cls, parser, arg_scope: str = ""): else: arg_name = arg["name"] - parser.add_argument( - f"--{arg_name}", - type=arg["type"], - help=help_string, - ) + if arg["name"] == "tracker": + # If default is defined for "tracking.tracker", we cannot detect + # mal-formed command line. + parser.add_argument( + f"--{arg_name}", + type=arg["type"], + help=help_string, + ) + else: + parser.add_argument( + f"--{arg_name}", + type=arg["type"], + help=help_string, + default=arg["default"], + ) @attr.s(auto_attribs=True) @@ -1156,19 +1302,6 @@ class FlowTracker(Tracker): candidate_maker: object = attr.ib(factory=FlowCandidateMaker) -attr.s(auto_attribs=True) - - -class FlowMaxTracker(Tracker): - """Pre-configured tracker to use optical flow shifted candidates with max tracks.""" - - max_tracks: int = attr.ib(kw_only=True) - similarity_function: Callable = instance_similarity - matching_function: Callable = greedy_matching - candidate_maker: object = attr.ib(factory=FlowMaxTracksCandidateMaker) - max_tracking: bool = True - - @attr.s(auto_attribs=True) class SimpleTracker(Tracker): """A Tracker pre-configured to use simple, non-image-based candidates.""" @@ -1178,17 +1311,6 @@ class SimpleTracker(Tracker): candidate_maker: object = attr.ib(factory=SimpleCandidateMaker) -@attr.s(auto_attribs=True) -class SimpleMaxTracker(Tracker): - """Pre-configured tracker to use simple, non-image-based candidates with max tracks.""" - - max_tracks: int = attr.ib(kw_only=True) - similarity_function: Callable = instance_iou - matching_function: Callable = hungarian_matching - candidate_maker: object = attr.ib(factory=SimpleMaxTracksCandidateMaker) - max_tracking: bool = True - - @attr.s(auto_attribs=True) class KalmanInitSet: init_frame_count: int @@ -1220,7 +1342,6 @@ def add_frame_instances( # "usuable" instances—i.e., instances with the nodes that we'll track # using Kalman filters. elif frame_match.has_only_first_choice_matches: - good_instances = [ inst for inst in instances if self.is_usable_instance(inst) ] @@ -1311,6 +1432,12 @@ class KalmanTracker(BaseTracker): last_t: int = 0 last_init_t: int = 0 + verbosity: str = attr.ib( + validator=attr.validators.in_(["none", "rich", "json"]), + default="none", + ) + report_rate: float = 2.0 + @property def is_valid(self): """Do we have everything we need to run tracking?""" @@ -1434,7 +1561,6 @@ def track( # Check whether we've been getting good results from the Kalman filters. # First, has it been a while since the filters were initialized? if self.init_done and (t - self.last_init_t) > self.re_init_cooldown: - # If it's been a while, then see if it's also been a while since # the filters successfully matched tracks to the instances. if self.kalman_tracker.last_frame_with_tracks < t - self.re_init_after: @@ -1491,46 +1617,6 @@ def run(self, frames: List[LabeledFrame]): connect_single_track_breaks(frames, self.instance_count) -def run_tracker(frames: List[LabeledFrame], tracker: BaseTracker) -> List[LabeledFrame]: - """Run a tracker on a set of labeled frames. - - Args: - frames: A list of labeled frames with instances. - tracker: An initialized Tracker. - - Returns: - The input frames with the new tracks assigned. If the frames already had tracks, - they will be cleared if the tracker has been re-initialized. - """ - # Return original frames if we aren't retracking - if not tracker.is_valid: - return frames - - new_lfs = [] - - # Run tracking on every frame - for lf in frames: - - # Clear the tracks - for inst in lf.instances: - inst.track = None - - track_args = dict(untracked_instances=lf.instances) - if tracker.uses_image: - track_args["img"] = lf.video[lf.frame_idx] - else: - track_args["img"] = None - - new_lf = LabeledFrame( - frame_idx=lf.frame_idx, - video=lf.video, - instances=tracker.track(**track_args), - ) - new_lfs.append(new_lf) - - return new_lfs - - def retrack(): import argparse import operator @@ -1568,8 +1654,7 @@ def retrack(): print(f"Done loading predictions in {time.time() - t0} seconds.") print("Starting tracker...") - frames = run_tracker(frames=frames, tracker=tracker) - tracker.final_pass(frames) + frames = tracker.run_tracker(frames=frames) new_labels = Labels(labeled_frames=frames) diff --git a/sleap/util.py b/sleap/util.py index eef762ff4..1e59ea237 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -1,8 +1,10 @@ -"""A miscellaneous set of utility functions. +"""A miscellaneous set of utility functions. Try not to put things in here unless they really have no other place. """ +from __future__ import annotations + import base64 import json import os @@ -11,25 +13,40 @@ from collections import defaultdict from io import BytesIO from pathlib import Path -from typing import Any, Dict, Hashable, Iterable, List, Optional +from typing import TYPE_CHECKING, Any, Dict, Hashable, Iterable, List, Optional from urllib.parse import unquote, urlparse from urllib.request import url2pathname +try: + from importlib.resources import files # New in 3.9+ +except ImportError: + from importlib_resources import files # TODO(LM): Upgrade to importlib.resources. + import attr import h5py as h5 import numpy as np import psutil import rapidjson +import rich.progress import yaml - -try: - from importlib.resources import files # New in 3.9+ -except ImportError: - from importlib_resources import files # TODO(LM): Upgrade to importlib.resources. from PIL import Image import sleap.version as sleap_version +if TYPE_CHECKING: + from rich.progress import Task + + +class RateColumn(rich.progress.ProgressColumn): + """Renders the progress rate.""" + + def render(self, task: Task) -> rich.progress.Text: + """Show progress rate.""" + speed = task.speed + if speed is None: + return rich.progress.Text("?", style="progress.data.speed") + return rich.progress.Text(f"{speed:.1f} FPS", style="progress.data.speed") + def json_loads(json_str: str) -> Dict: """A simple wrapper around the JSON decoder we are using. diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index fd615ea81..a3852ec20 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -58,7 +58,6 @@ from sleap.nn.tracking import ( MatchedFrameInstance, FlowCandidateMaker, - FlowMaxTracksCandidateMaker, Tracker, ) from sleap.instance import Track @@ -1273,6 +1272,7 @@ def test_make_export_cli(): assert args.max_instances == max_instances +@pytest.mark.slow() def test_topdown_predictor_save( min_centroid_model_path, min_centered_instance_model_path, tmp_path ): @@ -1315,6 +1315,7 @@ def test_topdown_predictor_save( ) +@pytest.mark.slow() def test_topdown_id_predictor_save( min_centroid_model_path, min_topdown_multiclass_model_path, tmp_path ): @@ -1361,9 +1362,7 @@ def test_topdown_id_predictor_save( "output_path,tracker_method", [ ("not_default", "flow"), - ("not_default", "flowmaxtracks"), (None, "simple"), - (None, "simplemaxtracks"), ], ) def test_retracking( @@ -1380,7 +1379,6 @@ def test_retracking( if tracker_method == "flow": cmd += " --tracking.save_shifted_instances 1" elif tracker_method == "simplemaxtracks" or tracker_method == "flowmaxtracks": - cmd += " --tracking.max_tracking 1" cmd += " --tracking.max_tracks 2" if output_path == "not_default": output_path = Path(tmpdir, "tracked_slp.slp") @@ -1944,8 +1942,8 @@ def test_flow_tracker(centered_pair_predictions_sorted: Labels, tmpdir): @pytest.mark.parametrize( "max_tracks, trackername", [ - (2, "flowmaxtracks"), - (2, "simplemaxtracks"), + (2, "flow"), + (2, "simple"), ], ) def test_max_tracks_matching_queue( @@ -1953,7 +1951,6 @@ def test_max_tracks_matching_queue( ): """Test flow max tracks instance generation.""" labels: Labels = centered_pair_predictions - max_tracking = True track_window = 5 # Setup flow max tracker @@ -1961,11 +1958,10 @@ def test_max_tracks_matching_queue( tracker=trackername, track_window=track_window, save_shifted_instances=True, - max_tracking=max_tracking, max_tracks=max_tracks, ) - tracker.candidate_maker = cast(FlowMaxTracksCandidateMaker, tracker.candidate_maker) + tracker.candidate_maker = cast(FlowCandidateMaker, tracker.candidate_maker) # Run tracking frames = sorted(labels.labeled_frames, key=lambda lf: lf.frame_idx) @@ -1978,18 +1974,17 @@ def test_max_tracks_matching_queue( track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx]) tracker.track(**track_args) - if trackername == "flowmaxtracks": + if trackername == "flow": # Check that saved instances are pruned to track window - for key in tracker.candidate_maker.shifted_instances.keys(): + for key in tracker.candidate_maker.shifted_instances: assert lf.frame_idx - key[0] <= track_window # Keys are pruned assert abs(key[0] - key[1]) <= track_window # Check if the length of each of the tracks is not more than the track window - for track in tracker.track_matching_queue_dict.keys(): - assert len(tracker.track_matching_queue_dict[track]) <= track_window + assert len(tracker.track_matching_queue) <= track_window # Check if number of tracks that are generated are not more than the maximum tracks - assert len(tracker.track_matching_queue_dict) <= max_tracks + assert len(tracker.unique_tracks_in_queue) <= max_tracks def test_movenet_inference(movenet_video): @@ -2012,6 +2007,7 @@ def test_movenet_inference(movenet_video): assert preds["instance_peaks"].shape == (1, 1, 17, 2) +@pytest.mark.slow() def test_movenet_predictor(min_dance_labels, movenet_video): predictor = MoveNetPredictor.from_trained_models("thunder") predictor.verbosity = "none" diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index 5786945fb..be5789dc9 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -15,53 +15,50 @@ from sleap.skeleton import Skeleton -def tracker_by_name(frames=None, **kwargs): - t = Tracker.make_tracker_by_name(**kwargs) - print(kwargs) - print(t.candidate_maker) - if frames is None: - t.track([]) - t.final_pass([]) - return - - for lf in frames: - # Clear the tracks - for inst in lf.instances: - inst.track = None - - track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx]) - t.track(**track_args) - t.final_pass(frames) - - -@pytest.mark.parametrize( - "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"] -) +def run_tracker_by_name(frames=None, img_scale: float = 0, **kwargs): + # Create tracker + t = Tracker.make_tracker_by_name(verbosity="none", **kwargs) + # Update img_scale + if img_scale: + if hasattr(t, "candidate_maker") and hasattr(t.candidate_maker, "img_scale"): + t.candidate_maker.img_scale = img_scale + else: + # Do not even run tracking as it can be slow + pytest.skip("img_scale is not defined for this tracker") + return + + # Run tracking + new_frames = t.run_tracker(frames or []) + assert len(new_frames) == len(frames) + + +@pytest.mark.parametrize("tracker", ["simple", "flow"]) @pytest.mark.parametrize("similarity", ["instance", "iou", "centroid"]) @pytest.mark.parametrize("match", ["greedy", "hungarian"]) +@pytest.mark.parametrize("img_scale", [0, 1, 0.25]) @pytest.mark.parametrize("count", [0, 2]) def test_tracker_by_name( centered_pair_predictions_sorted, tracker, similarity, match, + img_scale, count, ): # This is slow, so limit to 5 time points frames = centered_pair_predictions_sorted[:5] - tracker_by_name( + run_tracker_by_name( frames=frames, tracker=tracker, similarity=similarity, match=match, + img_scale=img_scale, max_tracks=count, ) -@pytest.mark.parametrize( - "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"] -) +@pytest.mark.parametrize("tracker", ["simple", "flow"]) @pytest.mark.parametrize("oks_score_weighting", ["True", "False"]) @pytest.mark.parametrize("oks_normalization", ["all", "ref", "union"]) def test_oks_tracker_by_name( @@ -73,7 +70,7 @@ def test_oks_tracker_by_name( # This is slow, so limit to 5 time points frames = centered_pair_predictions_sorted[:5] - tracker_by_name( + run_tracker_by_name( frames=frames, tracker=tracker, similarity="object_keypoint", @@ -244,7 +241,7 @@ def make_inst(x, y): return insts -def test_max_tracking_large_gap_single_track(): +def test_max_tracks_large_gap_single_track(): # Track 2 instances with gap > window size preds = make_insts( [ @@ -279,11 +276,9 @@ def test_max_tracking_large_gap_single_track(): tracker = Tracker.make_tracker_by_name( tracker="simple", - # tracker="simplemaxtracks", match="hungarian", track_window=2, - # max_tracks=2, - # max_tracking=True, + max_tracks=-1, ) tracked = [] @@ -295,12 +290,10 @@ def test_max_tracking_large_gap_single_track(): assert len(all_tracks) == 3 tracker = Tracker.make_tracker_by_name( - # tracker="simple", - tracker="simplemaxtracks", + tracker="simple", match="hungarian", track_window=2, max_tracks=2, - max_tracking=True, ) tracked = [] @@ -312,7 +305,7 @@ def test_max_tracking_large_gap_single_track(): assert len(all_tracks) == 2 -def test_max_tracking_small_gap_on_both_tracks(): +def test_max_tracks_small_gap_on_both_tracks(): # Test 2 instances with both tracks with gap > window size preds = make_insts( [ @@ -343,11 +336,9 @@ def test_max_tracking_small_gap_on_both_tracks(): tracker = Tracker.make_tracker_by_name( tracker="simple", - # tracker="simplemaxtracks", match="hungarian", track_window=2, - # max_tracks=2, - # max_tracking=True, + max_tracks=-1, ) tracked = [] @@ -359,12 +350,10 @@ def test_max_tracking_small_gap_on_both_tracks(): assert len(all_tracks) == 4 tracker = Tracker.make_tracker_by_name( - # tracker="simple", - tracker="simplemaxtracks", + tracker="simple", match="hungarian", track_window=2, max_tracks=2, - max_tracking=True, ) tracked = [] @@ -376,7 +365,7 @@ def test_max_tracking_small_gap_on_both_tracks(): assert len(all_tracks) == 2 -def test_max_tracking_extra_detections(): +def test_max_tracks_extra_detections(): # Test having more than 2 detected instances in a frame preds = make_insts( [ @@ -412,11 +401,9 @@ def test_max_tracking_extra_detections(): tracker = Tracker.make_tracker_by_name( tracker="simple", - # tracker="simplemaxtracks", match="hungarian", track_window=2, - # max_tracks=2, - # max_tracking=True, + max_tracks=-1, ) tracked = [] @@ -428,12 +415,10 @@ def test_max_tracking_extra_detections(): assert len(all_tracks) == 4 tracker = Tracker.make_tracker_by_name( - # tracker="simple", - tracker="simplemaxtracks", + tracker="simple", match="hungarian", track_window=2, max_tracks=2, - max_tracking=True, ) tracked = [] diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index a6592dc4d..ab8d52b87 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -19,13 +19,13 @@ def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path): inference_cli(cli.split(" ")) labels = sleap.load_file(f"{tmpdir}/simpletracks.slp") - assert len(labels.tracks) == 27 + assert len(labels.tracks) == 8 -def test_simplemax_tracker(tmpdir, centered_pair_predictions_slp_path): +def test_simple_max_tracks(tmpdir, centered_pair_predictions_slp_path): cli = ( - "--tracking.tracker simplemaxtracks " - "--tracking.max_tracking 1 --tracking.max_tracks 2 " + "--tracking.tracker simple " + "--tracking.max_tracks 2 " "--frames 200-300 " f"-o {tmpdir}/simplemaxtracks.slp " f"{centered_pair_predictions_slp_path}" @@ -37,18 +37,19 @@ def test_simplemax_tracker(tmpdir, centered_pair_predictions_slp_path): # TODO: Refactor the below things into a real test suite. +# running an equivalent to `make_ground_truth` is done as a test in tests/nn/test_tracker_components.py def make_ground_truth(frames, tracker, gt_filename): t0 = time.time() - new_labels = run_tracker(frames, tracker) + new_labels = tracker.run_tracker(frames, verbosity="none") print(f"{gt_filename}\t{len(tracker.spawned_tracks)}\t{time.time()-t0}") Labels.save_file(new_labels, gt_filename) def compare_ground_truth(frames, tracker, gt_filename): t0 = time.time() - new_labels = run_tracker(frames, tracker) + new_labels = tracker.run_tracker(frames, verbosity="none") print(f"{gt_filename}\t{time.time() - t0}") does_match = check_tracks(new_labels, gt_filename) @@ -78,43 +79,6 @@ def check_tracks(labels, gt_filename, limit=None): return True -def run_tracker(frames, tracker): - sig = inspect.signature(tracker.track) - takes_img = "img" in sig.parameters - - # t0 = time.time() - - new_lfs = [] - - # Run tracking on every frame - for lf in frames: - - # Clear the tracks - for inst in lf.instances: - inst.track = None - - track_args = dict(untracked_instances=lf.instances) - if takes_img: - track_args["img"] = lf.video[lf.frame_idx] - else: - track_args["img"] = None - - new_lf = LabeledFrame( - frame_idx=lf.frame_idx, - video=lf.video, - instances=tracker.track(**track_args), - ) - new_lfs.append(new_lf) - - # if lf.frame_idx % 100 == 0: print(lf.frame_idx, time.time()-t0) - - # print(time.time() - t0) - - new_labels = Labels() - new_labels.extend(new_lfs) - return new_labels - - def main(f, dir): filename = "tests/data/json_format_v2/centered_pair_predictions.json" @@ -127,8 +91,6 @@ def main(f, dir): trackers = dict( simple=sleap.nn.tracker.simple.SimpleTracker, flow=sleap.nn.tracker.flow.FlowTracker, - simplemaxtracks=sleap.nn.tracker.SimpleMaxTracker, - flowmaxtracks=sleap.nn.tracker.FlowMaxTracker, ) matchers = dict( hungarian=sleap.nn.tracker.components.hungarian_matching, @@ -144,27 +106,21 @@ def main(f, dir): 0.25, ) - def make_tracker( - tracker_name, matcher_name, sim_name, max_tracks, max_tracking=False, scale=0 - ): - if tracker_name == "simplemaxtracks" or tracker_name == "flowmaxtracks": - tracker = trackers[tracker_name]( - matching_function=matchers[matcher_name], - similarity_function=similarities[sim_name], - max_tracks=max_tracks, - max_tracking=max_tracking, - ) - else: - tracker = trackers[tracker_name]( - matching_function=matchers[matcher_name], - similarity_function=similarities[sim_name], - ) + def make_tracker(tracker_name, matcher_name, sim_name, max_tracks, scale=0): + tracker = trackers[tracker_name]( + matching_function=matchers[matcher_name], + similarity_function=similarities[sim_name], + max_tracks=max_tracks, + ) if scale: tracker.candidate_maker.img_scale = scale return tracker def make_filename(tracker_name, matcher_name, sim_name, scale=0): - return f"{dir}{tracker_name}_{int(scale * 100)}_{matcher_name}_{sim_name}.h5" + return os.path.join( + dir, + f"{tracker_name}_{int(scale * 100)}_{matcher_name}_{sim_name}.h5", + ) def make_tracker_and_filename(*args, **kwargs): tracker = make_tracker(*args, **kwargs) @@ -178,43 +134,22 @@ def make_tracker_and_filename(*args, **kwargs): for tracker_name in trackers.keys(): for matcher_name in matchers.keys(): for sim_name in similarities.keys(): - if tracker_name == "flow": # If this tracker supports scale, try multiple scales for scale in scales: tracker, gt_filename = make_tracker_and_filename( tracker_name=tracker_name, matcher_name=matcher_name, - sim_name=sim_name, - scale=scale, - ) - f(frames, tracker, gt_filename) - elif tracker_name == "flowmaxtracks": - # If this tracker supports scale, try multiple scales - for scale in scales: - tracker, gt_filename = make_tracker_and_filename( - tracker_name=tracker_name, - matcher_name=matcher_name, - sim_name=sim_name, max_tracks=2, - max_tracking=True, + sim_name=sim_name, scale=scale, ) f(frames, tracker, gt_filename) - elif tracker_name == "simplemaxtracks": - tracker, gt_filename = make_tracker_and_filename( - tracker_name=tracker_name, - matcher_name=matcher_name, - sim_name=sim_name, - max_tracks=2, - max_tracking=True, - scale=0, - ) - f(frames, tracker, gt_filename) else: tracker, gt_filename = make_tracker_and_filename( tracker_name=tracker_name, matcher_name=matcher_name, + max_tracks=2, sim_name=sim_name, scale=0, )