diff --git a/docs/guides/cli.md b/docs/guides/cli.md
index 03b806903..c59048ecd 100644
--- a/docs/guides/cli.md
+++ b/docs/guides/cli.md
@@ -36,7 +36,7 @@ optional arguments:
```none
usage: sleap-train [-h] [--video-paths VIDEO_PATHS] [--val_labels VAL_LABELS]
- [--test_labels TEST_LABELS] [--tensorboard] [--save_viz]
+ [--test_labels TEST_LABELS] [--tensorboard] [--save_viz]
[--keep_viz] [--zmq] [--run_name RUN_NAME] [--prefix PREFIX]
[--suffix SUFFIX]
training_job_path [labels_path]
@@ -124,7 +124,7 @@ usage: sleap-track [-h] [-m MODELS] [--frames FRAMES] [--only-labeled-frames] [-
[--verbosity {none,rich,json}] [--video.dataset VIDEO.DATASET] [--video.input_format VIDEO.INPUT_FORMAT]
[--video.index VIDEO.INDEX] [--cpu | --first-gpu | --last-gpu | --gpu GPU] [--max_edge_length_ratio MAX_EDGE_LENGTH_RATIO]
[--dist_penalty_weight DIST_PENALTY_WEIGHT] [--batch_size BATCH_SIZE] [--open-in-gui] [--peak_threshold PEAK_THRESHOLD]
- [-n MAX_INSTANCES] [--tracking.tracker TRACKING.TRACKER] [--tracking.max_tracking TRACKING.MAX_TRACKING]
+ [-n MAX_INSTANCES] [--tracking.tracker TRACKING.TRACKER]
[--tracking.max_tracks TRACKING.MAX_TRACKS] [--tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT]
[--tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET] [--tracking.pre_cull_iou_threshold TRACKING.PRE_CULL_IOU_THRESHOLD]
[--tracking.post_connect_single_breaks TRACKING.POST_CONNECT_SINGLE_BREAKS]
@@ -187,10 +187,8 @@ optional arguments:
Limit maximum number of instances in multi-instance models. Not available for ID models. Defaults to None.
--tracking.tracker TRACKING.TRACKER
Options: simple, flow, simplemaxtracks, flowmaxtracks, None (default: None)
- --tracking.max_tracking TRACKING.MAX_TRACKING
- If true then the tracker will cap the max number of tracks. (default: False)
--tracking.max_tracks TRACKING.MAX_TRACKS
- Maximum number of tracks to be tracked by the tracker. (default: None)
+ Maximum number of tracks to be tracked by the tracker. No limit if None or -1. (default: None)
--tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT
Target number of instances to track per frame. (default: 0)
--tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET
@@ -264,13 +262,13 @@ sleap-track -m "models/my_model" --tracking.tracker simple -o "output_prediction
**5. Inference with max tracks limit:**
```none
-sleap-track -m "models/my_model" --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 -o "output_predictions.slp" "input_video.mp4"
+sleap-track -m "models/my_model" --tracking.tracker simple --tracking.max_tracks 4 -o "output_predictions.slp" "input_video.mp4"
```
**6. Re-tracking without pose inference:**
```none
-sleap-track --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 -o "retracked.slp" "input_predictions.slp"
+sleap-track --tracking.tracker simple --tracking.max_tracks 4 -o "retracked.slp" "input_predictions.slp"
```
**7. Select GPU for pose inference:**
diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml
index d130b9cb9..406a02ea8 100644
--- a/sleap/config/pipeline_form.yaml
+++ b/sleap/config/pipeline_form.yaml
@@ -424,10 +424,6 @@ inference:
This tracker "shifts" instances from previous frames using optical flow
before matching instances in each frame to the shifted instances from
prior frames.'
- # - name: tracking.max_tracking
- # label: Limit max number of tracks
- # type: bool
- default: false
- name: tracking.max_tracks
label: Max number of tracks
type: optional_int
@@ -459,10 +455,12 @@ inference:
none_label: Use max (non-robust)
range: 0,1
default: 0.95
- # - name: tracking.save_shifted_instances
- # label: Save shifted instances
- # type: bool
- # default: false
+ - name: tracking.save_shifted_instances
+ label: Save shifted instances
+ help: 'Save the flow-shifted instances between elapsed frames. It improves
+ instance matching at the cost of using a bit more of memory.'
+ type: bool
+ default: true
- type: text
text: 'Kalman filter-based tracking:
Uses the above tracking options to track instances for an initial
@@ -523,10 +521,6 @@ inference:
text: 'Tracking:
This tracker assigns track identities by matching instances from prior
frames to instances on subsequent frames.'
- # - name: tracking.max_tracking
- # label: Limit max number of tracks
- # type: bool
- # default: false
- name: tracking.max_tracks
label: Max number of tracks
type: optional_int
diff --git a/sleap/gui/learning/dialog.py b/sleap/gui/learning/dialog.py
index 2c2617036..7a9c94358 100644
--- a/sleap/gui/learning/dialog.py
+++ b/sleap/gui/learning/dialog.py
@@ -733,7 +733,7 @@ def run(self):
# count < 0 means there was an error and we didn't get any results.
if new_counts is not None and new_counts >= 0:
total_count = items_for_inference.total_frame_count
- no_result_count = total_count - new_counts
+ no_result_count = max(0, total_count - new_counts)
message = (
f"Inference ran on {total_count} frames."
diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py
index d0bb1f3ba..fb2d799e0 100644
--- a/sleap/gui/learning/runners.py
+++ b/sleap/gui/learning/runners.py
@@ -204,7 +204,6 @@ def make_predict_cli_call(
# Make path where we'll save predictions (if not specified)
if output_path is None:
-
if self.labels_filename:
# Make a predictions directory next to the labels dataset file
predictions_dir = os.path.join(
@@ -239,15 +238,12 @@ def make_predict_cli_call(
if key in self.inference_params and self.inference_params[key] is None:
del self.inference_params[key]
- # Setting max_tracks to True means we want to use the max_tracking mode.
- if "tracking.max_tracks" in self.inference_params:
- self.inference_params["tracking.max_tracking"] = True
-
- # Hacky: Update the tracker name to include "maxtracks" suffix.
- if self.inference_params["tracking.tracker"] in ("simple", "flow"):
- self.inference_params["tracking.tracker"] = (
- self.inference_params["tracking.tracker"] + "maxtracks"
- )
+ # Compatibility with using the "maxtracks" suffix to the tracker name.
+ if "tracking.tracker" in self.inference_params:
+ compat_trackers = ("simplemaxtracks", "flowmaxtracks")
+ if self.inference_params["tracking.tracker"] in compat_trackers:
+ tname = self.inference_params["tracking.tracker"][: -len("maxtracks")]
+ self.inference_params["tracking.tracker"] = tname
# --tracking.kf_init_frame_count enables the kalman filter tracking
# so if not set, then remove other (unused) args
@@ -257,10 +253,12 @@ def make_predict_cli_call(
bool_items_as_ints = (
"tracking.pre_cull_to_target",
- "tracking.max_tracking",
+ "tracking.pre_cull_merge_instances",
"tracking.post_connect_single_breaks",
"tracking.save_shifted_instances",
"tracking.oks_score_weighting",
+ "tracking.prefer_reassigning_track",
+ "tracking.allow_reassigning_track",
)
for key in bool_items_as_ints:
@@ -303,10 +301,8 @@ def predict_subprocess(
# Run inference CLI capturing output.
with subprocess.Popen(cli_args, stdout=subprocess.PIPE) as proc:
-
# Poll until finished.
while proc.poll() is None:
-
# Read line.
line = proc.stdout.readline()
line = line.decode().rstrip()
@@ -635,7 +631,6 @@ def run_gui_training(
for config_info in config_info_list:
if config_info.dont_retrain:
-
if not config_info.has_trained_model:
raise ValueError(
"Config is set to not retrain but no trained model found: "
@@ -849,7 +844,6 @@ def train_subprocess(
success = False
with tempfile.TemporaryDirectory() as temp_dir:
-
# Write a temporary file of the TrainingJob so that we can respect
# any changed made to the job attributes after it was loaded.
temp_filename = datetime.now().strftime("%y%m%d_%H%M%S") + "_training_job.json"
diff --git a/sleap/gui/widgets/video.py b/sleap/gui/widgets/video.py
index 949703020..4c2370e09 100644
--- a/sleap/gui/widgets/video.py
+++ b/sleap/gui/widgets/video.py
@@ -816,6 +816,8 @@ def __init__(self, state=None, player=None, *args, **kwargs):
self.click_mode = ""
self.in_zoom = False
+ self._down_pos = None
+
self.zoomFactor = 1
anchor_mode = QGraphicsView.AnchorUnderMouse
self.setTransformationAnchor(anchor_mode)
@@ -1039,7 +1041,7 @@ def mouseReleaseEvent(self, event):
scenePos = self.mapToScene(event.pos())
# check if mouse moved during click
- has_moved = event.pos() != self._down_pos
+ has_moved = self._down_pos is not None and event.pos() != self._down_pos
if event.button() == Qt.LeftButton:
if self.in_zoom:
diff --git a/sleap/instance.py b/sleap/instance.py
index 08a5c6ae6..90c7af78e 100644
--- a/sleap/instance.py
+++ b/sleap/instance.py
@@ -18,12 +18,14 @@
"""
import math
+from functools import reduce
+from itertools import chain, combinations
import numpy as np
import cattr
from copy import copy
-from typing import Dict, List, Optional, Union, Tuple, ForwardRef
+from typing import Dict, List, Optional, Union, Sequence, Tuple
from numpy.lib.recfunctions import structured_to_unstructured
@@ -1177,6 +1179,75 @@ def from_numpy(
)
+def all_disjoint(x: Sequence[Sequence]) -> bool:
+ return all((set(p0).isdisjoint(set(p1))) for p0, p1 in combinations(x, 2))
+
+
+def create_merged_instances(
+ instances: List[PredictedInstance],
+ penalty: float = 0.2,
+) -> List[PredictedInstance]:
+ """Create merged instances from the list of PredictedInstance.
+
+ Only instances with non-overlapping visible nodes are merged.
+
+ Args:
+ instances: a list of original PredictedInstances to try to merge.
+ penalty: a float between 0 and 1. All scores of the merged instance
+ are multplied by (1 - penalty).
+
+ Returns:
+ a list of PredictedInstance that were merged.
+ """
+ # Ensure same skeleton
+ skeletons = {inst.skeleton for inst in instances}
+ if len(skeletons) != 1:
+ return []
+ skeleton = list(skeletons)[0]
+
+ # Ensure same track
+ tracks = {inst.track for inst in instances}
+ if len(tracks) != 1:
+ return []
+ track = list(tracks)[0]
+
+ # Ensure non-intersecting visible nodes
+ merged_instances = []
+ instance_subsets = (
+ combinations(instances, n) for n in range(2, len(instances) + 1)
+ )
+ instance_subsets = chain.from_iterable(instance_subsets)
+ for subset in instance_subsets:
+ nodes = [s.nodes for s in subset]
+ if not all_disjoint(nodes):
+ continue
+
+ nodes_points_gen = chain.from_iterable(
+ instance.nodes_points for instance in subset
+ )
+ predicted_points = {node: point for node, point in nodes_points_gen}
+
+ instance_score = reduce(lambda x, y: x * y, [s.score for s in subset])
+
+ # Penalize scores of merged instances
+ if 0 < penalty <= 1:
+ factor = 1 - penalty
+ instance_score *= factor
+ for point in predicted_points.values():
+ point.score *= factor
+
+ merged_instance = PredictedInstance(
+ points=predicted_points,
+ skeleton=skeleton,
+ score=instance_score,
+ track=track,
+ )
+
+ merged_instances.append(merged_instance)
+
+ return merged_instances
+
+
def make_instance_cattr() -> cattr.Converter:
"""Create a cattr converter for Lists of Instances/PredictedInstances.
diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py
index 421378d56..e5cb1dc49 100644
--- a/sleap/nn/inference.py
+++ b/sleap/nn/inference.py
@@ -46,6 +46,11 @@
from threading import Thread
from queue import Queue
+if sys.version_info >= (3, 8):
+ from functools import cached_property
+else: # cached_property is define only for python >=3.8
+ cached_property = property
+
import tensorflow as tf
import numpy as np
@@ -54,7 +59,7 @@
from sleap.nn.config import TrainingJobConfig, DataConfig
from sleap.nn.data.resizing import SizeMatcher
from sleap.nn.model import Model
-from sleap.nn.tracking import Tracker, run_tracker
+from sleap.nn.tracking import Tracker
from sleap.nn.paf_grouping import PAFScorer
from sleap.nn.data.pipelines import (
Provider,
@@ -69,7 +74,7 @@
)
from sleap.nn.utils import reset_input_layer
from sleap.io.dataset import Labels
-from sleap.util import frame_list, make_scoped_dictionary
+from sleap.util import frame_list, make_scoped_dictionary, RateColumn
from sleap.instance import PredictedInstance, LabeledFrame
from tensorflow.python.framework.convert_to_constants import (
@@ -144,17 +149,6 @@ def get_keras_model_path(path: Text) -> str:
return os.path.join(path, "best_model.h5")
-class RateColumn(rich.progress.ProgressColumn):
- """Renders the progress rate."""
-
- def render(self, task: "Task") -> rich.progress.Text:
- """Show progress rate."""
- speed = task.speed
- if speed is None:
- return rich.progress.Text("?", style="progress.data.speed")
- return rich.progress.Text(f"{speed:.1f} FPS", style="progress.data.speed")
-
-
@attr.s(auto_attribs=True)
class Predictor(ABC):
"""Base interface class for predictors."""
@@ -167,7 +161,7 @@ class Predictor(ABC):
report_rate: float = attr.ib(default=2.0, kw_only=True)
model_paths: List[str] = attr.ib(factory=list, kw_only=True)
- @property
+ @cached_property
def report_period(self) -> float:
"""Time between progress reports in seconds."""
return 1.0 / self.report_rate
@@ -374,6 +368,122 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline:
def _initialize_inference_model(self):
pass
+ def _process_batch(self, ex: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
+ """Run prediction model on batch.
+
+ This method handles running inference on a batch and postprocessing.
+
+ Args:
+ ex: a dictionary holding the input for inference.
+
+ Returns:
+ The input dictionary updated with the predictions.
+ """
+ # Skip inference if model is not loaded
+ if self.inference_model is None:
+ return ex
+
+ # Run inference on current batch.
+ preds = self.inference_model.predict_on_batch(ex, numpy=True)
+
+ # Add model outputs to the input data example.
+ ex.update(preds)
+
+ # Convert to numpy arrays if not already.
+ if isinstance(ex["video_ind"], tf.Tensor):
+ ex["video_ind"] = ex["video_ind"].numpy().flatten()
+ if isinstance(ex["frame_ind"], tf.Tensor):
+ ex["frame_ind"] = ex["frame_ind"].numpy().flatten()
+
+ # Adjust for potential SizeMatcher scaling.
+ offset_x = ex.get("offset_x", 0)
+ offset_y = ex.get("offset_y", 0)
+ ex["instance_peaks"] -= np.reshape([offset_x, offset_y], [-1, 1, 1, 2])
+ ex["instance_peaks"] /= np.expand_dims(
+ np.expand_dims(ex["scale"], axis=1), axis=1
+ )
+
+ return ex
+
+ def _run_batch_json(
+ self,
+ examples: List[Dict[str, np.ndarray]],
+ n_total: int,
+ max_length: int = 30,
+ ) -> Iterator[Dict[str, np.ndarray]]:
+ n_processed = 0
+ n_recent = deque(maxlen=max_length)
+ elapsed_recent = deque(maxlen=max_length)
+ last_report = time()
+ t0_all = time()
+ t0_batch = time()
+ for ex in examples:
+ # Process batch of examples.
+ ex = self._process_batch(ex)
+
+ # Track timing and progress.
+ elapsed_batch = time() - t0_batch
+ t0_batch = time()
+ n_batch = len(ex["frame_ind"])
+ n_processed += n_batch
+ elapsed_all = time() - t0_all
+
+ # Compute recent rate.
+ n_recent.append(n_batch)
+ elapsed_recent.append(elapsed_batch)
+ rate = sum(n_recent) / sum(elapsed_recent)
+ eta = (n_total - n_processed) / rate
+
+ # Report.
+ if time() > last_report + self.report_period:
+ print(
+ json.dumps(
+ {
+ "n_processed": n_processed,
+ "n_total": n_total,
+ "elapsed": elapsed_all,
+ "rate": rate,
+ "eta": eta,
+ }
+ ),
+ flush=True,
+ )
+ last_report = time()
+
+ # Return results.
+ yield ex
+
+ def _run_batch_rich(
+ self,
+ examples: List[Dict[str, np.ndarray]],
+ n_total: int,
+ ) -> Iterator[Dict[str, np.ndarray]]:
+ with rich.progress.Progress(
+ "{task.description}",
+ rich.progress.BarColumn(),
+ "[progress.percentage]{task.percentage:>3.0f}%",
+ "ETA:",
+ rich.progress.TimeRemainingColumn(),
+ RateColumn(),
+ auto_refresh=False,
+ refresh_per_second=self.report_rate,
+ speed_estimate_period=5,
+ ) as progress:
+ task = progress.add_task("Predicting...", total=n_total)
+ last_report = time()
+ for ex in examples:
+ ex = self._process_batch(ex)
+
+ progress.update(task, advance=len(ex["frame_ind"]))
+
+ # Handle refreshing manually to support notebooks.
+ if time() > last_report + self.report_period:
+ progress.refresh()
+ last_report = time()
+
+ # Return results.
+ yield ex
+
def _predict_generator(
self, data_provider: Provider
) -> Iterator[Dict[str, np.ndarray]]:
@@ -395,103 +505,22 @@ def _predict_generator(
if self.inference_model is None:
self._initialize_inference_model()
- def process_batch(ex):
- # Run inference on current batch.
- preds = self.inference_model.predict_on_batch(ex, numpy=True)
-
- # Add model outputs to the input data example.
- ex.update(preds)
-
- # Convert to numpy arrays if not already.
- if isinstance(ex["video_ind"], tf.Tensor):
- ex["video_ind"] = ex["video_ind"].numpy().flatten()
- if isinstance(ex["frame_ind"], tf.Tensor):
- ex["frame_ind"] = ex["frame_ind"].numpy().flatten()
-
- # Adjust for potential SizeMatcher scaling.
- offset_x = ex.get("offset_x", 0)
- offset_y = ex.get("offset_y", 0)
- ex["instance_peaks"] -= np.reshape([offset_x, offset_y], [-1, 1, 1, 2])
- ex["instance_peaks"] /= np.expand_dims(
- np.expand_dims(ex["scale"], axis=1), axis=1
- )
-
- return ex
+ # Compile loop examples before starting time to improve ETA
+ n_total = len(data_provider)
+ examples = self.pipeline.make_dataset()
# Loop over data batches with optional progress reporting.
if self.verbosity == "rich":
- with rich.progress.Progress(
- "{task.description}",
- rich.progress.BarColumn(),
- "[progress.percentage]{task.percentage:>3.0f}%",
- "ETA:",
- rich.progress.TimeRemainingColumn(),
- RateColumn(),
- auto_refresh=False,
- refresh_per_second=self.report_rate,
- speed_estimate_period=5,
- ) as progress:
- task = progress.add_task("Predicting...", total=len(data_provider))
- last_report = time()
- for ex in self.pipeline.make_dataset():
- ex = process_batch(ex)
- progress.update(task, advance=len(ex["frame_ind"]))
-
- # Handle refreshing manually to support notebooks.
- elapsed_since_last_report = time() - last_report
- if elapsed_since_last_report > self.report_period:
- progress.refresh()
-
- # Return results.
- yield ex
+ for ex in self._run_batch_rich(examples, n_total=n_total):
+ yield ex
elif self.verbosity == "json":
- n_processed = 0
- n_total = len(data_provider)
- n_recent = deque(maxlen=30)
- elapsed_recent = deque(maxlen=30)
- last_report = time()
- t0_all = time()
- t0_batch = time()
- for ex in self.pipeline.make_dataset():
- # Process batch of examples.
- ex = process_batch(ex)
-
- # Track timing and progress.
- elapsed_batch = time() - t0_batch
- t0_batch = time()
- n_batch = len(ex["frame_ind"])
- n_processed += n_batch
- elapsed_all = time() - t0_all
-
- # Compute recent rate.
- n_recent.append(n_batch)
- elapsed_recent.append(elapsed_batch)
- rate = sum(n_recent) / sum(elapsed_recent)
- eta = (n_total - n_processed) / rate
-
- # Report.
- elapsed_since_last_report = time() - last_report
- if elapsed_since_last_report > self.report_period:
- print(
- json.dumps(
- {
- "n_processed": n_processed,
- "n_total": n_total,
- "elapsed": elapsed_all,
- "rate": rate,
- "eta": eta,
- }
- ),
- flush=True,
- )
- last_report = time()
-
- # Return results.
+ for ex in self._run_batch_json(examples, n_total=n_total):
yield ex
+
else:
- for ex in self.pipeline.make_dataset():
- yield process_batch(ex)
+ for ex in examples:
+ yield self._process_batch(ex)
def predict(
self, data: Union[Provider, sleap.Labels, sleap.Video], make_labels: bool = True
@@ -4918,16 +4947,10 @@ def unpack_sleap_model(model_path):
)
predictor.verbosity = progress_reporting
if tracker is not None:
- use_max_tracker = tracker_max_instances is not None
- if use_max_tracker and not tracker.endswith("maxtracks"):
- # Append maxtracks to the tracker name to use the right tracker variants.
- tracker += "maxtracks"
-
predictor.tracker = Tracker.make_tracker_by_name(
tracker=tracker,
track_window=tracker_window,
post_connect_single_breaks=True,
- max_tracking=use_max_tracker,
max_tracks=tracker_max_instances,
# clean_instance_count=tracker_max_instances,
)
@@ -5388,7 +5411,9 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
)
)
else:
- provider_list.append(LabelsReader(labels))
+ provider_list.append(
+ LabelsReader(labels, example_indices=frame_list(args.frames))
+ )
data_path_list.append(file_path)
@@ -5482,7 +5507,10 @@ def _make_tracker_from_cli(args: argparse.Namespace) -> Optional[Tracker]:
"""
policy_args = make_scoped_dictionary(vars(args), exclude_nones=True)
if "tracking" in policy_args:
- tracker = Tracker.make_tracker_by_name(**policy_args["tracking"])
+ tracker = Tracker.make_tracker_by_name(
+ progress_reporting=args.verbosity,
+ **policy_args["tracking"],
+ )
return tracker
return None
@@ -5649,14 +5677,16 @@ def main(args: Optional[list] = None):
data_path = data_path_list[0]
# Load predictions
- data_path = args.data_path
print("Loading predictions...")
- labels_pr = sleap.load_file(data_path)
+ labels_pr = sleap.load_file(data_path.as_posix())
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)
+ if provider.example_indices is not None:
+ # Convert indices to a set to search in O(1), otherwise it is much slower
+ index_set = set(provider.example_indices)
+ frames = list(filter(lambda lf: lf.frame_idx in index_set, frames))
print("Starting tracker...")
- frames = run_tracker(frames=frames, tracker=tracker)
- tracker.final_pass(frames)
+ frames = tracker.run_tracker(frames=frames)
labels_pr = Labels(labeled_frames=frames)
@@ -5677,7 +5707,7 @@ def main(args: Optional[list] = None):
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
- labels_pr.provenance["data_path"] = data_path
+ labels_pr.provenance["data_path"] = os.fspath(data_path)
labels_pr.provenance["output_path"] = output_path
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
diff --git a/sleap/nn/tracker/components.py b/sleap/nn/tracker/components.py
index b2f35b21f..492c01d9f 100644
--- a/sleap/nn/tracker/components.py
+++ b/sleap/nn/tracker/components.py
@@ -12,6 +12,7 @@
"""
+
import operator
from collections import defaultdict
import logging
@@ -23,6 +24,7 @@
from sleap import PredictedInstance, Instance, Track
from sleap.nn import utils
+from sleap.instance import create_merged_instances
logger = logging.getLogger(__name__)
@@ -249,7 +251,6 @@ def nms_fast(boxes, scores, iou_threshold, target_count=None) -> List[int]:
# keep looping while some indexes still remain in the indexes list
while len(idxs) > 0:
-
# we want to add the best box which is the last box in sorted list
picked_box_idx = idxs[-1]
@@ -351,6 +352,8 @@ def cull_frame_instances(
instances_list: List[InstanceType],
instance_count: int,
iou_threshold: Optional[float] = None,
+ merge_instances: bool = False,
+ merging_penalty: float = 0.2,
) -> List["LabeledFrame"]:
"""
Removes instances (for single frame) over instance per frame threshold.
@@ -361,6 +364,9 @@ def cull_frame_instances(
iou_threshold: Intersection over Union (IOU) threshold to use when
removing overlapping instances over target count; if None, then
only use score to determine which instances to remove.
+ merge_instances: If True, allow merging instances with no overlapping
+ merging_penalty: a float between 0 and 1. All scores of the merged
+ instance are multplied by (1 - penalty).
Returns:
Updated list of frames, also modifies frames in place.
@@ -368,6 +374,14 @@ def cull_frame_instances(
if not instances_list:
return
+ # Merge instances
+ if merge_instances:
+ logger.info("Merging instances with penalty: %f", merging_penalty)
+ merged_instances = create_merged_instances(
+ instances_list, penalty=merging_penalty
+ )
+ instances_list.extend(merged_instances)
+
if len(instances_list) > instance_count:
# List of instances which we'll pare down
keep_instances = instances_list
@@ -387,9 +401,10 @@ def cull_frame_instances(
if len(keep_instances) > instance_count:
# Sort by ascending score, get target number of instances
# from the end of list (i.e., with highest score)
- extra_instances = sorted(keep_instances, key=operator.attrgetter("score"))[
- :-instance_count
- ]
+ extra_instances = sorted(
+ keep_instances,
+ key=operator.attrgetter("score"),
+ )[:-instance_count]
# Remove the extra instances
for inst in extra_instances:
@@ -523,7 +538,6 @@ def from_candidate_instances(
candidate_tracks = []
if candidate_instances:
-
# Group candidate instances by track.
candidate_instances_by_track = defaultdict(list)
for instance in candidate_instances:
@@ -536,7 +550,6 @@ def from_candidate_instances(
matching_similarities = np.full(dims, np.nan)
for i, untracked_instance in enumerate(untracked_instances):
-
for j, candidate_track in enumerate(candidate_tracks):
# Compute similarity between untracked instance and all track
# candidates.
diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py
index 2b02839de..b32f9b8a0 100644
--- a/sleap/nn/tracking.py
+++ b/sleap/nn/tracking.py
@@ -1,11 +1,17 @@
"""Tracking tools for linking grouped instances over time."""
-from collections import deque, defaultdict
import abc
+import json
+import operator
+import sys
+from collections import deque
+from time import time
+from typing import Callable, Deque, Dict, Iterable, Iterator, List, Optional, Tuple
+
import attr
-import numpy as np
import cv2
-from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple
+import numpy as np
+import rich.progress
from sleap import Track, LabeledFrame, Skeleton
@@ -24,8 +30,13 @@
Match,
)
from sleap.nn.tracker.kalman import BareKalmanTracker
-
from sleap.nn.data.normalization import ensure_int
+from sleap.util import RateColumn
+
+if sys.version_info >= (3, 8):
+ from functools import cached_property
+else: # cached_property is define only for python >=3.8
+ cached_property = property
@attr.s(eq=False, slots=True, auto_attribs=True)
@@ -64,7 +75,6 @@ def from_instance(
shift_score: float = 0.0,
with_skeleton: bool = False,
):
-
points_array = new_points_array
if points_array is None:
points_array = ref_instance.points_array
@@ -105,7 +115,23 @@ class MatchedShiftedFrameInstances:
@attr.s(auto_attribs=True)
-class FlowCandidateMaker:
+class CandidateMaker(abc.ABC):
+ """Abstract base class for candidate maker."""
+
+ @abc.abstractmethod
+ def get_candidates(
+ self,
+ track_matching_queue: Deque[MatchedFrameInstances],
+ t: int,
+ img: np.ndarray,
+ *args,
+ **kwargs,
+ ) -> List[ShiftedInstance]:
+ pass
+
+
+@attr.s(auto_attribs=True)
+class FlowCandidateMaker(CandidateMaker):
"""Class for producing optical flow shift matching candidates.
Attributes:
@@ -129,7 +155,7 @@ class FlowCandidateMaker:
img_scale: float = 1.0
of_window_size: int = 21
of_max_levels: int = 3
- save_shifted_instances: bool = False
+ save_shifted_instances: bool = True
track_window: int = 5
shifted_instances: Dict[
@@ -200,37 +226,6 @@ def get_shifted_instances(
return shifted_instances
- def get_candidates(
- self,
- track_matching_queue: Deque[MatchedFrameInstances],
- t: int,
- img: np.ndarray,
- ) -> List[ShiftedInstance]:
- candidate_instances = []
-
- # Prune old shifted instances to save time and memory
- self.prune_shifted_instances(t)
-
- for matched_item in track_matching_queue:
- ref_t, ref_img, ref_instances = (
- matched_item.t,
- matched_item.img_t,
- matched_item.instances_t,
- )
-
- # Check if shifted instance was computed at earlier time
- if self.save_shifted_instances:
- ref_img, ref_instances = self.get_shifted_instances_from_earlier_time(
- ref_t, ref_img, ref_instances, t
- )
-
- if len(ref_instances) > 0:
- candidate_instances.extend(
- self.get_shifted_instances(ref_instances, ref_img, ref_t, img, t)
- )
-
- return candidate_instances
-
def prune_shifted_instances(self, t: int):
"""Prune the shifted instances older than `self.track_window`.
@@ -354,45 +349,9 @@ def flow_shift_instances(
return shifted_instances
-
-@attr.s(auto_attribs=True)
-class FlowMaxTracksCandidateMaker(FlowCandidateMaker):
- """Class for producing optical flow shift matching candidates with maximum tracks.
-
- Attributes:
- max_tracks: The maximum number of tracks to avoid redundant tracks.
-
- """
-
- max_tracks: int = None
-
- @staticmethod
- def get_ref_instances(
- ref_t: int,
- ref_img: np.ndarray,
- track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]],
- ) -> List[InstanceType]:
- """Generates a list of instances based on the reference time and image.
-
- Args:
- ref_t: Previous frame time instance.
- ref_img: Previous frame image as a numpy array.
- track_matching_queue_dict: A dictionary of mapping between the tracks
- and the corresponding instances associated with the track.
- """
- instances = []
- for track, matched_items in track_matching_queue_dict.items():
- instances += [
- item.instance_t
- for item in matched_items
- if item.t == ref_t and np.all(item.img_t == ref_img)
- ]
- return instances
-
def get_candidates(
self,
- track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]],
- max_tracking: bool,
+ track_matching_queue: Deque[MatchedFrameInstances],
t: int,
img: np.ndarray,
*args,
@@ -402,42 +361,33 @@ def get_candidates(
# Prune old shifted instances to save time and memory
self.prune_shifted_instances(t)
- # Storing the tracks from the dictionary for counting purpose.
- tracks = []
-
- for track, matched_items in track_matching_queue_dict.items():
- if not max_tracking or len(tracks) < self.max_tracks:
- tracks.append(track)
- for matched_item in matched_items:
- ref_t, ref_img = (
- matched_item.t,
- matched_item.img_t,
- )
- ref_instances = self.get_ref_instances(
- ref_t, ref_img, track_matching_queue_dict
- )
- # Check if shifted instance was computed at earlier time
- if self.save_shifted_instances:
- (
- ref_img,
- ref_instances,
- ) = self.get_shifted_instances_from_earlier_time(
- ref_t, ref_img, ref_instances, t
- )
-
- if len(ref_instances) > 0:
- candidate_instances.extend(
- self.get_shifted_instances(
- ref_instances, ref_img, ref_t, img, t
- )
- )
+ for matched_item in track_matching_queue:
+ ref_t, ref_img, ref_instances = (
+ matched_item.t,
+ matched_item.img_t,
+ matched_item.instances_t,
+ )
+
+ # Check if shifted instance was computed at earlier time
+ if self.save_shifted_instances:
+ (
+ ref_img,
+ ref_instances,
+ ) = self.get_shifted_instances_from_earlier_time(
+ ref_t, ref_img, ref_instances, t
+ )
+
+ if len(ref_instances) > 0:
+ candidate_instances.extend(
+ self.get_shifted_instances(ref_instances, ref_img, ref_t, img, t)
+ )
return candidate_instances
@attr.s(auto_attribs=True)
-class SimpleCandidateMaker:
+class SimpleCandidateMaker(CandidateMaker):
"""Class for producing list of matching candidates from prior frames."""
min_points: int = 0
@@ -447,48 +397,25 @@ def uses_image(self):
return False
def get_candidates(
- self, track_matching_queue: Deque[MatchedFrameInstances], *args, **kwargs
+ self,
+ track_matching_queue: Deque[MatchedFrameInstances],
+ *args,
+ **kwargs,
) -> List[InstanceType]:
- # Build a pool of matchable candidate instances.
+ # Create set of matchable candidate instances from each track.
candidate_instances = []
for matched_item in track_matching_queue:
ref_t, ref_instances = matched_item.t, matched_item.instances_t
for ref_instance in ref_instances:
if ref_instance.n_visible_points >= self.min_points:
candidate_instances.append(ref_instance)
- return candidate_instances
-
-@attr.s(auto_attribs=True)
-class SimpleMaxTracksCandidateMaker(SimpleCandidateMaker):
- """Class to generate instances with maximum number of tracks from prior frames."""
-
- max_tracks: int = None
-
- def get_candidates(
- self,
- track_matching_queue_dict: Dict,
- max_tracking: bool,
- *args,
- **kwargs,
- ) -> List[InstanceType]:
- # Create set of matchable candidate instances from each track.
- candidate_instances = []
- tracks = []
- for track, matched_instances in track_matching_queue_dict.items():
- if not max_tracking or len(tracks) < self.max_tracks:
- tracks.append(track)
- for ref_instance in matched_instances:
- if ref_instance.instance_t.n_visible_points >= self.min_points:
- candidate_instances.append(ref_instance.instance_t)
return candidate_instances
tracker_policies = dict(
simple=SimpleCandidateMaker,
flow=FlowCandidateMaker,
- simplemaxtracks=SimpleMaxTracksCandidateMaker,
- flowmaxtracks=FlowMaxTracksCandidateMaker,
)
similarity_policies = dict(
@@ -508,10 +435,149 @@ def get_candidates(
class BaseTracker(abc.ABC):
"""Abstract base class for tracker."""
+ verbosity: str
+ report_rate: float
+
@property
def is_valid(self):
return False
+ @cached_property
+ def report_period(self) -> float:
+ """Time between progress reports in seconds."""
+ return 1.0 / self.report_rate
+
+ def run_step(self, lf: LabeledFrame) -> LabeledFrame:
+ # Clear the tracks
+ for inst in lf.instances:
+ inst.track = None
+
+ track_args = dict(untracked_instances=lf.instances, t=lf.frame_idx)
+ if self.uses_image:
+ track_args["img"] = lf.video[lf.frame_idx]
+ else:
+ track_args["img"] = None
+
+ return LabeledFrame(
+ frame_idx=lf.frame_idx,
+ video=lf.video,
+ instances=self.track(**track_args),
+ )
+
+ def _run_tracker_json(
+ self,
+ frames: List[LabeledFrame],
+ max_length: int = 30,
+ ) -> Iterator[LabeledFrame]:
+ n_total = len(frames)
+ n_processed = 0
+ n_batch = 0
+ n_recent = deque(maxlen=max_length)
+ elapsed_recent = deque(maxlen=max_length)
+ last_report = time()
+ t0_all = time()
+ t0_batch = time()
+
+ for lf in frames:
+ new_lf = self.run_step(lf)
+
+ # Track timing and progress
+ elapsed_all = time() - t0_all
+ n_processed += 1
+ n_batch += 1
+
+ # Report
+ if time() > last_report + self.report_period:
+ elapsed_batch = time() - t0_batch
+ t0_batch = time()
+
+ # Compute recent rate
+ n_recent.append(n_batch)
+ n_batch = 0
+ elapsed_recent.append(elapsed_batch)
+ rate = sum(n_recent) / sum(elapsed_recent)
+ eta = (n_total - n_processed) / rate
+
+ print(
+ json.dumps(
+ {
+ "n_processed": n_processed,
+ "n_total": n_total,
+ "elapsed": elapsed_all,
+ "rate": rate,
+ "eta": eta,
+ }
+ ),
+ flush=True,
+ )
+ last_report = time()
+
+ yield new_lf
+
+ def _run_tracker_rich(self, frames: List[LabeledFrame]) -> Iterator[LabeledFrame]:
+ with rich.progress.Progress(
+ "{task.description}",
+ rich.progress.BarColumn(),
+ "[progress.percentage]{task.percentage:>3.0f}%",
+ "ETA:",
+ rich.progress.TimeRemainingColumn(),
+ RateColumn(),
+ auto_refresh=False,
+ refresh_per_second=self.report_rate,
+ speed_estimate_period=5,
+ ) as progress:
+ task = progress.add_task("Tracking...", total=len(frames))
+ last_report = time()
+ for lf in frames:
+ new_lf = self.run_step(lf)
+
+ progress.update(task, advance=1)
+
+ # Handle refreshing manually to support notebooks.
+ if time() > last_report + self.report_period:
+ progress.refresh()
+ last_report = time()
+
+ yield new_lf
+
+ def run_tracker(
+ self,
+ frames: List[LabeledFrame],
+ *,
+ verbosity: Optional[str] = None,
+ final_pass: bool = True,
+ ) -> List[LabeledFrame]:
+ """Run the tracker on a set of labeled frames.
+
+ Args:
+ frames: A list of labeled frames with instances.
+
+ Returns:
+ The input frames with the new tracks assigned. If the frames already had tracks,
+ they will be cleared if the tracker has been re-initialized.
+ """
+ # Return original frames if we aren't retracking
+ if not self.is_valid:
+ return frames
+
+ verbosity = verbosity or self.verbosity
+
+ # Run tracking on every frame
+ if verbosity == "rich":
+ new_lfs = list(self._run_tracker_rich(frames))
+
+ elif verbosity == "json":
+ new_lfs = list(self._run_tracker_json(frames))
+
+ else:
+ new_lfs = list(self.run_step(lf) for lf in frames)
+
+ # Run final_pass
+ if final_pass:
+ self.final_pass(new_lfs)
+
+ return new_lfs
+
@abc.abstractmethod
def track(
self,
@@ -560,15 +626,20 @@ class Tracker(BaseTracker):
use a robust quantile similarity score for the track. If the value is 1,
use the max similarity (non-robust). For selecting a robust score,
0.95 is a good value.
- max_tracking: Max tracking is incorporated when this is set to true.
+ max_tracks: Maximum number of tracks. No limit if set to -1.
+ verbosity: Mode of inference progress reporting. If `"rich"` (the
+ default), an updating progress bar is displayed in the console or notebook.
+ If `"json"`, a JSON-serialized message is printed out which can be captured
+ for programmatic progress monitoring. If `"none"`, nothing is displayed
+ during tracking -- this is recommended when running on clusters or headless
+ machines where the output is captured to a log file.
"""
- max_tracks: int = None
track_window: int = 5
similarity_function: Optional[Callable] = instance_similarity
matching_function: Callable = greedy_matching
- candidate_maker: object = attr.ib(factory=FlowCandidateMaker)
- max_tracking: bool = False # To enable maximum tracking.
+ candidate_maker: CandidateMaker = attr.ib(factory=FlowCandidateMaker)
+ max_tracks: int = -1
cleaner: Optional[Callable] = None # TODO: deprecate
target_instance_count: int = 0
@@ -577,15 +648,23 @@ class Tracker(BaseTracker):
robust_best_instance: float = 1.0
min_new_track_points: int = 0
+ prefer_reassigning_track: bool = False
+ allow_reassigning_track: bool = False
+
+ verbosity: str = attr.ib(
+ validator=attr.validators.in_(["none", "rich", "json"]),
+ default="none",
+ )
+ report_rate: float = 2.0
+ #: Hold frames with matched instances as deque of length `track_window`.
track_matching_queue: Deque[MatchedFrameInstances] = attr.ib()
- # Hold track, instances with instances as a deque with length as track_window.
- track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]] = attr.ib(
- factory=dict
- )
spawned_tracks: List[Track] = attr.ib(factory=list)
+ #: Found tracks with last time an instance was found
+ found_tracks: Dict[Track, InstanceType] = attr.ib(factory=dict)
+
save_tracked_instances: bool = False
tracked_instances: Dict[int, List[InstanceType]] = attr.ib(
factory=dict
@@ -593,49 +672,50 @@ class Tracker(BaseTracker):
last_matches: Optional[FrameMatches] = None
+ verbosity: str = attr.ib(
+ validator=attr.validators.in_(["none", "rich", "json"]),
+ default="none",
+ )
+ report_rate: float = 2.0
+
@property
def is_valid(self):
return self.similarity_function is not None
@track_matching_queue.default
- def _init_matching_queue(self):
+ def _init_track_matching_queue(self):
"""Factory for instantiating default matching queue with specified size."""
return deque(maxlen=self.track_window)
- @property
- def has_max_tracking(self) -> bool:
- return isinstance(
- self.candidate_maker,
- (SimpleMaxTracksCandidateMaker, FlowMaxTracksCandidateMaker),
- )
-
def reset_candidates(self):
- if self.has_max_tracking:
- for track in self.track_matching_queue_dict:
- self.track_matching_queue_dict[track] = deque(maxlen=self.track_window)
- else:
- self.track_matching_queue = deque(maxlen=self.track_window)
+ self.track_matching_queue = deque(maxlen=self.track_window)
@property
def unique_tracks_in_queue(self) -> List[Track]:
"""Returns the unique tracks in the matching queue."""
-
- unique_tracks = set()
- if self.has_max_tracking:
- for track in self.track_matching_queue_dict.keys():
- unique_tracks.add(track)
-
- else:
- for match_item in self.track_matching_queue:
- for instance in match_item.instances_t:
- unique_tracks.add(instance.track)
-
- return list(unique_tracks)
+ return {
+ instance.track
+ for item in self.track_matching_queue
+ for instance in item.instances_t
+ }
@property
def uses_image(self):
return getattr(self.candidate_maker, "uses_image", False)
+ def infer_next_timestep(self, t: Optional[int] = None) -> int:
+ """Infer timestep if not provided."""
+ # Timestep was provided
+ if t is not None:
+ return t
+
+ # Default to last timestep + 1 if available.
+ if len(self.track_matching_queue) > 0:
+ return self.track_matching_queue[-1].t + 1
+
+ # Default to 0
+ return 0
+
def track(
self,
untracked_instances: List[InstanceType],
@@ -657,58 +737,22 @@ def track(
return untracked_instances
# Infer timestep if not provided.
- if t is None:
- if self.has_max_tracking:
- if len(self.track_matching_queue_dict) > 0:
-
- # Default to last timestep + 1 if available.
- # Here we find the track that has the most instances.
- track_with_max_instances = max(
- self.track_matching_queue_dict,
- key=lambda track: len(self.track_matching_queue_dict[track]),
- )
- t = (
- self.track_matching_queue_dict[track_with_max_instances][-1].t
- + 1
- )
-
- else:
- t = 0
- else:
- if len(self.track_matching_queue) > 0:
-
- # Default to last timestep + 1 if available.
- t = self.track_matching_queue[-1].t + 1
-
- else:
- t = 0
+ t = self.infer_next_timestep(t)
# Initialize containers for tracked instances at the current timestep.
tracked_instances = []
- # Make cache so similarity function doesn't have to recompute everything.
- # similarity_cache = dict()
-
# Process untracked instances.
if untracked_instances:
-
if self.pre_cull_function:
self.pre_cull_function(untracked_instances)
# Build a pool of matchable candidate instances.
- if self.has_max_tracking:
- candidate_instances = self.candidate_maker.get_candidates(
- track_matching_queue_dict=self.track_matching_queue_dict,
- max_tracking=self.max_tracking,
- t=t,
- img=img,
- )
- else:
- candidate_instances = self.candidate_maker.get_candidates(
- track_matching_queue=self.track_matching_queue,
- t=t,
- img=img,
- )
+ candidate_instances = self.candidate_maker.get_candidates(
+ track_matching_queue=self.track_matching_queue,
+ t=t,
+ img=img,
+ )
# Determine matches for untracked instances in current frame.
frame_matches = FrameMatches.from_candidate_instances(
@@ -724,37 +768,18 @@ def track(
# Set track for each of the matched instances.
tracked_instances.extend(
- self.update_matched_instance_tracks(frame_matches.matches)
+ self.update_matched_instance_tracks(frame_matches.matches, t)
)
- # Spawn a new track for each remaining untracked instance.
+ # Assign unmatched instances to new tracks or already existing tracks
tracked_instances.extend(
self.spawn_for_untracked_instances(frame_matches.unmatched_instances, t)
)
- # Add the tracked instances to the dictionary of matched instances.
- if self.has_max_tracking:
- for tracked_instance in tracked_instances:
- if tracked_instance.track in self.track_matching_queue_dict:
- self.track_matching_queue_dict[tracked_instance.track].append(
- MatchedFrameInstance(t, tracked_instance, img)
- )
- elif (
- not self.max_tracking
- or len(self.track_matching_queue_dict) < self.max_tracks
- ):
- self.track_matching_queue_dict[tracked_instance.track] = deque(
- maxlen=self.track_window
- )
- self.track_matching_queue_dict[tracked_instance.track].append(
- MatchedFrameInstance(t, tracked_instance, img)
- )
-
- else:
- # Add the tracked instances to the matching buffer.
- self.track_matching_queue.append(
- MatchedFrameInstances(t, tracked_instances, img)
- )
+ # Add the tracked instances to the matching buffer.
+ self.track_matching_queue.append(
+ MatchedFrameInstances(t, tracked_instances, img)
+ )
# Save tracked instances internally.
if self.save_tracked_instances:
@@ -762,8 +787,11 @@ def track(
return tracked_instances
- @staticmethod
- def update_matched_instance_tracks(matches: List[Match]) -> List[InstanceType]:
+ def update_matched_instance_tracks(
+ self,
+ matches: List[Match],
+ t: int,
+ ) -> List[InstanceType]:
inst_list = []
for match in matches:
# Assign to track and save.
@@ -774,32 +802,108 @@ def update_matched_instance_tracks(matches: List[Match]) -> List[InstanceType]:
tracking_score=match.score,
)
)
+ # Keep the last instance for this track
+ self.found_tracks[match.track] = t
return inst_list
+ def spawn_new_track(self, inst: InstanceType, t: int) -> Optional[InstanceType]:
+ """Try spawning a new track for instance."""
+ if self.max_tracks >= 0 and len(self.found_tracks) >= self.max_tracks:
+ return None
+
+ # Spawn new track.
+ new_track = Track(spawned_on=t, name=f"track_{len(self.spawned_tracks)}")
+ self.spawned_tracks.append(new_track)
+
+ # After setting track, keep the last instance for this track
+ self.found_tracks[new_track] = t
+
+ # Assign instance to the new track and save.
+ return attr.evolve(inst, track=new_track)
+
+ def assign_to_track(
+ self,
+ inst: InstanceType,
+ t: int,
+ not_assigned_tracks: List[Track],
+ ) -> Optional[InstanceType]:
+ """Try assigning instance to a track from a candidate list (best last).
+
+ `not_assigned_tracks` will be modified.
+ """
+ if len(not_assigned_tracks) == 0:
+ return None
+
+ existing_track = not_assigned_tracks.pop()
+
+ # After setting track, keep the last instance for this track
+ self.found_tracks[existing_track] = t
+
+ # Assign instance to the existing track and save.
+ return attr.evolve(inst, track=existing_track, tracking_score=0)
+
def spawn_for_untracked_instances(
- self, unmatched_instances: List[InstanceType], t: int
+ self,
+ unmatched_instances: List[InstanceType],
+ t: int,
) -> List[InstanceType]:
- results = []
- for inst in unmatched_instances:
+ """Assign an existing track or spawn a new track for untracked instances."""
+ # Early return
+ if len(unmatched_instances) == 0:
+ return []
+
+ # Use the tracks that have not been assigned an instance since track_window
+ not_assigned_tracks_dict = {
+ track: last_t
+ for track, last_t in self.found_tracks.items()
+ if last_t < t - self.track_window
+ }
+ # Sort tracks by last used last, so we can pop the last used track
+ not_assigned_tracks = sorted(
+ not_assigned_tracks_dict,
+ key=self.found_tracks.get,
+ )
- # Skip if this instance is too small to spawn a new track with.
- if inst.n_visible_points < self.min_new_track_points:
- continue
+ # Tracks left to assign (all if max_tracks is negative
+ n_remaining_tracks = (
+ max(0, self.max_tracks - len(not_assigned_tracks))
+ if self.max_tracks >= 0
+ else len(unmatched_instances)
+ )
- # Skip if we've reached the maximum number of tracks.
- if (
- self.has_max_tracking
- and self.max_tracking
- and len(self.track_matching_queue_dict) >= self.max_tracks
- ):
- break
+ # Sort instances by descending instance-level grouping score
+ sorted_instances = sorted(
+ unmatched_instances,
+ key=operator.attrgetter("score"),
+ reverse=True,
+ )[:n_remaining_tracks]
- # Spawn new track.
- new_track = Track(spawned_on=t, name=f"track_{len(self.spawned_tracks)}")
- self.spawned_tracks.append(new_track)
+ results = []
+ for inst in sorted_instances:
+ # Skip if not enough visible nodes to assign to a track or spawn a new track
+ if inst.n_visible_points < self.min_new_track_points:
+ continue
- # Assign instance to the new track and save.
- results.append(attr.evolve(inst, track=new_track))
+ # Try spawning new track (if this is preferred, before reassigning track)
+ if not self.prefer_reassigning_track:
+ matched_inst = self.spawn_new_track(inst, t)
+ if matched_inst is not None:
+ results.append(matched_inst)
+ continue
+
+ # Try assigning to an existing track
+ if self.allow_reassigning_track:
+ matched_inst = self.assign_to_track(inst, t, not_assigned_tracks)
+ if matched_inst is not None:
+ results.append(matched_inst)
+ continue
+
+ # Try spawning new track (if this is not preferred, after reassigning track)
+ if self.prefer_reassigning_track:
+ matched_inst = self.spawn_new_track(inst, t)
+ if matched_inst is not None:
+ results.append(matched_inst)
+ continue
return results
@@ -843,6 +947,8 @@ def make_tracker_by_name(
target_instance_count: int = 0,
pre_cull_to_target: bool = False,
pre_cull_iou_threshold: Optional[float] = None,
+ pre_cull_merge_instances: bool = False,
+ pre_cull_merging_penalty: float = 0.2,
# Post-tracking options to connect broken tracks
post_connect_single_breaks: bool = False,
# TODO: deprecate these post-tracking cleaning options
@@ -853,18 +959,18 @@ def make_tracker_by_name(
kf_node_indices: Optional[list] = None,
# Max tracking options
max_tracks: Optional[int] = None,
- max_tracking: bool = False,
+ prefer_reassigning_track: bool = False,
+ allow_reassigning_track: bool = False,
# Object keypoint similarity options
oks_errors: Optional[list] = None,
oks_score_weighting: bool = False,
oks_normalization: str = "all",
+ progress_reporting: str = "rich",
+ report_rate: float = 2.0,
**kwargs,
) -> BaseTracker:
- # Parse max_tracking arguments, only True if max_tracks is not None and > 0
- max_tracking = max_tracking if max_tracks else False
- if max_tracking and tracker in ("simple", "flow"):
- # Force a candidate maker of 'maxtracks' type
- tracker += "maxtracks"
+ # Parse max_tracks, set to -1 if None
+ max_tracks = max_tracks if max_tracks is not None and max_tracks >= 0 else -1
if tracker.lower() == "none":
candidate_maker = None
@@ -900,9 +1006,6 @@ def make_tracker_by_name(
candidate_maker.save_shifted_instances = save_shifted_instances
candidate_maker.track_window = track_window
- if tracker == "simplemaxtracks" or tracker == "flowmaxtracks":
- candidate_maker.max_tracks = max_tracks
-
cleaner = None
if clean_instance_count:
cleaner = TrackCleaner(
@@ -910,28 +1013,33 @@ def make_tracker_by_name(
)
pre_cull_function = None
- if target_instance_count and pre_cull_to_target:
+ if (target_instance_count and pre_cull_to_target) or pre_cull_merge_instances:
def pre_cull_function(inst_list):
cull_frame_instances(
inst_list,
instance_count=target_instance_count,
iou_threshold=pre_cull_iou_threshold,
+ merge_instances=pre_cull_merge_instances,
+ merging_penalty=pre_cull_merging_penalty,
)
tracker_obj = cls(
track_window=track_window,
robust_best_instance=robust,
min_new_track_points=min_new_track_points,
+ max_tracks=max_tracks,
similarity_function=similarity_function,
matching_function=matching_function,
candidate_maker=candidate_maker,
cleaner=cleaner,
pre_cull_function=pre_cull_function,
- max_tracking=max_tracking,
- max_tracks=max_tracks,
target_instance_count=target_instance_count,
+ allow_reassigning_track=allow_reassigning_track,
+ prefer_reassigning_track=prefer_reassigning_track,
post_connect_single_breaks=post_connect_single_breaks,
+ verbosity=progress_reporting,
+ report_rate=report_rate,
)
if target_instance_count and kf_init_frame_count:
@@ -951,7 +1059,6 @@ def pre_cull_function(inst_list):
@classmethod
def get_by_name_factory_options(cls):
-
options = []
option = dict(name="tracker", default="None")
@@ -961,17 +1068,12 @@ def get_by_name_factory_options(cls):
]
options.append(option)
- option = dict(name="max_tracking", default=False)
- option["type"] = bool
- option["help"] = (
- "If true then the tracker will cap the max number of tracks. "
- "Falls back to false if `max_tracks` is not defined or 0."
- )
- options.append(option)
-
option = dict(name="max_tracks", default=None)
option["type"] = int
- option["help"] = "Maximum number of tracks to be tracked by the tracker."
+ option["help"] = (
+ "Maximum number of tracks to be tracked by the tracker. "
+ "No maximum if set to -1."
+ )
options.append(option)
option = dict(name="target_instance_count", default=0)
@@ -996,6 +1098,22 @@ def get_by_name_factory_options(cls):
)
options.append(option)
+ option = dict(name="pre_cull_merge_instances", default=False)
+ option["type"] = bool
+ option["help"] = (
+ "If True, allow merging instances with non-overlapping visible nodes "
+ "to create new instances *before* tracking."
+ )
+ options.append(option)
+
+ option = dict(name="pre_cull_merging_penalty", default=0.2)
+ option["type"] = float
+ option["help"] = (
+ "A float between 0 and 1. All scores of the merged instances "
+ "are multplied by (1 - penalty)."
+ )
+ options.append(option)
+
option = dict(name="post_connect_single_breaks", default=0)
option["type"] = int
option["help"] = (
@@ -1043,6 +1161,21 @@ def get_by_name_factory_options(cls):
option["help"] = "Minimum number of instance points for spawning new track"
options.append(option)
+ option = dict(name="allow_reassigning_track", default=False)
+ option["type"] = bool
+ option[
+ "help"
+ ] = "Allow assigning existing but unused track to unmatched instances."
+ options.append(option)
+
+ option = dict(name="prefer_reassigning_track", default=False)
+ option["type"] = bool
+ option["help"] = (
+ "Try first to reassign to an existing track before trying to "
+ "spawn a new track with unmatched instances."
+ )
+ options.append(option)
+
option = dict(name="min_match_points", default=0)
option["type"] = int
option["help"] = "Minimum points for match candidates"
@@ -1066,11 +1199,12 @@ def get_by_name_factory_options(cls):
option["help"] = "For optical-flow: Number of pyramid scale levels to consider"
options.append(option)
- option = dict(name="save_shifted_instances", default=0)
+ option = dict(name="save_shifted_instances", default=1)
option["type"] = int
option["help"] = (
"If non-zero and tracking.tracker is set to flow, save the shifted "
- "instances between elapsed frames"
+ "instances between elapsed frames. It uses a bit more of memory but gives "
+ "better instance matches."
)
options.append(option)
@@ -1084,9 +1218,10 @@ def int_list_func(s):
option = dict(name="kf_init_frame_count", default="0")
option["type"] = int
- option[
- "help"
- ] = "For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used."
+ option["help"] = (
+ "For Kalman filter: Number of frames to track with other tracker. "
+ "0 means no Kalman filters will be used."
+ )
options.append(option)
def float_list_func(s):
@@ -1107,9 +1242,10 @@ def float_list_func(s):
option = dict(name="oks_score_weighting", default="0")
option["type"] = int
option["help"] = (
- "For Object Keypoint similarity: if 0 (default), only the distance between the reference "
- "and query keypoint is used to compute the similarity. If 1, each distance is weighted "
- "by the prediction scores of the reference and query keypoint."
+ "For Object Keypoint similarity: if 0 (default), only the distance "
+ "between the reference and query keypoint is used to compute the "
+ "similarity. If 1, each distance is weighted by the prediction scores "
+ "of the reference and query keypoint."
)
options.append(option)
@@ -1117,10 +1253,10 @@ def float_list_func(s):
option["type"] = str
option["options"] = ["all", "ref", "union"]
option["help"] = (
- "For Object Keypoint similarity: Determine how to normalize similarity score. "
+ "Object Keypoint similarity: Determine how to normalize similarity score. "
"If 'all', similarity score is normalized by number of reference points. "
- "If 'ref', similarity score is normalized by number of visible reference points. "
- "If 'union', similarity score is normalized by number of points both visible "
+ "If 'ref', score is normalized by number of visible reference points. "
+ "If 'union', score is normalized by number of points both visible "
"in query and reference instance."
)
options.append(option)
@@ -1140,11 +1276,21 @@ def add_cli_parser_args(cls, parser, arg_scope: str = ""):
else:
arg_name = arg["name"]
- parser.add_argument(
- f"--{arg_name}",
- type=arg["type"],
- help=help_string,
- )
+ if arg["name"] == "tracker":
+ # If default is defined for "tracking.tracker", we cannot detect
+ # mal-formed command line.
+ parser.add_argument(
+ f"--{arg_name}",
+ type=arg["type"],
+ help=help_string,
+ )
+ else:
+ parser.add_argument(
+ f"--{arg_name}",
+ type=arg["type"],
+ help=help_string,
+ default=arg["default"],
+ )
@attr.s(auto_attribs=True)
@@ -1156,19 +1302,6 @@ class FlowTracker(Tracker):
candidate_maker: object = attr.ib(factory=FlowCandidateMaker)
-attr.s(auto_attribs=True)
-
-
-class FlowMaxTracker(Tracker):
- """Pre-configured tracker to use optical flow shifted candidates with max tracks."""
-
- max_tracks: int = attr.ib(kw_only=True)
- similarity_function: Callable = instance_similarity
- matching_function: Callable = greedy_matching
- candidate_maker: object = attr.ib(factory=FlowMaxTracksCandidateMaker)
- max_tracking: bool = True
-
-
@attr.s(auto_attribs=True)
class SimpleTracker(Tracker):
"""A Tracker pre-configured to use simple, non-image-based candidates."""
@@ -1178,17 +1311,6 @@ class SimpleTracker(Tracker):
candidate_maker: object = attr.ib(factory=SimpleCandidateMaker)
-@attr.s(auto_attribs=True)
-class SimpleMaxTracker(Tracker):
- """Pre-configured tracker to use simple, non-image-based candidates with max tracks."""
-
- max_tracks: int = attr.ib(kw_only=True)
- similarity_function: Callable = instance_iou
- matching_function: Callable = hungarian_matching
- candidate_maker: object = attr.ib(factory=SimpleMaxTracksCandidateMaker)
- max_tracking: bool = True
-
-
@attr.s(auto_attribs=True)
class KalmanInitSet:
init_frame_count: int
@@ -1220,7 +1342,6 @@ def add_frame_instances(
# "usuable" instances—i.e., instances with the nodes that we'll track
# using Kalman filters.
elif frame_match.has_only_first_choice_matches:
-
good_instances = [
inst for inst in instances if self.is_usable_instance(inst)
]
@@ -1311,6 +1432,12 @@ class KalmanTracker(BaseTracker):
last_t: int = 0
last_init_t: int = 0
+ verbosity: str = attr.ib(
+ validator=attr.validators.in_(["none", "rich", "json"]),
+ default="none",
+ )
+ report_rate: float = 2.0
+
@property
def is_valid(self):
"""Do we have everything we need to run tracking?"""
@@ -1434,7 +1561,6 @@ def track(
# Check whether we've been getting good results from the Kalman filters.
# First, has it been a while since the filters were initialized?
if self.init_done and (t - self.last_init_t) > self.re_init_cooldown:
-
# If it's been a while, then see if it's also been a while since
# the filters successfully matched tracks to the instances.
if self.kalman_tracker.last_frame_with_tracks < t - self.re_init_after:
@@ -1491,46 +1617,6 @@ def run(self, frames: List[LabeledFrame]):
connect_single_track_breaks(frames, self.instance_count)
-def run_tracker(frames: List[LabeledFrame], tracker: BaseTracker) -> List[LabeledFrame]:
- """Run a tracker on a set of labeled frames.
-
- Args:
- frames: A list of labeled frames with instances.
- tracker: An initialized Tracker.
-
- Returns:
- The input frames with the new tracks assigned. If the frames already had tracks,
- they will be cleared if the tracker has been re-initialized.
- """
- # Return original frames if we aren't retracking
- if not tracker.is_valid:
- return frames
-
- new_lfs = []
-
- # Run tracking on every frame
- for lf in frames:
-
- # Clear the tracks
- for inst in lf.instances:
- inst.track = None
-
- track_args = dict(untracked_instances=lf.instances)
- if tracker.uses_image:
- track_args["img"] = lf.video[lf.frame_idx]
- else:
- track_args["img"] = None
-
- new_lf = LabeledFrame(
- frame_idx=lf.frame_idx,
- video=lf.video,
- instances=tracker.track(**track_args),
- )
- new_lfs.append(new_lf)
-
- return new_lfs
-
-
def retrack():
import argparse
import operator
@@ -1568,8 +1654,7 @@ def retrack():
print(f"Done loading predictions in {time.time() - t0} seconds.")
print("Starting tracker...")
- frames = run_tracker(frames=frames, tracker=tracker)
- tracker.final_pass(frames)
+ frames = tracker.run_tracker(frames=frames)
new_labels = Labels(labeled_frames=frames)
diff --git a/sleap/util.py b/sleap/util.py
index eef762ff4..1e59ea237 100644
--- a/sleap/util.py
+++ b/sleap/util.py
@@ -1,8 +1,10 @@
-"""A miscellaneous set of utility functions.
+"""A miscellaneous set of utility functions.
Try not to put things in here unless they really have no other place.
"""
+from __future__ import annotations
+
import base64
import json
import os
@@ -11,25 +13,40 @@
from collections import defaultdict
from io import BytesIO
from pathlib import Path
-from typing import Any, Dict, Hashable, Iterable, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, Hashable, Iterable, List, Optional
from urllib.parse import unquote, urlparse
from urllib.request import url2pathname
+try:
+ from importlib.resources import files # New in 3.9+
+except ImportError:
+ from importlib_resources import files # TODO(LM): Upgrade to importlib.resources.
+
import attr
import h5py as h5
import numpy as np
import psutil
import rapidjson
+import rich.progress
import yaml
-
-try:
- from importlib.resources import files # New in 3.9+
-except ImportError:
- from importlib_resources import files # TODO(LM): Upgrade to importlib.resources.
from PIL import Image
import sleap.version as sleap_version
+if TYPE_CHECKING:
+ from rich.progress import Task
+
+
+class RateColumn(rich.progress.ProgressColumn):
+ """Renders the progress rate."""
+
+ def render(self, task: Task) -> rich.progress.Text:
+ """Show progress rate."""
+ speed = task.speed
+ if speed is None:
+ return rich.progress.Text("?", style="progress.data.speed")
+ return rich.progress.Text(f"{speed:.1f} FPS", style="progress.data.speed")
+
def json_loads(json_str: str) -> Dict:
"""A simple wrapper around the JSON decoder we are using.
diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py
index fd615ea81..a3852ec20 100644
--- a/tests/nn/test_inference.py
+++ b/tests/nn/test_inference.py
@@ -58,7 +58,6 @@
from sleap.nn.tracking import (
MatchedFrameInstance,
FlowCandidateMaker,
- FlowMaxTracksCandidateMaker,
Tracker,
)
from sleap.instance import Track
@@ -1273,6 +1272,7 @@ def test_make_export_cli():
assert args.max_instances == max_instances
+@pytest.mark.slow()
def test_topdown_predictor_save(
min_centroid_model_path, min_centered_instance_model_path, tmp_path
):
@@ -1315,6 +1315,7 @@ def test_topdown_predictor_save(
)
+@pytest.mark.slow()
def test_topdown_id_predictor_save(
min_centroid_model_path, min_topdown_multiclass_model_path, tmp_path
):
@@ -1361,9 +1362,7 @@ def test_topdown_id_predictor_save(
"output_path,tracker_method",
[
("not_default", "flow"),
- ("not_default", "flowmaxtracks"),
(None, "simple"),
- (None, "simplemaxtracks"),
],
)
def test_retracking(
@@ -1380,7 +1379,6 @@ def test_retracking(
if tracker_method == "flow":
cmd += " --tracking.save_shifted_instances 1"
elif tracker_method == "simplemaxtracks" or tracker_method == "flowmaxtracks":
- cmd += " --tracking.max_tracking 1"
cmd += " --tracking.max_tracks 2"
if output_path == "not_default":
output_path = Path(tmpdir, "tracked_slp.slp")
@@ -1944,8 +1942,8 @@ def test_flow_tracker(centered_pair_predictions_sorted: Labels, tmpdir):
@pytest.mark.parametrize(
"max_tracks, trackername",
[
- (2, "flowmaxtracks"),
- (2, "simplemaxtracks"),
+ (2, "flow"),
+ (2, "simple"),
],
)
def test_max_tracks_matching_queue(
@@ -1953,7 +1951,6 @@ def test_max_tracks_matching_queue(
):
"""Test flow max tracks instance generation."""
labels: Labels = centered_pair_predictions
- max_tracking = True
track_window = 5
# Setup flow max tracker
@@ -1961,11 +1958,10 @@ def test_max_tracks_matching_queue(
tracker=trackername,
track_window=track_window,
save_shifted_instances=True,
- max_tracking=max_tracking,
max_tracks=max_tracks,
)
- tracker.candidate_maker = cast(FlowMaxTracksCandidateMaker, tracker.candidate_maker)
+ tracker.candidate_maker = cast(FlowCandidateMaker, tracker.candidate_maker)
# Run tracking
frames = sorted(labels.labeled_frames, key=lambda lf: lf.frame_idx)
@@ -1978,18 +1974,17 @@ def test_max_tracks_matching_queue(
track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx])
tracker.track(**track_args)
- if trackername == "flowmaxtracks":
+ if trackername == "flow":
# Check that saved instances are pruned to track window
- for key in tracker.candidate_maker.shifted_instances.keys():
+ for key in tracker.candidate_maker.shifted_instances:
assert lf.frame_idx - key[0] <= track_window # Keys are pruned
assert abs(key[0] - key[1]) <= track_window
# Check if the length of each of the tracks is not more than the track window
- for track in tracker.track_matching_queue_dict.keys():
- assert len(tracker.track_matching_queue_dict[track]) <= track_window
+ assert len(tracker.track_matching_queue) <= track_window
# Check if number of tracks that are generated are not more than the maximum tracks
- assert len(tracker.track_matching_queue_dict) <= max_tracks
+ assert len(tracker.unique_tracks_in_queue) <= max_tracks
def test_movenet_inference(movenet_video):
@@ -2012,6 +2007,7 @@ def test_movenet_inference(movenet_video):
assert preds["instance_peaks"].shape == (1, 1, 17, 2)
+@pytest.mark.slow()
def test_movenet_predictor(min_dance_labels, movenet_video):
predictor = MoveNetPredictor.from_trained_models("thunder")
predictor.verbosity = "none"
diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py
index 5786945fb..be5789dc9 100644
--- a/tests/nn/test_tracker_components.py
+++ b/tests/nn/test_tracker_components.py
@@ -15,53 +15,50 @@
from sleap.skeleton import Skeleton
-def tracker_by_name(frames=None, **kwargs):
- t = Tracker.make_tracker_by_name(**kwargs)
- print(kwargs)
- print(t.candidate_maker)
- if frames is None:
- t.track([])
- t.final_pass([])
- return
-
- for lf in frames:
- # Clear the tracks
- for inst in lf.instances:
- inst.track = None
-
- track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx])
- t.track(**track_args)
- t.final_pass(frames)
-
-
-@pytest.mark.parametrize(
- "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"]
-)
+def run_tracker_by_name(frames=None, img_scale: float = 0, **kwargs):
+ # Create tracker
+ t = Tracker.make_tracker_by_name(verbosity="none", **kwargs)
+ # Update img_scale
+ if img_scale:
+ if hasattr(t, "candidate_maker") and hasattr(t.candidate_maker, "img_scale"):
+ t.candidate_maker.img_scale = img_scale
+ else:
+ # Do not even run tracking as it can be slow
+ pytest.skip("img_scale is not defined for this tracker")
+ return
+
+ # Run tracking
+ new_frames = t.run_tracker(frames or [])
+ assert len(new_frames) == len(frames)
+
+
+@pytest.mark.parametrize("tracker", ["simple", "flow"])
@pytest.mark.parametrize("similarity", ["instance", "iou", "centroid"])
@pytest.mark.parametrize("match", ["greedy", "hungarian"])
+@pytest.mark.parametrize("img_scale", [0, 1, 0.25])
@pytest.mark.parametrize("count", [0, 2])
def test_tracker_by_name(
centered_pair_predictions_sorted,
tracker,
similarity,
match,
+ img_scale,
count,
):
# This is slow, so limit to 5 time points
frames = centered_pair_predictions_sorted[:5]
- tracker_by_name(
+ run_tracker_by_name(
frames=frames,
tracker=tracker,
similarity=similarity,
match=match,
+ img_scale=img_scale,
max_tracks=count,
)
-@pytest.mark.parametrize(
- "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"]
-)
+@pytest.mark.parametrize("tracker", ["simple", "flow"])
@pytest.mark.parametrize("oks_score_weighting", ["True", "False"])
@pytest.mark.parametrize("oks_normalization", ["all", "ref", "union"])
def test_oks_tracker_by_name(
@@ -73,7 +70,7 @@ def test_oks_tracker_by_name(
# This is slow, so limit to 5 time points
frames = centered_pair_predictions_sorted[:5]
- tracker_by_name(
+ run_tracker_by_name(
frames=frames,
tracker=tracker,
similarity="object_keypoint",
@@ -244,7 +241,7 @@ def make_inst(x, y):
return insts
-def test_max_tracking_large_gap_single_track():
+def test_max_tracks_large_gap_single_track():
# Track 2 instances with gap > window size
preds = make_insts(
[
@@ -279,11 +276,9 @@ def test_max_tracking_large_gap_single_track():
tracker = Tracker.make_tracker_by_name(
tracker="simple",
- # tracker="simplemaxtracks",
match="hungarian",
track_window=2,
- # max_tracks=2,
- # max_tracking=True,
+ max_tracks=-1,
)
tracked = []
@@ -295,12 +290,10 @@ def test_max_tracking_large_gap_single_track():
assert len(all_tracks) == 3
tracker = Tracker.make_tracker_by_name(
- # tracker="simple",
- tracker="simplemaxtracks",
+ tracker="simple",
match="hungarian",
track_window=2,
max_tracks=2,
- max_tracking=True,
)
tracked = []
@@ -312,7 +305,7 @@ def test_max_tracking_large_gap_single_track():
assert len(all_tracks) == 2
-def test_max_tracking_small_gap_on_both_tracks():
+def test_max_tracks_small_gap_on_both_tracks():
# Test 2 instances with both tracks with gap > window size
preds = make_insts(
[
@@ -343,11 +336,9 @@ def test_max_tracking_small_gap_on_both_tracks():
tracker = Tracker.make_tracker_by_name(
tracker="simple",
- # tracker="simplemaxtracks",
match="hungarian",
track_window=2,
- # max_tracks=2,
- # max_tracking=True,
+ max_tracks=-1,
)
tracked = []
@@ -359,12 +350,10 @@ def test_max_tracking_small_gap_on_both_tracks():
assert len(all_tracks) == 4
tracker = Tracker.make_tracker_by_name(
- # tracker="simple",
- tracker="simplemaxtracks",
+ tracker="simple",
match="hungarian",
track_window=2,
max_tracks=2,
- max_tracking=True,
)
tracked = []
@@ -376,7 +365,7 @@ def test_max_tracking_small_gap_on_both_tracks():
assert len(all_tracks) == 2
-def test_max_tracking_extra_detections():
+def test_max_tracks_extra_detections():
# Test having more than 2 detected instances in a frame
preds = make_insts(
[
@@ -412,11 +401,9 @@ def test_max_tracking_extra_detections():
tracker = Tracker.make_tracker_by_name(
tracker="simple",
- # tracker="simplemaxtracks",
match="hungarian",
track_window=2,
- # max_tracks=2,
- # max_tracking=True,
+ max_tracks=-1,
)
tracked = []
@@ -428,12 +415,10 @@ def test_max_tracking_extra_detections():
assert len(all_tracks) == 4
tracker = Tracker.make_tracker_by_name(
- # tracker="simple",
- tracker="simplemaxtracks",
+ tracker="simple",
match="hungarian",
track_window=2,
max_tracks=2,
- max_tracking=True,
)
tracked = []
diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py
index a6592dc4d..ab8d52b87 100644
--- a/tests/nn/test_tracking_integration.py
+++ b/tests/nn/test_tracking_integration.py
@@ -19,13 +19,13 @@ def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path):
inference_cli(cli.split(" "))
labels = sleap.load_file(f"{tmpdir}/simpletracks.slp")
- assert len(labels.tracks) == 27
+ assert len(labels.tracks) == 8
-def test_simplemax_tracker(tmpdir, centered_pair_predictions_slp_path):
+def test_simple_max_tracks(tmpdir, centered_pair_predictions_slp_path):
cli = (
- "--tracking.tracker simplemaxtracks "
- "--tracking.max_tracking 1 --tracking.max_tracks 2 "
+ "--tracking.tracker simple "
+ "--tracking.max_tracks 2 "
"--frames 200-300 "
f"-o {tmpdir}/simplemaxtracks.slp "
f"{centered_pair_predictions_slp_path}"
@@ -37,18 +37,19 @@ def test_simplemax_tracker(tmpdir, centered_pair_predictions_slp_path):
# TODO: Refactor the below things into a real test suite.
+# running an equivalent to `make_ground_truth` is done as a test in tests/nn/test_tracker_components.py
def make_ground_truth(frames, tracker, gt_filename):
t0 = time.time()
- new_labels = run_tracker(frames, tracker)
+ new_labels = tracker.run_tracker(frames, verbosity="none")
print(f"{gt_filename}\t{len(tracker.spawned_tracks)}\t{time.time()-t0}")
Labels.save_file(new_labels, gt_filename)
def compare_ground_truth(frames, tracker, gt_filename):
t0 = time.time()
- new_labels = run_tracker(frames, tracker)
+ new_labels = tracker.run_tracker(frames, verbosity="none")
print(f"{gt_filename}\t{time.time() - t0}")
does_match = check_tracks(new_labels, gt_filename)
@@ -78,43 +79,6 @@ def check_tracks(labels, gt_filename, limit=None):
return True
-def run_tracker(frames, tracker):
- sig = inspect.signature(tracker.track)
- takes_img = "img" in sig.parameters
-
- # t0 = time.time()
-
- new_lfs = []
-
- # Run tracking on every frame
- for lf in frames:
-
- # Clear the tracks
- for inst in lf.instances:
- inst.track = None
-
- track_args = dict(untracked_instances=lf.instances)
- if takes_img:
- track_args["img"] = lf.video[lf.frame_idx]
- else:
- track_args["img"] = None
-
- new_lf = LabeledFrame(
- frame_idx=lf.frame_idx,
- video=lf.video,
- instances=tracker.track(**track_args),
- )
- new_lfs.append(new_lf)
-
- # if lf.frame_idx % 100 == 0: print(lf.frame_idx, time.time()-t0)
-
- # print(time.time() - t0)
-
- new_labels = Labels()
- new_labels.extend(new_lfs)
- return new_labels
-
-
def main(f, dir):
filename = "tests/data/json_format_v2/centered_pair_predictions.json"
@@ -127,8 +91,6 @@ def main(f, dir):
trackers = dict(
simple=sleap.nn.tracker.simple.SimpleTracker,
flow=sleap.nn.tracker.flow.FlowTracker,
- simplemaxtracks=sleap.nn.tracker.SimpleMaxTracker,
- flowmaxtracks=sleap.nn.tracker.FlowMaxTracker,
)
matchers = dict(
hungarian=sleap.nn.tracker.components.hungarian_matching,
@@ -144,27 +106,21 @@ def main(f, dir):
0.25,
)
- def make_tracker(
- tracker_name, matcher_name, sim_name, max_tracks, max_tracking=False, scale=0
- ):
- if tracker_name == "simplemaxtracks" or tracker_name == "flowmaxtracks":
- tracker = trackers[tracker_name](
- matching_function=matchers[matcher_name],
- similarity_function=similarities[sim_name],
- max_tracks=max_tracks,
- max_tracking=max_tracking,
- )
- else:
- tracker = trackers[tracker_name](
- matching_function=matchers[matcher_name],
- similarity_function=similarities[sim_name],
- )
+ def make_tracker(tracker_name, matcher_name, sim_name, max_tracks, scale=0):
+ tracker = trackers[tracker_name](
+ matching_function=matchers[matcher_name],
+ similarity_function=similarities[sim_name],
+ max_tracks=max_tracks,
+ )
if scale:
tracker.candidate_maker.img_scale = scale
return tracker
def make_filename(tracker_name, matcher_name, sim_name, scale=0):
- return f"{dir}{tracker_name}_{int(scale * 100)}_{matcher_name}_{sim_name}.h5"
+ return os.path.join(
+ dir,
+ f"{tracker_name}_{int(scale * 100)}_{matcher_name}_{sim_name}.h5",
+ )
def make_tracker_and_filename(*args, **kwargs):
tracker = make_tracker(*args, **kwargs)
@@ -178,43 +134,22 @@ def make_tracker_and_filename(*args, **kwargs):
for tracker_name in trackers.keys():
for matcher_name in matchers.keys():
for sim_name in similarities.keys():
-
if tracker_name == "flow":
# If this tracker supports scale, try multiple scales
for scale in scales:
tracker, gt_filename = make_tracker_and_filename(
tracker_name=tracker_name,
matcher_name=matcher_name,
- sim_name=sim_name,
- scale=scale,
- )
- f(frames, tracker, gt_filename)
- elif tracker_name == "flowmaxtracks":
- # If this tracker supports scale, try multiple scales
- for scale in scales:
- tracker, gt_filename = make_tracker_and_filename(
- tracker_name=tracker_name,
- matcher_name=matcher_name,
- sim_name=sim_name,
max_tracks=2,
- max_tracking=True,
+ sim_name=sim_name,
scale=scale,
)
f(frames, tracker, gt_filename)
- elif tracker_name == "simplemaxtracks":
- tracker, gt_filename = make_tracker_and_filename(
- tracker_name=tracker_name,
- matcher_name=matcher_name,
- sim_name=sim_name,
- max_tracks=2,
- max_tracking=True,
- scale=0,
- )
- f(frames, tracker, gt_filename)
else:
tracker, gt_filename = make_tracker_and_filename(
tracker_name=tracker_name,
matcher_name=matcher_name,
+ max_tracks=2,
sim_name=sim_name,
scale=0,
)