Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the max_tracking code #1896

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
12 changes: 5 additions & 7 deletions docs/guides/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ optional arguments:

```none
usage: sleap-train [-h] [--video-paths VIDEO_PATHS] [--val_labels VAL_LABELS]
[--test_labels TEST_LABELS] [--tensorboard] [--save_viz]
[--test_labels TEST_LABELS] [--tensorboard] [--save_viz]
[--keep_viz] [--zmq] [--run_name RUN_NAME] [--prefix PREFIX]
[--suffix SUFFIX]
training_job_path [labels_path]
Expand Down Expand Up @@ -124,7 +124,7 @@ usage: sleap-track [-h] [-m MODELS] [--frames FRAMES] [--only-labeled-frames] [-
[--verbosity {none,rich,json}] [--video.dataset VIDEO.DATASET] [--video.input_format VIDEO.INPUT_FORMAT]
[--video.index VIDEO.INDEX] [--cpu | --first-gpu | --last-gpu | --gpu GPU] [--max_edge_length_ratio MAX_EDGE_LENGTH_RATIO]
[--dist_penalty_weight DIST_PENALTY_WEIGHT] [--batch_size BATCH_SIZE] [--open-in-gui] [--peak_threshold PEAK_THRESHOLD]
[-n MAX_INSTANCES] [--tracking.tracker TRACKING.TRACKER] [--tracking.max_tracking TRACKING.MAX_TRACKING]
[-n MAX_INSTANCES] [--tracking.tracker TRACKING.TRACKER]
[--tracking.max_tracks TRACKING.MAX_TRACKS] [--tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT]
[--tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET] [--tracking.pre_cull_iou_threshold TRACKING.PRE_CULL_IOU_THRESHOLD]
[--tracking.post_connect_single_breaks TRACKING.POST_CONNECT_SINGLE_BREAKS]
Expand Down Expand Up @@ -187,10 +187,8 @@ optional arguments:
Limit maximum number of instances in multi-instance models. Not available for ID models. Defaults to None.
--tracking.tracker TRACKING.TRACKER
Options: simple, flow, simplemaxtracks, flowmaxtracks, None (default: None)
--tracking.max_tracking TRACKING.MAX_TRACKING
If true then the tracker will cap the max number of tracks. (default: False)
--tracking.max_tracks TRACKING.MAX_TRACKS
Maximum number of tracks to be tracked by the tracker. (default: None)
Maximum number of tracks to be tracked by the tracker. No limit if None or -1. (default: None)
--tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT
Target number of instances to track per frame. (default: 0)
--tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET
Expand Down Expand Up @@ -264,13 +262,13 @@ sleap-track -m "models/my_model" --tracking.tracker simple -o "output_prediction
**5. Inference with max tracks limit:**

```none
sleap-track -m "models/my_model" --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 -o "output_predictions.slp" "input_video.mp4"
sleap-track -m "models/my_model" --tracking.tracker simple --tracking.max_tracks 4 -o "output_predictions.slp" "input_video.mp4"
```

**6. Re-tracking without pose inference:**

```none
sleap-track --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 -o "retracked.slp" "input_predictions.slp"
sleap-track --tracking.tracker simple --tracking.max_tracks 4 -o "retracked.slp" "input_predictions.slp"
```

**7. Select GPU for pose inference:**
Expand Down
18 changes: 6 additions & 12 deletions sleap/config/pipeline_form.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -424,10 +424,6 @@ inference:
This tracker "shifts" instances from previous frames using optical flow
before matching instances in each frame to the <i>shifted</i> instances from
prior frames.'
# - name: tracking.max_tracking
# label: Limit max number of tracks
# type: bool
default: false
- name: tracking.max_tracks
label: Max number of tracks
type: optional_int
Expand Down Expand Up @@ -459,10 +455,12 @@ inference:
none_label: Use max (non-robust)
range: 0,1
default: 0.95
# - name: tracking.save_shifted_instances
# label: Save shifted instances
# type: bool
# default: false
- name: tracking.save_shifted_instances
label: Save shifted instances
help: 'Save the flow-shifted instances between elapsed frames. It improves
instance matching at the cost of using a bit more of memory.'
type: bool
default: true
- type: text
text: '<b>Kalman filter-based tracking</b>:<br />
Uses the above tracking options to track instances for an initial
Expand Down Expand Up @@ -523,10 +521,6 @@ inference:
text: '<b>Tracking</b>:<br />
This tracker assigns track identities by matching instances from prior
frames to instances on subsequent frames.'
# - name: tracking.max_tracking
# label: Limit max number of tracks
# type: bool
# default: false
- name: tracking.max_tracks
label: Max number of tracks
type: optional_int
Expand Down
2 changes: 1 addition & 1 deletion sleap/gui/learning/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ def run(self):
# count < 0 means there was an error and we didn't get any results.
if new_counts is not None and new_counts >= 0:
total_count = items_for_inference.total_frame_count
no_result_count = total_count - new_counts
no_result_count = max(0, total_count - new_counts)

message = (
f"Inference ran on {total_count} frames."
Expand Down
24 changes: 9 additions & 15 deletions sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def make_predict_cli_call(

# Make path where we'll save predictions (if not specified)
if output_path is None:

if self.labels_filename:
# Make a predictions directory next to the labels dataset file
predictions_dir = os.path.join(
Expand Down Expand Up @@ -239,15 +238,12 @@ def make_predict_cli_call(
if key in self.inference_params and self.inference_params[key] is None:
del self.inference_params[key]

# Setting max_tracks to True means we want to use the max_tracking mode.
if "tracking.max_tracks" in self.inference_params:
self.inference_params["tracking.max_tracking"] = True

# Hacky: Update the tracker name to include "maxtracks" suffix.
if self.inference_params["tracking.tracker"] in ("simple", "flow"):
self.inference_params["tracking.tracker"] = (
self.inference_params["tracking.tracker"] + "maxtracks"
)
# Compatibility with using the "maxtracks" suffix to the tracker name.
if "tracking.tracker" in self.inference_params:
compat_trackers = ("simplemaxtracks", "flowmaxtracks")
if self.inference_params["tracking.tracker"] in compat_trackers:
tname = self.inference_params["tracking.tracker"][: -len("maxtracks")]
self.inference_params["tracking.tracker"] = tname

# --tracking.kf_init_frame_count enables the kalman filter tracking
# so if not set, then remove other (unused) args
Expand All @@ -257,10 +253,12 @@ def make_predict_cli_call(

bool_items_as_ints = (
"tracking.pre_cull_to_target",
"tracking.max_tracking",
"tracking.pre_cull_merge_instances",
"tracking.post_connect_single_breaks",
"tracking.save_shifted_instances",
"tracking.oks_score_weighting",
"tracking.prefer_reassigning_track",
"tracking.allow_reassigning_track",
)

for key in bool_items_as_ints:
Expand Down Expand Up @@ -303,10 +301,8 @@ def predict_subprocess(

# Run inference CLI capturing output.
with subprocess.Popen(cli_args, stdout=subprocess.PIPE) as proc:

# Poll until finished.
while proc.poll() is None:

# Read line.
line = proc.stdout.readline()
line = line.decode().rstrip()
Expand Down Expand Up @@ -635,7 +631,6 @@ def run_gui_training(

for config_info in config_info_list:
if config_info.dont_retrain:

if not config_info.has_trained_model:
raise ValueError(
"Config is set to not retrain but no trained model found: "
Expand Down Expand Up @@ -849,7 +844,6 @@ def train_subprocess(
success = False

with tempfile.TemporaryDirectory() as temp_dir:

# Write a temporary file of the TrainingJob so that we can respect
# any changed made to the job attributes after it was loaded.
temp_filename = datetime.now().strftime("%y%m%d_%H%M%S") + "_training_job.json"
Expand Down
4 changes: 3 additions & 1 deletion sleap/gui/widgets/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,8 @@ def __init__(self, state=None, player=None, *args, **kwargs):
self.click_mode = ""
self.in_zoom = False

self._down_pos = None

self.zoomFactor = 1
anchor_mode = QGraphicsView.AnchorUnderMouse
self.setTransformationAnchor(anchor_mode)
Expand Down Expand Up @@ -1039,7 +1041,7 @@ def mouseReleaseEvent(self, event):
scenePos = self.mapToScene(event.pos())

# check if mouse moved during click
has_moved = event.pos() != self._down_pos
has_moved = self._down_pos is not None and event.pos() != self._down_pos
if event.button() == Qt.LeftButton:

if self.in_zoom:
Expand Down
73 changes: 72 additions & 1 deletion sleap/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
"""

import math
from functools import reduce
from itertools import chain, combinations

import numpy as np
import cattr

from copy import copy
from typing import Dict, List, Optional, Union, Tuple, ForwardRef
from typing import Dict, List, Optional, Union, Sequence, Tuple

from numpy.lib.recfunctions import structured_to_unstructured

Expand Down Expand Up @@ -1177,6 +1179,75 @@ def from_numpy(
)


def all_disjoint(x: Sequence[Sequence]) -> bool:
return all((set(p0).isdisjoint(set(p1))) for p0, p1 in combinations(x, 2))


def create_merged_instances(
instances: List[PredictedInstance],
penalty: float = 0.2,
) -> List[PredictedInstance]:
"""Create merged instances from the list of PredictedInstance.

Only instances with non-overlapping visible nodes are merged.

Args:
instances: a list of original PredictedInstances to try to merge.
penalty: a float between 0 and 1. All scores of the merged instance
are multplied by (1 - penalty).

Returns:
a list of PredictedInstance that were merged.
"""
# Ensure same skeleton
skeletons = {inst.skeleton for inst in instances}
if len(skeletons) != 1:
return []
skeleton = list(skeletons)[0]

# Ensure same track
tracks = {inst.track for inst in instances}
if len(tracks) != 1:
return []
track = list(tracks)[0]

# Ensure non-intersecting visible nodes
merged_instances = []
instance_subsets = (
combinations(instances, n) for n in range(2, len(instances) + 1)
)
instance_subsets = chain.from_iterable(instance_subsets)
for subset in instance_subsets:
nodes = [s.nodes for s in subset]
if not all_disjoint(nodes):
continue

nodes_points_gen = chain.from_iterable(
instance.nodes_points for instance in subset
)
predicted_points = {node: point for node, point in nodes_points_gen}

instance_score = reduce(lambda x, y: x * y, [s.score for s in subset])

# Penalize scores of merged instances
if 0 < penalty <= 1:
factor = 1 - penalty
instance_score *= factor
for point in predicted_points.values():
point.score *= factor

merged_instance = PredictedInstance(
points=predicted_points,
skeleton=skeleton,
score=instance_score,
track=track,
)

merged_instances.append(merged_instance)

return merged_instances


def make_instance_cattr() -> cattr.Converter:
"""Create a cattr converter for Lists of Instances/PredictedInstances.

Expand Down
Loading
Loading