Skip to content

Commit

Permalink
simplify max_tracks tracker
Browse files Browse the repository at this point in the history
  • Loading branch information
getzze committed Sep 4, 2024
1 parent 25cd34b commit a7b3a0a
Show file tree
Hide file tree
Showing 8 changed files with 261 additions and 401 deletions.
12 changes: 5 additions & 7 deletions docs/guides/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ optional arguments:

```none
usage: sleap-train [-h] [--video-paths VIDEO_PATHS] [--val_labels VAL_LABELS]
[--test_labels TEST_LABELS] [--tensorboard] [--save_viz]
[--test_labels TEST_LABELS] [--tensorboard] [--save_viz]
[--keep_viz] [--zmq] [--run_name RUN_NAME] [--prefix PREFIX]
[--suffix SUFFIX]
training_job_path [labels_path]
Expand Down Expand Up @@ -124,7 +124,7 @@ usage: sleap-track [-h] [-m MODELS] [--frames FRAMES] [--only-labeled-frames] [-
[--verbosity {none,rich,json}] [--video.dataset VIDEO.DATASET] [--video.input_format VIDEO.INPUT_FORMAT]
[--video.index VIDEO.INDEX] [--cpu | --first-gpu | --last-gpu | --gpu GPU] [--max_edge_length_ratio MAX_EDGE_LENGTH_RATIO]
[--dist_penalty_weight DIST_PENALTY_WEIGHT] [--batch_size BATCH_SIZE] [--open-in-gui] [--peak_threshold PEAK_THRESHOLD]
[-n MAX_INSTANCES] [--tracking.tracker TRACKING.TRACKER] [--tracking.max_tracking TRACKING.MAX_TRACKING]
[-n MAX_INSTANCES] [--tracking.tracker TRACKING.TRACKER]
[--tracking.max_tracks TRACKING.MAX_TRACKS] [--tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT]
[--tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET] [--tracking.pre_cull_iou_threshold TRACKING.PRE_CULL_IOU_THRESHOLD]
[--tracking.post_connect_single_breaks TRACKING.POST_CONNECT_SINGLE_BREAKS]
Expand Down Expand Up @@ -187,10 +187,8 @@ optional arguments:
Limit maximum number of instances in multi-instance models. Not available for ID models. Defaults to None.
--tracking.tracker TRACKING.TRACKER
Options: simple, flow, simplemaxtracks, flowmaxtracks, None (default: None)
--tracking.max_tracking TRACKING.MAX_TRACKING
If true then the tracker will cap the max number of tracks. (default: False)
--tracking.max_tracks TRACKING.MAX_TRACKS
Maximum number of tracks to be tracked by the tracker. (default: None)
Maximum number of tracks to be tracked by the tracker. No limit if None or -1. (default: None)
--tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT
Target number of instances to track per frame. (default: 0)
--tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET
Expand Down Expand Up @@ -264,13 +262,13 @@ sleap-track -m "models/my_model" --tracking.tracker simple -o "output_prediction
**5. Inference with max tracks limit:**

```none
sleap-track -m "models/my_model" --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 -o "output_predictions.slp" "input_video.mp4"
sleap-track -m "models/my_model" --tracking.tracker simple --tracking.max_tracks 4 -o "output_predictions.slp" "input_video.mp4"
```

**6. Re-tracking without pose inference:**

```none
sleap-track --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 -o "retracked.slp" "input_predictions.slp"
sleap-track --tracking.tracker simple --tracking.max_tracks 4 -o "retracked.slp" "input_predictions.slp"
```

**7. Select GPU for pose inference:**
Expand Down
4 changes: 0 additions & 4 deletions sleap/config/pipeline_form.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -521,10 +521,6 @@ inference:
text: '<b>Tracking</b>:<br />
This tracker assigns track identities by matching instances from prior
frames to instances on subsequent frames.'
# - name: tracking.max_tracking
# label: Limit max number of tracks
# type: bool
# default: false
- name: tracking.max_tracks
label: Max number of tracks
type: optional_int
Expand Down
23 changes: 8 additions & 15 deletions sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def make_predict_cli_call(

# Make path where we'll save predictions (if not specified)
if output_path is None:

if self.labels_filename:
# Make a predictions directory next to the labels dataset file
predictions_dir = os.path.join(
Expand Down Expand Up @@ -239,15 +238,12 @@ def make_predict_cli_call(
if key in self.inference_params and self.inference_params[key] is None:
del self.inference_params[key]

# Setting max_tracks to True means we want to use the max_tracking mode.
if "tracking.max_tracks" in self.inference_params:
self.inference_params["tracking.max_tracking"] = True

# Hacky: Update the tracker name to include "maxtracks" suffix.
if self.inference_params["tracking.tracker"] in ("simple", "flow"):
self.inference_params["tracking.tracker"] = (
self.inference_params["tracking.tracker"] + "maxtracks"
)
# Compatibility with using the "maxtracks" suffix to the tracker name.
if "tracking.tracker" in self.inference_params:
compat_trackers = ("simplemaxtracks", "flowmaxtracks")
if self.inference_params["tracking.tracker"] in compat_trackers:
tname = self.inference_params["tracking.tracker"][: -len("maxtracks")]
self.inference_params["tracking.tracker"] = tname

# --tracking.kf_init_frame_count enables the kalman filter tracking
# so if not set, then remove other (unused) args
Expand All @@ -257,10 +253,11 @@ def make_predict_cli_call(

bool_items_as_ints = (
"tracking.pre_cull_to_target",
"tracking.max_tracking",
"tracking.post_connect_single_breaks",
"tracking.save_shifted_instances",
"tracking.oks_score_weighting",
"tracking.prefer_reassigning_track",
"tracking.allow_reassigning_track",
)

for key in bool_items_as_ints:
Expand Down Expand Up @@ -303,10 +300,8 @@ def predict_subprocess(

# Run inference CLI capturing output.
with subprocess.Popen(cli_args, stdout=subprocess.PIPE) as proc:

# Poll until finished.
while proc.poll() is None:

# Read line.
line = proc.stdout.readline()
line = line.decode().rstrip()
Expand Down Expand Up @@ -635,7 +630,6 @@ def run_gui_training(

for config_info in config_info_list:
if config_info.dont_retrain:

if not config_info.has_trained_model:
raise ValueError(
"Config is set to not retrain but no trained model found: "
Expand Down Expand Up @@ -849,7 +843,6 @@ def train_subprocess(
success = False

with tempfile.TemporaryDirectory() as temp_dir:

# Write a temporary file of the TrainingJob so that we can respect
# any changed made to the job attributes after it was loaded.
temp_filename = datetime.now().strftime("%y%m%d_%H%M%S") + "_training_job.json"
Expand Down
6 changes: 0 additions & 6 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4947,16 +4947,10 @@ def unpack_sleap_model(model_path):
)
predictor.verbosity = progress_reporting
if tracker is not None:
use_max_tracker = tracker_max_instances is not None
if use_max_tracker and not tracker.endswith("maxtracks"):
# Append maxtracks to the tracker name to use the right tracker variants.
tracker += "maxtracks"

predictor.tracker = Tracker.make_tracker_by_name(
tracker=tracker,
track_window=tracker_window,
post_connect_single_breaks=True,
max_tracking=use_max_tracker,
max_tracks=tracker_max_instances,
# clean_instance_count=tracker_max_instances,
)
Expand Down
Loading

0 comments on commit a7b3a0a

Please sign in to comment.