diff --git a/docs/guides/cli.md b/docs/guides/cli.md
index 03b806903..c59048ecd 100644
--- a/docs/guides/cli.md
+++ b/docs/guides/cli.md
@@ -36,7 +36,7 @@ optional arguments:
```none
usage: sleap-train [-h] [--video-paths VIDEO_PATHS] [--val_labels VAL_LABELS]
- [--test_labels TEST_LABELS] [--tensorboard] [--save_viz]
+ [--test_labels TEST_LABELS] [--tensorboard] [--save_viz]
[--keep_viz] [--zmq] [--run_name RUN_NAME] [--prefix PREFIX]
[--suffix SUFFIX]
training_job_path [labels_path]
@@ -124,7 +124,7 @@ usage: sleap-track [-h] [-m MODELS] [--frames FRAMES] [--only-labeled-frames] [-
[--verbosity {none,rich,json}] [--video.dataset VIDEO.DATASET] [--video.input_format VIDEO.INPUT_FORMAT]
[--video.index VIDEO.INDEX] [--cpu | --first-gpu | --last-gpu | --gpu GPU] [--max_edge_length_ratio MAX_EDGE_LENGTH_RATIO]
[--dist_penalty_weight DIST_PENALTY_WEIGHT] [--batch_size BATCH_SIZE] [--open-in-gui] [--peak_threshold PEAK_THRESHOLD]
- [-n MAX_INSTANCES] [--tracking.tracker TRACKING.TRACKER] [--tracking.max_tracking TRACKING.MAX_TRACKING]
+ [-n MAX_INSTANCES] [--tracking.tracker TRACKING.TRACKER]
[--tracking.max_tracks TRACKING.MAX_TRACKS] [--tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT]
[--tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET] [--tracking.pre_cull_iou_threshold TRACKING.PRE_CULL_IOU_THRESHOLD]
[--tracking.post_connect_single_breaks TRACKING.POST_CONNECT_SINGLE_BREAKS]
@@ -187,10 +187,8 @@ optional arguments:
Limit maximum number of instances in multi-instance models. Not available for ID models. Defaults to None.
--tracking.tracker TRACKING.TRACKER
Options: simple, flow, simplemaxtracks, flowmaxtracks, None (default: None)
- --tracking.max_tracking TRACKING.MAX_TRACKING
- If true then the tracker will cap the max number of tracks. (default: False)
--tracking.max_tracks TRACKING.MAX_TRACKS
- Maximum number of tracks to be tracked by the tracker. (default: None)
+ Maximum number of tracks to be tracked by the tracker. No limit if None or -1. (default: None)
--tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT
Target number of instances to track per frame. (default: 0)
--tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET
@@ -264,13 +262,13 @@ sleap-track -m "models/my_model" --tracking.tracker simple -o "output_prediction
**5. Inference with max tracks limit:**
```none
-sleap-track -m "models/my_model" --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 -o "output_predictions.slp" "input_video.mp4"
+sleap-track -m "models/my_model" --tracking.tracker simple --tracking.max_tracks 4 -o "output_predictions.slp" "input_video.mp4"
```
**6. Re-tracking without pose inference:**
```none
-sleap-track --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 -o "retracked.slp" "input_predictions.slp"
+sleap-track --tracking.tracker simple --tracking.max_tracks 4 -o "retracked.slp" "input_predictions.slp"
```
**7. Select GPU for pose inference:**
diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml
index c34faea55..406a02ea8 100644
--- a/sleap/config/pipeline_form.yaml
+++ b/sleap/config/pipeline_form.yaml
@@ -521,10 +521,6 @@ inference:
text: 'Tracking:
This tracker assigns track identities by matching instances from prior
frames to instances on subsequent frames.'
- # - name: tracking.max_tracking
- # label: Limit max number of tracks
- # type: bool
- # default: false
- name: tracking.max_tracks
label: Max number of tracks
type: optional_int
diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py
index d0bb1f3ba..641e2c4ab 100644
--- a/sleap/gui/learning/runners.py
+++ b/sleap/gui/learning/runners.py
@@ -204,7 +204,6 @@ def make_predict_cli_call(
# Make path where we'll save predictions (if not specified)
if output_path is None:
-
if self.labels_filename:
# Make a predictions directory next to the labels dataset file
predictions_dir = os.path.join(
@@ -239,15 +238,12 @@ def make_predict_cli_call(
if key in self.inference_params and self.inference_params[key] is None:
del self.inference_params[key]
- # Setting max_tracks to True means we want to use the max_tracking mode.
- if "tracking.max_tracks" in self.inference_params:
- self.inference_params["tracking.max_tracking"] = True
-
- # Hacky: Update the tracker name to include "maxtracks" suffix.
- if self.inference_params["tracking.tracker"] in ("simple", "flow"):
- self.inference_params["tracking.tracker"] = (
- self.inference_params["tracking.tracker"] + "maxtracks"
- )
+ # Compatibility with using the "maxtracks" suffix to the tracker name.
+ if "tracking.tracker" in self.inference_params:
+ compat_trackers = ("simplemaxtracks", "flowmaxtracks")
+ if self.inference_params["tracking.tracker"] in compat_trackers:
+ tname = self.inference_params["tracking.tracker"][: -len("maxtracks")]
+ self.inference_params["tracking.tracker"] = tname
# --tracking.kf_init_frame_count enables the kalman filter tracking
# so if not set, then remove other (unused) args
@@ -257,10 +253,11 @@ def make_predict_cli_call(
bool_items_as_ints = (
"tracking.pre_cull_to_target",
- "tracking.max_tracking",
"tracking.post_connect_single_breaks",
"tracking.save_shifted_instances",
"tracking.oks_score_weighting",
+ "tracking.prefer_reassigning_track",
+ "tracking.allow_reassigning_track",
)
for key in bool_items_as_ints:
@@ -303,10 +300,8 @@ def predict_subprocess(
# Run inference CLI capturing output.
with subprocess.Popen(cli_args, stdout=subprocess.PIPE) as proc:
-
# Poll until finished.
while proc.poll() is None:
-
# Read line.
line = proc.stdout.readline()
line = line.decode().rstrip()
@@ -635,7 +630,6 @@ def run_gui_training(
for config_info in config_info_list:
if config_info.dont_retrain:
-
if not config_info.has_trained_model:
raise ValueError(
"Config is set to not retrain but no trained model found: "
@@ -849,7 +843,6 @@ def train_subprocess(
success = False
with tempfile.TemporaryDirectory() as temp_dir:
-
# Write a temporary file of the TrainingJob so that we can respect
# any changed made to the job attributes after it was loaded.
temp_filename = datetime.now().strftime("%y%m%d_%H%M%S") + "_training_job.json"
diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py
index a1a083ba5..e5cb1dc49 100644
--- a/sleap/nn/inference.py
+++ b/sleap/nn/inference.py
@@ -4947,16 +4947,10 @@ def unpack_sleap_model(model_path):
)
predictor.verbosity = progress_reporting
if tracker is not None:
- use_max_tracker = tracker_max_instances is not None
- if use_max_tracker and not tracker.endswith("maxtracks"):
- # Append maxtracks to the tracker name to use the right tracker variants.
- tracker += "maxtracks"
-
predictor.tracker = Tracker.make_tracker_by_name(
tracker=tracker,
track_window=tracker_window,
post_connect_single_breaks=True,
- max_tracking=use_max_tracker,
max_tracks=tracker_max_instances,
# clean_instance_count=tracker_max_instances,
)
diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py
index a99aab8b5..770b337c6 100644
--- a/sleap/nn/tracking.py
+++ b/sleap/nn/tracking.py
@@ -2,6 +2,7 @@
import abc
import json
+import operator
import sys
from collections import deque
from time import time
@@ -114,7 +115,23 @@ class MatchedShiftedFrameInstances:
@attr.s(auto_attribs=True)
-class FlowCandidateMaker:
+class CandidateMaker(abc.ABC):
+ """Abstract base class for candidate maker."""
+
+ @abc.abstractmethod
+ def get_candidates(
+ self,
+ track_matching_queue: Deque[MatchedFrameInstances],
+ t: int,
+ img: np.ndarray,
+ *args,
+ **kwargs,
+ ) -> List[ShiftedInstance]:
+ pass
+
+
+@attr.s(auto_attribs=True)
+class FlowCandidateMaker(CandidateMaker):
"""Class for producing optical flow shift matching candidates.
Attributes:
@@ -209,37 +226,6 @@ def get_shifted_instances(
return shifted_instances
- def get_candidates(
- self,
- track_matching_queue: Deque[MatchedFrameInstances],
- t: int,
- img: np.ndarray,
- ) -> List[ShiftedInstance]:
- candidate_instances = []
-
- # Prune old shifted instances to save time and memory
- self.prune_shifted_instances(t)
-
- for matched_item in track_matching_queue:
- ref_t, ref_img, ref_instances = (
- matched_item.t,
- matched_item.img_t,
- matched_item.instances_t,
- )
-
- # Check if shifted instance was computed at earlier time
- if self.save_shifted_instances:
- ref_img, ref_instances = self.get_shifted_instances_from_earlier_time(
- ref_t, ref_img, ref_instances, t
- )
-
- if len(ref_instances) > 0:
- candidate_instances.extend(
- self.get_shifted_instances(ref_instances, ref_img, ref_t, img, t)
- )
-
- return candidate_instances
-
def prune_shifted_instances(self, t: int):
"""Prune the shifted instances older than `self.track_window`.
@@ -363,45 +349,9 @@ def flow_shift_instances(
return shifted_instances
-
-@attr.s(auto_attribs=True)
-class FlowMaxTracksCandidateMaker(FlowCandidateMaker):
- """Class for producing optical flow shift matching candidates with maximum tracks.
-
- Attributes:
- max_tracks: The maximum number of tracks to avoid redundant tracks.
-
- """
-
- max_tracks: int = None
-
- @staticmethod
- def get_ref_instances(
- ref_t: int,
- ref_img: np.ndarray,
- track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]],
- ) -> List[InstanceType]:
- """Generates a list of instances based on the reference time and image.
-
- Args:
- ref_t: Previous frame time instance.
- ref_img: Previous frame image as a numpy array.
- track_matching_queue_dict: A dictionary of mapping between the tracks
- and the corresponding instances associated with the track.
- """
- instances = []
- for track, matched_items in track_matching_queue_dict.items():
- instances += [
- item.instance_t
- for item in matched_items
- if item.t == ref_t and np.all(item.img_t == ref_img)
- ]
- return instances
-
def get_candidates(
self,
- track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]],
- max_tracking: bool,
+ track_matching_queue: Deque[MatchedFrameInstances],
t: int,
img: np.ndarray,
*args,
@@ -411,42 +361,33 @@ def get_candidates(
# Prune old shifted instances to save time and memory
self.prune_shifted_instances(t)
- # Storing the tracks from the dictionary for counting purpose.
- tracks = []
-
- for track, matched_items in track_matching_queue_dict.items():
- if not max_tracking or len(tracks) < self.max_tracks:
- tracks.append(track)
- for matched_item in matched_items:
- ref_t, ref_img = (
- matched_item.t,
- matched_item.img_t,
- )
- ref_instances = self.get_ref_instances(
- ref_t, ref_img, track_matching_queue_dict
- )
- # Check if shifted instance was computed at earlier time
- if self.save_shifted_instances:
- (
- ref_img,
- ref_instances,
- ) = self.get_shifted_instances_from_earlier_time(
- ref_t, ref_img, ref_instances, t
- )
-
- if len(ref_instances) > 0:
- candidate_instances.extend(
- self.get_shifted_instances(
- ref_instances, ref_img, ref_t, img, t
- )
- )
+ for matched_item in track_matching_queue:
+ ref_t, ref_img, ref_instances = (
+ matched_item.t,
+ matched_item.img_t,
+ matched_item.instances_t,
+ )
+
+ # Check if shifted instance was computed at earlier time
+ if self.save_shifted_instances:
+ (
+ ref_img,
+ ref_instances,
+ ) = self.get_shifted_instances_from_earlier_time(
+ ref_t, ref_img, ref_instances, t
+ )
+
+ if len(ref_instances) > 0:
+ candidate_instances.extend(
+ self.get_shifted_instances(ref_instances, ref_img, ref_t, img, t)
+ )
return candidate_instances
@attr.s(auto_attribs=True)
-class SimpleCandidateMaker:
+class SimpleCandidateMaker(CandidateMaker):
"""Class for producing list of matching candidates from prior frames."""
min_points: int = 0
@@ -456,48 +397,25 @@ def uses_image(self):
return False
def get_candidates(
- self, track_matching_queue: Deque[MatchedFrameInstances], *args, **kwargs
+ self,
+ track_matching_queue: Deque[MatchedFrameInstances],
+ *args,
+ **kwargs,
) -> List[InstanceType]:
- # Build a pool of matchable candidate instances.
+ # Create set of matchable candidate instances from each track.
candidate_instances = []
for matched_item in track_matching_queue:
ref_t, ref_instances = matched_item.t, matched_item.instances_t
for ref_instance in ref_instances:
if ref_instance.n_visible_points >= self.min_points:
candidate_instances.append(ref_instance)
- return candidate_instances
-
-
-@attr.s(auto_attribs=True)
-class SimpleMaxTracksCandidateMaker(SimpleCandidateMaker):
- """Class to generate instances with maximum number of tracks from prior frames."""
-
- max_tracks: int = None
- def get_candidates(
- self,
- track_matching_queue_dict: Dict,
- max_tracking: bool,
- *args,
- **kwargs,
- ) -> List[InstanceType]:
- # Create set of matchable candidate instances from each track.
- candidate_instances = []
- tracks = []
- for track, matched_instances in track_matching_queue_dict.items():
- if not max_tracking or len(tracks) < self.max_tracks:
- tracks.append(track)
- for ref_instance in matched_instances:
- if ref_instance.instance_t.n_visible_points >= self.min_points:
- candidate_instances.append(ref_instance.instance_t)
return candidate_instances
tracker_policies = dict(
simple=SimpleCandidateMaker,
flow=FlowCandidateMaker,
- simplemaxtracks=SimpleMaxTracksCandidateMaker,
- flowmaxtracks=FlowMaxTracksCandidateMaker,
)
similarity_policies = dict(
@@ -708,7 +626,7 @@ class Tracker(BaseTracker):
use a robust quantile similarity score for the track. If the value is 1,
use the max similarity (non-robust). For selecting a robust score,
0.95 is a good value.
- max_tracking: Max tracking is incorporated when this is set to true.
+ max_tracks: Maximum number of tracks. No limit if set to -1.
verbosity: Mode of inference progress reporting. If `"rich"` (the
default), an updating progress bar is displayed in the console or notebook.
If `"json"`, a JSON-serialized message is printed out which can be captured
@@ -717,12 +635,11 @@ class Tracker(BaseTracker):
machines where the output is captured to a log file.
"""
- max_tracks: int = None
track_window: int = 5
similarity_function: Optional[Callable] = instance_similarity
matching_function: Callable = greedy_matching
- candidate_maker: object = attr.ib(factory=FlowCandidateMaker)
- max_tracking: bool = False # To enable maximum tracking.
+ candidate_maker: CandidateMaker = attr.ib(factory=FlowCandidateMaker)
+ max_tracks: int = -1
cleaner: Optional[Callable] = None # TODO: deprecate
target_instance_count: int = 0
@@ -731,15 +648,23 @@ class Tracker(BaseTracker):
robust_best_instance: float = 1.0
min_new_track_points: int = 0
+ prefer_reassigning_track: bool = False
+ allow_reassigning_track: bool = False
+ verbosity: str = attr.ib(
+ validator=attr.validators.in_(["none", "rich", "json"]),
+ default="none",
+ )
+ report_rate: float = 2.0
+
+ #: Hold frames with matched instances as deque of length `track_window`.
track_matching_queue: Deque[MatchedFrameInstances] = attr.ib()
- # Hold track, instances with instances as a deque with length as track_window.
- track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]] = attr.ib(
- factory=dict
- )
spawned_tracks: List[Track] = attr.ib(factory=list)
+ #: Found tracks with last time an instance was found
+ found_tracks: Dict[Track, InstanceType] = attr.ib(factory=dict)
+
save_tracked_instances: bool = False
tracked_instances: Dict[int, List[InstanceType]] = attr.ib(
factory=dict
@@ -758,44 +683,39 @@ def is_valid(self):
return self.similarity_function is not None
@track_matching_queue.default
- def _init_matching_queue(self):
+ def _init_track_matching_queue(self):
"""Factory for instantiating default matching queue with specified size."""
return deque(maxlen=self.track_window)
- @property
- def has_max_tracking(self) -> bool:
- return isinstance(
- self.candidate_maker,
- (SimpleMaxTracksCandidateMaker, FlowMaxTracksCandidateMaker),
- )
-
def reset_candidates(self):
- if self.has_max_tracking:
- for track in self.track_matching_queue_dict:
- self.track_matching_queue_dict[track] = deque(maxlen=self.track_window)
- else:
- self.track_matching_queue = deque(maxlen=self.track_window)
+ self.track_matching_queue = deque(maxlen=self.track_window)
@property
def unique_tracks_in_queue(self) -> List[Track]:
"""Returns the unique tracks in the matching queue."""
-
- unique_tracks = set()
- if self.has_max_tracking:
- for track in self.track_matching_queue_dict.keys():
- unique_tracks.add(track)
-
- else:
- for match_item in self.track_matching_queue:
- for instance in match_item.instances_t:
- unique_tracks.add(instance.track)
-
- return list(unique_tracks)
+ return {
+ instance.track
+ for item in self.track_matching_queue
+ for instance in item.instances_t
+ }
@property
def uses_image(self):
return getattr(self.candidate_maker, "uses_image", False)
+ def infer_next_timestep(self, t: Optional[int] = None) -> int:
+ """Infer timestep if not provided."""
+ # Timestep was provided
+ if t is not None:
+ return t
+
+ # Default to last timestep + 1 if available.
+ if len(self.track_matching_queue) > 0:
+ return self.track_matching_queue[-1].t + 1
+
+ # Default to 0
+ return 0
+
def track(
self,
untracked_instances: List[InstanceType],
@@ -817,55 +737,22 @@ def track(
return untracked_instances
# Infer timestep if not provided.
- if t is None:
- if self.has_max_tracking:
- if len(self.track_matching_queue_dict) > 0:
- # Default to last timestep + 1 if available.
- # Here we find the track that has the most instances.
- track_with_max_instances = max(
- self.track_matching_queue_dict,
- key=lambda track: len(self.track_matching_queue_dict[track]),
- )
- t = (
- self.track_matching_queue_dict[track_with_max_instances][-1].t
- + 1
- )
-
- else:
- t = 0
- else:
- if len(self.track_matching_queue) > 0:
- # Default to last timestep + 1 if available.
- t = self.track_matching_queue[-1].t + 1
-
- else:
- t = 0
+ t = self.infer_next_timestep(t)
# Initialize containers for tracked instances at the current timestep.
tracked_instances = []
- # Make cache so similarity function doesn't have to recompute everything.
- # similarity_cache = dict()
-
# Process untracked instances.
if untracked_instances:
if self.pre_cull_function:
self.pre_cull_function(untracked_instances)
# Build a pool of matchable candidate instances.
- if self.has_max_tracking:
- candidate_instances = self.candidate_maker.get_candidates(
- track_matching_queue_dict=self.track_matching_queue_dict,
- max_tracking=self.max_tracking,
- t=t,
- img=img,
- )
- else:
- candidate_instances = self.candidate_maker.get_candidates(
- track_matching_queue=self.track_matching_queue,
- t=t,
- img=img,
- )
+ candidate_instances = self.candidate_maker.get_candidates(
+ track_matching_queue=self.track_matching_queue,
+ t=t,
+ img=img,
+ )
# Determine matches for untracked instances in current frame.
frame_matches = FrameMatches.from_candidate_instances(
@@ -881,37 +768,18 @@ def track(
# Set track for each of the matched instances.
tracked_instances.extend(
- self.update_matched_instance_tracks(frame_matches.matches)
+ self.update_matched_instance_tracks(frame_matches.matches, t)
)
- # Spawn a new track for each remaining untracked instance.
+ # Assign unmatched instances to new tracks or already existing tracks
tracked_instances.extend(
self.spawn_for_untracked_instances(frame_matches.unmatched_instances, t)
)
- # Add the tracked instances to the dictionary of matched instances.
- if self.has_max_tracking:
- for tracked_instance in tracked_instances:
- if tracked_instance.track in self.track_matching_queue_dict:
- self.track_matching_queue_dict[tracked_instance.track].append(
- MatchedFrameInstance(t, tracked_instance, img)
- )
- elif (
- not self.max_tracking
- or len(self.track_matching_queue_dict) < self.max_tracks
- ):
- self.track_matching_queue_dict[tracked_instance.track] = deque(
- maxlen=self.track_window
- )
- self.track_matching_queue_dict[tracked_instance.track].append(
- MatchedFrameInstance(t, tracked_instance, img)
- )
-
- else:
- # Add the tracked instances to the matching buffer.
- self.track_matching_queue.append(
- MatchedFrameInstances(t, tracked_instances, img)
- )
+ # Add the tracked instances to the matching buffer.
+ self.track_matching_queue.append(
+ MatchedFrameInstances(t, tracked_instances, img)
+ )
# Save tracked instances internally.
if self.save_tracked_instances:
@@ -919,8 +787,11 @@ def track(
return tracked_instances
- @staticmethod
- def update_matched_instance_tracks(matches: List[Match]) -> List[InstanceType]:
+ def update_matched_instance_tracks(
+ self,
+ matches: List[Match],
+ t: int,
+ ) -> List[InstanceType]:
inst_list = []
for match in matches:
# Assign to track and save.
@@ -931,31 +802,108 @@ def update_matched_instance_tracks(matches: List[Match]) -> List[InstanceType]:
tracking_score=match.score,
)
)
+ # Keep the last instance for this track
+ self.found_tracks[match.track] = t
return inst_list
+ def spawn_new_track(self, inst: InstanceType, t: int) -> Optional[InstanceType]:
+ """Try spawning a new track for instance."""
+ if self.max_tracks >= 0 and len(self.found_tracks) >= self.max_tracks:
+ return None
+
+ # Spawn new track.
+ new_track = Track(spawned_on=t, name=f"track_{len(self.spawned_tracks)}")
+ self.spawned_tracks.append(new_track)
+
+ # After setting track, keep the last instance for this track
+ self.found_tracks[new_track] = t
+
+ # Assign instance to the new track and save.
+ return attr.evolve(inst, track=new_track)
+
+ def assign_to_track(
+ self,
+ inst: InstanceType,
+ t: int,
+ not_assigned_tracks: List[Track],
+ ) -> Optional[InstanceType]:
+ """Try assigning instance to a track from a candidate list (best last).
+
+ `not_assigned_tracks` will be modified.
+ """
+ if len(not_assigned_tracks) == 0:
+ return None
+
+ existing_track = not_assigned_tracks.pop()
+
+ # After setting track, keep the last instance for this track
+ self.found_tracks[existing_track] = t
+
+ # Assign instance to the existing track and save.
+ return attr.evolve(inst, track=existing_track, tracking_score=0)
+
def spawn_for_untracked_instances(
- self, unmatched_instances: List[InstanceType], t: int
+ self,
+ unmatched_instances: List[InstanceType],
+ t: int,
) -> List[InstanceType]:
+ """Assign an existing track or spawn a new track for untracked instances."""
+ # Early return
+ if len(unmatched_instances) == 0:
+ return []
+
+ # Use the tracks that have not been assigned an instance since track_window
+ not_assigned_tracks_dict = {
+ track: last_t
+ for track, last_t in self.found_tracks.items()
+ if last_t < t - self.track_window
+ }
+ # Sort tracks by last used last, so we can pop the last used track
+ not_assigned_tracks = sorted(
+ not_assigned_tracks_dict,
+ key=self.found_tracks.get,
+ )
+
+ # Tracks left to assign (all if max_tracks is negative
+ n_remaining_tracks = (
+ max(0, self.max_tracks - len(not_assigned_tracks))
+ if self.max_tracks >= 0
+ else len(unmatched_instances)
+ )
+
+ # Sort instances by descending instance-level grouping score
+ sorted_instances = sorted(
+ unmatched_instances,
+ key=operator.attrgetter("score"),
+ reverse=True,
+ )[:n_remaining_tracks]
+
results = []
- for inst in unmatched_instances:
- # Skip if this instance is too small to spawn a new track with.
+ for inst in sorted_instances:
+ # Skip if not enough visible nodes to assign to a track or spawn a new track
if inst.n_visible_points < self.min_new_track_points:
continue
- # Skip if we've reached the maximum number of tracks.
- if (
- self.has_max_tracking
- and self.max_tracking
- and len(self.track_matching_queue_dict) >= self.max_tracks
- ):
- break
-
- # Spawn new track.
- new_track = Track(spawned_on=t, name=f"track_{len(self.spawned_tracks)}")
- self.spawned_tracks.append(new_track)
-
- # Assign instance to the new track and save.
- results.append(attr.evolve(inst, track=new_track))
+ # Try spawning new track (if this is preferred, before reassigning track)
+ if not self.prefer_reassigning_track:
+ matched_inst = self.spawn_new_track(inst, t)
+ if matched_inst is not None:
+ results.append(matched_inst)
+ continue
+
+ # Try assigning to an existing track
+ if self.allow_reassigning_track:
+ matched_inst = self.assign_to_track(inst, t, not_assigned_tracks)
+ if matched_inst is not None:
+ results.append(matched_inst)
+ continue
+
+ # Try spawning new track (if this is not preferred, after reassigning track)
+ if self.prefer_reassigning_track:
+ matched_inst = self.spawn_new_track(inst, t)
+ if matched_inst is not None:
+ results.append(matched_inst)
+ continue
return results
@@ -1009,7 +957,8 @@ def make_tracker_by_name(
kf_node_indices: Optional[list] = None,
# Max tracking options
max_tracks: Optional[int] = None,
- max_tracking: bool = False,
+ prefer_reassigning_track: bool = False,
+ allow_reassigning_track: bool = False,
# Object keypoint similarity options
oks_errors: Optional[list] = None,
oks_score_weighting: bool = False,
@@ -1018,11 +967,8 @@ def make_tracker_by_name(
report_rate: float = 2.0,
**kwargs,
) -> BaseTracker:
- # Parse max_tracking arguments, only True if max_tracks is not None and > 0
- max_tracking = max_tracking if max_tracks else False
- if max_tracking and tracker in ("simple", "flow"):
- # Force a candidate maker of 'maxtracks' type
- tracker += "maxtracks"
+ # Parse max_tracks, set to -1 if None
+ max_tracks = max_tracks if max_tracks is not None and max_tracks >= 0 else -1
if tracker.lower() == "none":
candidate_maker = None
@@ -1058,9 +1004,6 @@ def make_tracker_by_name(
candidate_maker.save_shifted_instances = save_shifted_instances
candidate_maker.track_window = track_window
- if tracker == "simplemaxtracks" or tracker == "flowmaxtracks":
- candidate_maker.max_tracks = max_tracks
-
cleaner = None
if clean_instance_count:
cleaner = TrackCleaner(
@@ -1081,14 +1024,15 @@ def pre_cull_function(inst_list):
track_window=track_window,
robust_best_instance=robust,
min_new_track_points=min_new_track_points,
+ max_tracks=max_tracks,
similarity_function=similarity_function,
matching_function=matching_function,
candidate_maker=candidate_maker,
cleaner=cleaner,
pre_cull_function=pre_cull_function,
- max_tracking=max_tracking,
- max_tracks=max_tracks,
target_instance_count=target_instance_count,
+ allow_reassigning_track=allow_reassigning_track,
+ prefer_reassigning_track=prefer_reassigning_track,
post_connect_single_breaks=post_connect_single_breaks,
verbosity=progress_reporting,
report_rate=report_rate,
@@ -1120,17 +1064,12 @@ def get_by_name_factory_options(cls):
]
options.append(option)
- option = dict(name="max_tracking", default=False)
- option["type"] = bool
- option["help"] = (
- "If true then the tracker will cap the max number of tracks. "
- "Falls back to false if `max_tracks` is not defined or 0."
- )
- options.append(option)
-
option = dict(name="max_tracks", default=None)
option["type"] = int
- option["help"] = "Maximum number of tracks to be tracked by the tracker."
+ option["help"] = (
+ "Maximum number of tracks to be tracked by the tracker. "
+ "No maximum if set to -1."
+ )
options.append(option)
option = dict(name="target_instance_count", default=0)
@@ -1202,6 +1141,21 @@ def get_by_name_factory_options(cls):
option["help"] = "Minimum number of instance points for spawning new track"
options.append(option)
+ option = dict(name="allow_reassigning_track", default=False)
+ option["type"] = bool
+ option[
+ "help"
+ ] = "Allow assigning existing but unused track to unmatched instances."
+ options.append(option)
+
+ option = dict(name="prefer_reassigning_track", default=False)
+ option["type"] = bool
+ option["help"] = (
+ "Try first to reassign to an existing track before trying to "
+ "spawn a new track with unmatched instances."
+ )
+ options.append(option)
+
option = dict(name="min_match_points", default=0)
option["type"] = int
option["help"] = "Minimum points for match candidates"
@@ -1328,19 +1282,6 @@ class FlowTracker(Tracker):
candidate_maker: object = attr.ib(factory=FlowCandidateMaker)
-attr.s(auto_attribs=True)
-
-
-class FlowMaxTracker(Tracker):
- """Pre-configured tracker to use optical flow shifted candidates with max tracks."""
-
- max_tracks: int = attr.ib(kw_only=True)
- similarity_function: Callable = instance_similarity
- matching_function: Callable = greedy_matching
- candidate_maker: object = attr.ib(factory=FlowMaxTracksCandidateMaker)
- max_tracking: bool = True
-
-
@attr.s(auto_attribs=True)
class SimpleTracker(Tracker):
"""A Tracker pre-configured to use simple, non-image-based candidates."""
@@ -1350,17 +1291,6 @@ class SimpleTracker(Tracker):
candidate_maker: object = attr.ib(factory=SimpleCandidateMaker)
-@attr.s(auto_attribs=True)
-class SimpleMaxTracker(Tracker):
- """Pre-configured tracker to use simple, non-image-based candidates with max tracks."""
-
- max_tracks: int = attr.ib(kw_only=True)
- similarity_function: Callable = instance_iou
- matching_function: Callable = hungarian_matching
- candidate_maker: object = attr.ib(factory=SimpleMaxTracksCandidateMaker)
- max_tracking: bool = True
-
-
@attr.s(auto_attribs=True)
class KalmanInitSet:
init_frame_count: int
diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py
index fd615ea81..9f9332338 100644
--- a/tests/nn/test_inference.py
+++ b/tests/nn/test_inference.py
@@ -58,7 +58,6 @@
from sleap.nn.tracking import (
MatchedFrameInstance,
FlowCandidateMaker,
- FlowMaxTracksCandidateMaker,
Tracker,
)
from sleap.instance import Track
@@ -1273,6 +1272,7 @@ def test_make_export_cli():
assert args.max_instances == max_instances
+@pytest.mark.slow()
def test_topdown_predictor_save(
min_centroid_model_path, min_centered_instance_model_path, tmp_path
):
@@ -1315,6 +1315,7 @@ def test_topdown_predictor_save(
)
+@pytest.mark.slow()
def test_topdown_id_predictor_save(
min_centroid_model_path, min_topdown_multiclass_model_path, tmp_path
):
@@ -1361,9 +1362,7 @@ def test_topdown_id_predictor_save(
"output_path,tracker_method",
[
("not_default", "flow"),
- ("not_default", "flowmaxtracks"),
(None, "simple"),
- (None, "simplemaxtracks"),
],
)
def test_retracking(
@@ -1380,7 +1379,6 @@ def test_retracking(
if tracker_method == "flow":
cmd += " --tracking.save_shifted_instances 1"
elif tracker_method == "simplemaxtracks" or tracker_method == "flowmaxtracks":
- cmd += " --tracking.max_tracking 1"
cmd += " --tracking.max_tracks 2"
if output_path == "not_default":
output_path = Path(tmpdir, "tracked_slp.slp")
@@ -1944,8 +1942,8 @@ def test_flow_tracker(centered_pair_predictions_sorted: Labels, tmpdir):
@pytest.mark.parametrize(
"max_tracks, trackername",
[
- (2, "flowmaxtracks"),
- (2, "simplemaxtracks"),
+ (2, "flow"),
+ (2, "simple"),
],
)
def test_max_tracks_matching_queue(
@@ -1953,7 +1951,6 @@ def test_max_tracks_matching_queue(
):
"""Test flow max tracks instance generation."""
labels: Labels = centered_pair_predictions
- max_tracking = True
track_window = 5
# Setup flow max tracker
@@ -1961,11 +1958,10 @@ def test_max_tracks_matching_queue(
tracker=trackername,
track_window=track_window,
save_shifted_instances=True,
- max_tracking=max_tracking,
max_tracks=max_tracks,
)
- tracker.candidate_maker = cast(FlowMaxTracksCandidateMaker, tracker.candidate_maker)
+ tracker.candidate_maker = cast(FlowCandidateMaker, tracker.candidate_maker)
# Run tracking
frames = sorted(labels.labeled_frames, key=lambda lf: lf.frame_idx)
@@ -1978,18 +1974,17 @@ def test_max_tracks_matching_queue(
track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx])
tracker.track(**track_args)
- if trackername == "flowmaxtracks":
+ if trackername == "flow":
# Check that saved instances are pruned to track window
for key in tracker.candidate_maker.shifted_instances.keys():
assert lf.frame_idx - key[0] <= track_window # Keys are pruned
assert abs(key[0] - key[1]) <= track_window
# Check if the length of each of the tracks is not more than the track window
- for track in tracker.track_matching_queue_dict.keys():
- assert len(tracker.track_matching_queue_dict[track]) <= track_window
+ assert len(tracker.track_matching_queue) <= track_window
# Check if number of tracks that are generated are not more than the maximum tracks
- assert len(tracker.track_matching_queue_dict) <= max_tracks
+ assert len(tracker.unique_tracks_in_queue) <= max_tracks
def test_movenet_inference(movenet_video):
@@ -2012,6 +2007,7 @@ def test_movenet_inference(movenet_video):
assert preds["instance_peaks"].shape == (1, 1, 17, 2)
+@pytest.mark.slow()
def test_movenet_predictor(min_dance_labels, movenet_video):
predictor = MoveNetPredictor.from_trained_models("thunder")
predictor.verbosity = "none"
diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py
index ffdd35257..be5789dc9 100644
--- a/tests/nn/test_tracker_components.py
+++ b/tests/nn/test_tracker_components.py
@@ -32,9 +32,7 @@ def run_tracker_by_name(frames=None, img_scale: float = 0, **kwargs):
assert len(new_frames) == len(frames)
-@pytest.mark.parametrize(
- "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"]
-)
+@pytest.mark.parametrize("tracker", ["simple", "flow"])
@pytest.mark.parametrize("similarity", ["instance", "iou", "centroid"])
@pytest.mark.parametrize("match", ["greedy", "hungarian"])
@pytest.mark.parametrize("img_scale", [0, 1, 0.25])
@@ -60,9 +58,7 @@ def test_tracker_by_name(
)
-@pytest.mark.parametrize(
- "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"]
-)
+@pytest.mark.parametrize("tracker", ["simple", "flow"])
@pytest.mark.parametrize("oks_score_weighting", ["True", "False"])
@pytest.mark.parametrize("oks_normalization", ["all", "ref", "union"])
def test_oks_tracker_by_name(
@@ -245,7 +241,7 @@ def make_inst(x, y):
return insts
-def test_max_tracking_large_gap_single_track():
+def test_max_tracks_large_gap_single_track():
# Track 2 instances with gap > window size
preds = make_insts(
[
@@ -280,11 +276,9 @@ def test_max_tracking_large_gap_single_track():
tracker = Tracker.make_tracker_by_name(
tracker="simple",
- # tracker="simplemaxtracks",
match="hungarian",
track_window=2,
- # max_tracks=2,
- # max_tracking=True,
+ max_tracks=-1,
)
tracked = []
@@ -296,12 +290,10 @@ def test_max_tracking_large_gap_single_track():
assert len(all_tracks) == 3
tracker = Tracker.make_tracker_by_name(
- # tracker="simple",
- tracker="simplemaxtracks",
+ tracker="simple",
match="hungarian",
track_window=2,
max_tracks=2,
- max_tracking=True,
)
tracked = []
@@ -313,7 +305,7 @@ def test_max_tracking_large_gap_single_track():
assert len(all_tracks) == 2
-def test_max_tracking_small_gap_on_both_tracks():
+def test_max_tracks_small_gap_on_both_tracks():
# Test 2 instances with both tracks with gap > window size
preds = make_insts(
[
@@ -344,11 +336,9 @@ def test_max_tracking_small_gap_on_both_tracks():
tracker = Tracker.make_tracker_by_name(
tracker="simple",
- # tracker="simplemaxtracks",
match="hungarian",
track_window=2,
- # max_tracks=2,
- # max_tracking=True,
+ max_tracks=-1,
)
tracked = []
@@ -360,12 +350,10 @@ def test_max_tracking_small_gap_on_both_tracks():
assert len(all_tracks) == 4
tracker = Tracker.make_tracker_by_name(
- # tracker="simple",
- tracker="simplemaxtracks",
+ tracker="simple",
match="hungarian",
track_window=2,
max_tracks=2,
- max_tracking=True,
)
tracked = []
@@ -377,7 +365,7 @@ def test_max_tracking_small_gap_on_both_tracks():
assert len(all_tracks) == 2
-def test_max_tracking_extra_detections():
+def test_max_tracks_extra_detections():
# Test having more than 2 detected instances in a frame
preds = make_insts(
[
@@ -413,11 +401,9 @@ def test_max_tracking_extra_detections():
tracker = Tracker.make_tracker_by_name(
tracker="simple",
- # tracker="simplemaxtracks",
match="hungarian",
track_window=2,
- # max_tracks=2,
- # max_tracking=True,
+ max_tracks=-1,
)
tracked = []
@@ -429,12 +415,10 @@ def test_max_tracking_extra_detections():
assert len(all_tracks) == 4
tracker = Tracker.make_tracker_by_name(
- # tracker="simple",
- tracker="simplemaxtracks",
+ tracker="simple",
match="hungarian",
track_window=2,
max_tracks=2,
- max_tracking=True,
)
tracked = []
diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py
index 84707fd44..ab8d52b87 100644
--- a/tests/nn/test_tracking_integration.py
+++ b/tests/nn/test_tracking_integration.py
@@ -22,10 +22,10 @@ def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path):
assert len(labels.tracks) == 8
-def test_simplemax_tracker(tmpdir, centered_pair_predictions_slp_path):
+def test_simple_max_tracks(tmpdir, centered_pair_predictions_slp_path):
cli = (
- "--tracking.tracker simplemaxtracks "
- "--tracking.max_tracking 1 --tracking.max_tracks 2 "
+ "--tracking.tracker simple "
+ "--tracking.max_tracks 2 "
"--frames 200-300 "
f"-o {tmpdir}/simplemaxtracks.slp "
f"{centered_pair_predictions_slp_path}"
@@ -91,8 +91,6 @@ def main(f, dir):
trackers = dict(
simple=sleap.nn.tracker.simple.SimpleTracker,
flow=sleap.nn.tracker.flow.FlowTracker,
- simplemaxtracks=sleap.nn.tracker.SimpleMaxTracker,
- flowmaxtracks=sleap.nn.tracker.FlowMaxTracker,
)
matchers = dict(
hungarian=sleap.nn.tracker.components.hungarian_matching,
@@ -108,21 +106,12 @@ def main(f, dir):
0.25,
)
- def make_tracker(
- tracker_name, matcher_name, sim_name, max_tracks, max_tracking=False, scale=0
- ):
- if tracker_name == "simplemaxtracks" or tracker_name == "flowmaxtracks":
- tracker = trackers[tracker_name](
- matching_function=matchers[matcher_name],
- similarity_function=similarities[sim_name],
- max_tracks=max_tracks,
- max_tracking=max_tracking,
- )
- else:
- tracker = trackers[tracker_name](
- matching_function=matchers[matcher_name],
- similarity_function=similarities[sim_name],
- )
+ def make_tracker(tracker_name, matcher_name, sim_name, max_tracks, scale=0):
+ tracker = trackers[tracker_name](
+ matching_function=matchers[matcher_name],
+ similarity_function=similarities[sim_name],
+ max_tracks=max_tracks,
+ )
if scale:
tracker.candidate_maker.img_scale = scale
return tracker
@@ -151,36 +140,16 @@ def make_tracker_and_filename(*args, **kwargs):
tracker, gt_filename = make_tracker_and_filename(
tracker_name=tracker_name,
matcher_name=matcher_name,
- sim_name=sim_name,
- scale=scale,
- )
- f(frames, tracker, gt_filename)
- elif tracker_name == "flowmaxtracks":
- # If this tracker supports scale, try multiple scales
- for scale in scales:
- tracker, gt_filename = make_tracker_and_filename(
- tracker_name=tracker_name,
- matcher_name=matcher_name,
- sim_name=sim_name,
max_tracks=2,
- max_tracking=True,
+ sim_name=sim_name,
scale=scale,
)
f(frames, tracker, gt_filename)
- elif tracker_name == "simplemaxtracks":
- tracker, gt_filename = make_tracker_and_filename(
- tracker_name=tracker_name,
- matcher_name=matcher_name,
- sim_name=sim_name,
- max_tracks=2,
- max_tracking=True,
- scale=0,
- )
- f(frames, tracker, gt_filename)
else:
tracker, gt_filename = make_tracker_and_filename(
tracker_name=tracker_name,
matcher_name=matcher_name,
+ max_tracks=2,
sim_name=sim_name,
scale=0,
)