From 0a8d5d285b33949eed820b6e9a0f8eacc5f6d132 Mon Sep 17 00:00:00 2001 From: getzze Date: Wed, 31 Jul 2024 10:47:08 +0100 Subject: [PATCH 01/15] improve ETA precision --- sleap/nn/inference.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 14e0d5c6f..57f53ad45 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -418,6 +418,9 @@ def process_batch(ex): return ex + # Compile loop examples before starting time to improve ETA + examples = self.pipeline.make_dataset() + # Loop over data batches with optional progress reporting. if self.verbosity == "rich": with rich.progress.Progress( @@ -433,7 +436,7 @@ def process_batch(ex): ) as progress: task = progress.add_task("Predicting...", total=len(data_provider)) last_report = time() - for ex in self.pipeline.make_dataset(): + for ex in examples: ex = process_batch(ex) progress.update(task, advance=len(ex["frame_ind"])) @@ -453,7 +456,7 @@ def process_batch(ex): last_report = time() t0_all = time() t0_batch = time() - for ex in self.pipeline.make_dataset(): + for ex in examples: # Process batch of examples. ex = process_batch(ex) @@ -490,7 +493,7 @@ def process_batch(ex): # Return results. yield ex else: - for ex in self.pipeline.make_dataset(): + for ex in examples: yield process_batch(ex) def predict( From 2b26078074f6be8a58deb0eb09873980069820d9 Mon Sep 17 00:00:00 2001 From: getzze Date: Wed, 31 Jul 2024 11:32:34 +0100 Subject: [PATCH 02/15] add tracking progress reporting --- sleap/nn/inference.py | 27 ++-- sleap/nn/tracking.py | 225 +++++++++++++++++++------- sleap/util.py | 33 +++- tests/nn/test_tracker_components.py | 39 ++--- tests/nn/test_tracking_integration.py | 50 +----- 5 files changed, 235 insertions(+), 139 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 57f53ad45..4e969c16b 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -46,6 +46,11 @@ from threading import Thread from queue import Queue +if sys.version_info >= (3, 8): + from functools import cached_property +else: # cached_property is define only for python >=3.8 + cached_property = property + import tensorflow as tf import numpy as np @@ -54,7 +59,7 @@ from sleap.nn.config import TrainingJobConfig, DataConfig from sleap.nn.data.resizing import SizeMatcher from sleap.nn.model import Model -from sleap.nn.tracking import Tracker, run_tracker +from sleap.nn.tracking import Tracker from sleap.nn.paf_grouping import PAFScorer from sleap.nn.data.pipelines import ( Provider, @@ -69,7 +74,7 @@ ) from sleap.nn.utils import reset_input_layer from sleap.io.dataset import Labels -from sleap.util import frame_list, make_scoped_dictionary +from sleap.util import frame_list, make_scoped_dictionary, RateColumn from sleap.instance import PredictedInstance, LabeledFrame from tensorflow.python.framework.convert_to_constants import ( @@ -144,17 +149,6 @@ def get_keras_model_path(path: Text) -> str: return os.path.join(path, "best_model.h5") -class RateColumn(rich.progress.ProgressColumn): - """Renders the progress rate.""" - - def render(self, task: "Task") -> rich.progress.Text: - """Show progress rate.""" - speed = task.speed - if speed is None: - return rich.progress.Text("?", style="progress.data.speed") - return rich.progress.Text(f"{speed:.1f} FPS", style="progress.data.speed") - - @attr.s(auto_attribs=True) class Predictor(ABC): """Base interface class for predictors.""" @@ -167,7 +161,7 @@ class Predictor(ABC): report_rate: float = attr.ib(default=2.0, kw_only=True) model_paths: List[str] = attr.ib(factory=list, kw_only=True) - @property + @cached_property def report_period(self) -> float: """Time between progress reports in seconds.""" return 1.0 / self.report_rate @@ -5487,7 +5481,10 @@ def _make_tracker_from_cli(args: argparse.Namespace) -> Optional[Tracker]: """ policy_args = make_scoped_dictionary(vars(args), exclude_nones=True) if "tracking" in policy_args: - tracker = Tracker.make_tracker_by_name(**policy_args["tracking"]) + tracker = Tracker.make_tracker_by_name( + progress_reporting=args.verbosity, + **policy_args["tracking"], + ) return tracker return None diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 558aa9309..6d7c423b5 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -1,12 +1,16 @@ """Tracking tools for linking grouped instances over time.""" -from collections import deque, defaultdict import abc +import json +import sys +from collections import deque +from time import time +from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple + import attr -import numpy as np import cv2 -import functools -from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple +import numpy as np +import rich.progress from sleap import Track, LabeledFrame, Skeleton @@ -26,8 +30,13 @@ Match, ) from sleap.nn.tracker.kalman import BareKalmanTracker - from sleap.nn.data.normalization import ensure_int +from sleap.util import RateColumn + +if sys.version_info >= (3, 8): + from functools import cached_property +else: # cached_property is define only for python >=3.8 + cached_property = property @attr.s(eq=False, slots=True, auto_attribs=True) @@ -66,7 +75,6 @@ def from_instance( shift_score: float = 0.0, with_skeleton: bool = False, ): - points_array = new_points_array if points_array is None: points_array = ref_instance.points_array @@ -511,10 +519,142 @@ def get_candidates( class BaseTracker(abc.ABC): """Abstract base class for tracker.""" + verbosity: str + report_rate: float + @property def is_valid(self): return False + @cached_property + def report_period(self) -> float: + """Time between progress reports in seconds.""" + return 1.0 / self.report_rate + + def run_step(self, lf: LabeledFrame) -> LabeledFrame: + # Clear the tracks + for inst in lf.instances: + inst.track = None + + track_args = dict(untracked_instances=lf.instances, t=lf.frame_idx) + if self.uses_image: + track_args["img"] = lf.video[lf.frame_idx] + else: + track_args["img"] = None + track_args["img_hw"] = lf.image.shape[-3:-1] + + return LabeledFrame( + frame_idx=lf.frame_idx, + video=lf.video, + instances=self.track(**track_args), + ) + + def run_tracker( + self, + frames: List[LabeledFrame], + *, + verbosity: Optional[str] = None, + final_pass: bool = True, + ) -> List[LabeledFrame]: + """Run the tracker on a set of labeled frames. + + Args: + frames: A list of labeled frames with instances. + + Returns: + The input frames with the new tracks assigned. If the frames already had tracks, + they will be cleared if the tracker has been re-initialized. + """ + # Return original frames if we aren't retracking + if not self.is_valid: + return frames + + verbosity = verbosity or self.verbosity + new_lfs = [] + + # Run tracking on every frame + if verbosity == "rich": + with rich.progress.Progress( + "{task.description}", + rich.progress.BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + "ETA:", + rich.progress.TimeRemainingColumn(), + RateColumn(), + auto_refresh=False, + refresh_per_second=self.report_rate, + speed_estimate_period=5, + ) as progress: + task = progress.add_task("Tracking...", total=len(frames)) + last_report = time() + for lf in frames: + new_lf = self.run_step(lf) + new_lfs.append(new_lf) + + progress.update(task, advance=1) + + # Handle refreshing manually to support notebooks. + elapsed_since_last_report = time() - last_report + if elapsed_since_last_report > self.report_period: + progress.refresh() + + elif verbosity == "json": + n_total = len(frames) + n_processed = 0 + n_batch = 0 + elapsed_all = 0 + n_recent = deque(maxlen=30) + elapsed_recent = deque(maxlen=30) + last_report = time() + t0_all = time() + t0_batch = time() + for lf in frames: + new_lf = self.run_step(lf) + new_lfs.append(new_lf) + + # Track timing and progress. + elapsed_all = time() - t0_all + n_processed += 1 + n_batch += 1 + + # Report. + elapsed_since_last_report = time() - last_report + if elapsed_since_last_report > self.report_period: + elapsed_batch = time() - t0_batch + t0_batch = time() + + # Compute recent rate. + n_recent.append(n_batch) + n_batch = 0 + elapsed_recent.append(elapsed_batch) + rate = sum(n_recent) / sum(elapsed_recent) + eta = (n_total - n_processed) / rate + + print( + json.dumps( + { + "n_processed": n_processed, + "n_total": n_total, + "elapsed": elapsed_all, + "rate": rate, + "eta": eta, + } + ), + flush=True, + ) + last_report = time() + + else: + for lf in frames: + new_lf = self.run_step(lf) + new_lfs.append(new_lf) + + # Run final_pass + if final_pass: + self.final_pass(new_lfs) + + return new_lfs + @abc.abstractmethod def track( self, @@ -564,6 +704,12 @@ class Tracker(BaseTracker): use the max similarity (non-robust). For selecting a robust score, 0.95 is a good value. max_tracking: Max tracking is incorporated when this is set to true. + verbosity: Mode of inference progress reporting. If `"rich"` (the + default), an updating progress bar is displayed in the console or notebook. + If `"json"`, a JSON-serialized message is printed out which can be captured + for programmatic progress monitoring. If `"none"`, nothing is displayed + during tracking -- this is recommended when running on clusters or headless + machines where the output is captured to a log file. """ max_tracks: int = None @@ -596,6 +742,12 @@ class Tracker(BaseTracker): last_matches: Optional[FrameMatches] = None + verbosity: str = attr.ib( + validator=attr.validators.in_(["none", "rich", "json"]), + default="none", + ) + report_rate: float = 2.0 + @property def is_valid(self): return self.similarity_function is not None @@ -670,7 +822,6 @@ def track( if t is None: if self.has_max_tracking: if len(self.track_matching_queue_dict) > 0: - # Default to last timestep + 1 if available. # Here we find the track that has the most instances. track_with_max_instances = max( @@ -686,7 +837,6 @@ def track( t = 0 else: if len(self.track_matching_queue) > 0: - # Default to last timestep + 1 if available. t = self.track_matching_queue[-1].t + 1 @@ -701,7 +851,6 @@ def track( # Process untracked instances. if untracked_instances: - if self.pre_cull_function: self.pre_cull_function(untracked_instances) @@ -791,7 +940,6 @@ def spawn_for_untracked_instances( ) -> List[InstanceType]: results = [] for inst in unmatched_instances: - # Skip if this instance is too small to spawn a new track with. if inst.n_visible_points < self.min_new_track_points: continue @@ -868,6 +1016,8 @@ def make_tracker_by_name( oks_errors: Optional[list] = None, oks_score_weighting: bool = False, oks_normalization: str = "all", + progress_reporting: str = "rich", + report_rate: float = 2.0, **kwargs, ) -> BaseTracker: # Parse max_tracking arguments, only True if max_tracks is not None and > 0 @@ -942,6 +1092,8 @@ def pre_cull_function(inst_list): max_tracks=max_tracks, target_instance_count=target_instance_count, post_connect_single_breaks=post_connect_single_breaks, + verbosity=progress_reporting, + report_rate=report_rate, ) if target_instance_count and kf_init_frame_count: @@ -961,7 +1113,6 @@ def pre_cull_function(inst_list): @classmethod def get_by_name_factory_options(cls): - options = [] option = dict(name="tracker", default="None") @@ -1230,7 +1381,6 @@ def add_frame_instances( # "usuable" instances—i.e., instances with the nodes that we'll track # using Kalman filters. elif frame_match.has_only_first_choice_matches: - good_instances = [ inst for inst in instances if self.is_usable_instance(inst) ] @@ -1321,6 +1471,12 @@ class KalmanTracker(BaseTracker): last_t: int = 0 last_init_t: int = 0 + verbosity: str = attr.ib( + validator=attr.validators.in_(["none", "rich", "json"]), + default="none", + ) + report_rate: float = 2.0 + @property def is_valid(self): """Do we have everything we need to run tracking?""" @@ -1444,7 +1600,6 @@ def track( # Check whether we've been getting good results from the Kalman filters. # First, has it been a while since the filters were initialized? if self.init_done and (t - self.last_init_t) > self.re_init_cooldown: - # If it's been a while, then see if it's also been a while since # the filters successfully matched tracks to the instances. if self.kalman_tracker.last_frame_with_tracks < t - self.re_init_after: @@ -1501,47 +1656,6 @@ def run(self, frames: List[LabeledFrame]): connect_single_track_breaks(frames, self.instance_count) -def run_tracker(frames: List[LabeledFrame], tracker: BaseTracker) -> List[LabeledFrame]: - """Run a tracker on a set of labeled frames. - - Args: - frames: A list of labeled frames with instances. - tracker: An initialized Tracker. - - Returns: - The input frames with the new tracks assigned. If the frames already had tracks, - they will be cleared if the tracker has been re-initialized. - """ - # Return original frames if we aren't retracking - if not tracker.is_valid: - return frames - - new_lfs = [] - - # Run tracking on every frame - for lf in frames: - - # Clear the tracks - for inst in lf.instances: - inst.track = None - - track_args = dict(untracked_instances=lf.instances) - if tracker.uses_image: - track_args["img"] = lf.video[lf.frame_idx] - else: - track_args["img"] = None - track_args["img_hw"] = lf.image.shape[-3:-1] - - new_lf = LabeledFrame( - frame_idx=lf.frame_idx, - video=lf.video, - instances=tracker.track(**track_args), - ) - new_lfs.append(new_lf) - - return new_lfs - - def retrack(): import argparse import operator @@ -1579,8 +1693,7 @@ def retrack(): print(f"Done loading predictions in {time.time() - t0} seconds.") print("Starting tracker...") - frames = run_tracker(frames=frames, tracker=tracker) - tracker.final_pass(frames) + frames = tracker.run_tracker(frames=frames) new_labels = Labels(labeled_frames=frames) diff --git a/sleap/util.py b/sleap/util.py index bc3389b7d..b92606868 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -1,32 +1,51 @@ -"""A miscellaneous set of utility functions. +"""A miscellaneous set of utility functions. Try not to put things in here unless they really have no other place. """ +from __future__ import annotations + +import base64 import json import os import re import shutil from collections import defaultdict from pathlib import Path -from typing import Any, Dict, Hashable, Iterable, List, Optional +from typing import TYPE_CHECKING, Any, Dict, Hashable, Iterable, List, Optional from urllib.parse import unquote, urlparse from urllib.request import url2pathname +try: + from importlib.resources import files # New in 3.9+ +except ImportError: + from importlib_resources import files # TODO(LM): Upgrade to importlib.resources. + import attr import h5py as h5 import numpy as np import psutil import rapidjson +import rich.progress import yaml - -try: - from importlib.resources import files # New in 3.9+ -except ImportError: - from importlib_resources import files # TODO(LM): Upgrade to importlib.resources. +from PIL import Image import sleap.version as sleap_version +if TYPE_CHECKING: + from rich.progress import Task + + +class RateColumn(rich.progress.ProgressColumn): + """Renders the progress rate.""" + + def render(self, task: Task) -> rich.progress.Text: + """Show progress rate.""" + speed = task.speed + if speed is None: + return rich.progress.Text("?", style="progress.data.speed") + return rich.progress.Text(f"{speed:.1f} FPS", style="progress.data.speed") + def json_loads(json_str: str) -> Dict: """A simple wrapper around the JSON decoder we are using. diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index 0c7ba2b0a..9d3b65b38 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -15,23 +15,21 @@ from sleap.skeleton import Skeleton -def tracker_by_name(frames=None, **kwargs): - t = Tracker.make_tracker_by_name(**kwargs) - print(kwargs) - print(t.candidate_maker) - if frames is None: - t.track([]) - t.final_pass([]) - return - - for lf in frames: - # Clear the tracks - for inst in lf.instances: - inst.track = None - - track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx]) - t.track(**track_args, img_hw=(1, 1)) - t.final_pass(frames) +def run_tracker_by_name(frames=None, img_scale: float = 0, **kwargs): + # Create tracker + t = Tracker.make_tracker_by_name(verbosity="none", **kwargs) + # Update img_scale + if img_scale: + if hasattr(t, "candidate_maker") and hasattr(t.candidate_maker, "img_scale"): + t.candidate_maker.img_scale = img_scale + else: + # Do not even run tracking as it can be slow + pytest.skip("img_scale is not defined for this tracker") + return + + # Run tracking + new_frames = t.run_tracker(frames or []) + assert len(new_frames) == len(frames) @pytest.mark.parametrize( @@ -42,22 +40,25 @@ def tracker_by_name(frames=None, **kwargs): ["instance", "normalized_instance", "iou", "centroid", "object_keypoint"], ) @pytest.mark.parametrize("match", ["greedy", "hungarian"]) +@pytest.mark.parametrize("img_scale", [0, 1, 0.25]) @pytest.mark.parametrize("count", [0, 2]) def test_tracker_by_name( centered_pair_predictions_sorted, tracker, similarity, match, + img_scale, count, ): # This is slow, so limit to 5 time points frames = centered_pair_predictions_sorted[:5] - tracker_by_name( + run_tracker_by_name( frames=frames, tracker=tracker, similarity=similarity, match=match, + img_scale=img_scale, max_tracks=count, ) @@ -76,7 +77,7 @@ def test_oks_tracker_by_name( # This is slow, so limit to 5 time points frames = centered_pair_predictions_sorted[:5] - tracker_by_name( + run_tracker_by_name( frames=frames, tracker=tracker, similarity="object_keypoint", diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index 625302fd0..c7c25476d 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -19,7 +19,7 @@ def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path): inference_cli(cli.split(" ")) labels = sleap.load_file(f"{tmpdir}/simpletracks.slp") - assert len(labels.tracks) == 27 + assert len(labels.tracks) == 8 def test_simplemax_tracker(tmpdir, centered_pair_predictions_slp_path): @@ -37,18 +37,19 @@ def test_simplemax_tracker(tmpdir, centered_pair_predictions_slp_path): # TODO: Refactor the below things into a real test suite. +# running an equivalent to `make_ground_truth` is done as a test in tests/nn/test_tracker_components.py def make_ground_truth(frames, tracker, gt_filename): t0 = time.time() - new_labels = run_tracker(frames, tracker) + new_labels = tracker.run_tracker(frames, verbosity="none") print(f"{gt_filename}\t{len(tracker.spawned_tracks)}\t{time.time()-t0}") Labels.save_file(new_labels, gt_filename) def compare_ground_truth(frames, tracker, gt_filename): t0 = time.time() - new_labels = run_tracker(frames, tracker) + new_labels = tracker.run_tracker(frames, verbosity="none") print(f"{gt_filename}\t{time.time() - t0}") does_match = check_tracks(new_labels, gt_filename) @@ -78,43 +79,6 @@ def check_tracks(labels, gt_filename, limit=None): return True -def run_tracker(frames, tracker): - sig = inspect.signature(tracker.track) - takes_img = "img" in sig.parameters - - # t0 = time.time() - - new_lfs = [] - - # Run tracking on every frame - for lf in frames: - - # Clear the tracks - for inst in lf.instances: - inst.track = None - - track_args = dict(untracked_instances=lf.instances) - if takes_img: - track_args["img"] = lf.video[lf.frame_idx] - else: - track_args["img"] = None - - new_lf = LabeledFrame( - frame_idx=lf.frame_idx, - video=lf.video, - instances=tracker.track(**track_args, img_hw=lf.image.shape[-3:-1]), - ) - new_lfs.append(new_lf) - - # if lf.frame_idx % 100 == 0: print(lf.frame_idx, time.time()-t0) - - # print(time.time() - t0) - - new_labels = Labels() - new_labels.extend(new_lfs) - return new_labels - - def main(f, dir): filename = "tests/data/json_format_v2/centered_pair_predictions.json" @@ -166,7 +130,10 @@ def make_tracker( return tracker def make_filename(tracker_name, matcher_name, sim_name, scale=0): - return f"{dir}{tracker_name}_{int(scale * 100)}_{matcher_name}_{sim_name}.h5" + return os.path.join( + dir, + f"{tracker_name}_{int(scale * 100)}_{matcher_name}_{sim_name}.h5", + ) def make_tracker_and_filename(*args, **kwargs): tracker = make_tracker(*args, **kwargs) @@ -180,7 +147,6 @@ def make_tracker_and_filename(*args, **kwargs): for tracker_name in trackers.keys(): for matcher_name in matchers.keys(): for sim_name in similarities.keys(): - if tracker_name == "flow": # If this tracker supports scale, try multiple scales for scale in scales: From 135fc66051fcf375f11b016d8c8cdfc2eaeee51e Mon Sep 17 00:00:00 2001 From: getzze Date: Fri, 2 Aug 2024 00:04:34 +0100 Subject: [PATCH 03/15] make sure tracker cli defaults are passed --- sleap/nn/tracking.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 6d7c423b5..b97a3c5f0 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -1301,11 +1301,21 @@ def add_cli_parser_args(cls, parser, arg_scope: str = ""): else: arg_name = arg["name"] - parser.add_argument( - f"--{arg_name}", - type=arg["type"], - help=help_string, - ) + if arg["name"] == "tracker": + # If default is defined for "tracking.tracker", we cannot detect + # mal-formed command line. + parser.add_argument( + f"--{arg_name}", + type=arg["type"], + help=help_string, + ) + else: + parser.add_argument( + f"--{arg_name}", + type=arg["type"], + help=help_string, + default=arg["default"], + ) @attr.s(auto_attribs=True) From 9262f9530ec3526adf1d1b40f294881ecbdf9038 Mon Sep 17 00:00:00 2001 From: getzze Date: Wed, 31 Jul 2024 15:28:27 +0100 Subject: [PATCH 04/15] fix bug with video GUI --- sleap/gui/widgets/video.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sleap/gui/widgets/video.py b/sleap/gui/widgets/video.py index 949703020..4c2370e09 100644 --- a/sleap/gui/widgets/video.py +++ b/sleap/gui/widgets/video.py @@ -816,6 +816,8 @@ def __init__(self, state=None, player=None, *args, **kwargs): self.click_mode = "" self.in_zoom = False + self._down_pos = None + self.zoomFactor = 1 anchor_mode = QGraphicsView.AnchorUnderMouse self.setTransformationAnchor(anchor_mode) @@ -1039,7 +1041,7 @@ def mouseReleaseEvent(self, event): scenePos = self.mapToScene(event.pos()) # check if mouse moved during click - has_moved = event.pos() != self._down_pos + has_moved = self._down_pos is not None and event.pos() != self._down_pos if event.button() == Qt.LeftButton: if self.in_zoom: From 57983df2d180f57145e5a554103927fe0693bdda Mon Sep 17 00:00:00 2001 From: getzze Date: Wed, 31 Jul 2024 15:28:52 +0100 Subject: [PATCH 05/15] only positive numbers --- sleap/gui/learning/dialog.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/gui/learning/dialog.py b/sleap/gui/learning/dialog.py index 2c2617036..7a9c94358 100644 --- a/sleap/gui/learning/dialog.py +++ b/sleap/gui/learning/dialog.py @@ -733,7 +733,7 @@ def run(self): # count < 0 means there was an error and we didn't get any results. if new_counts is not None and new_counts >= 0: total_count = items_for_inference.total_frame_count - no_result_count = total_count - new_counts + no_result_count = max(0, total_count - new_counts) message = ( f"Inference ran on {total_count} frames." From f564c1f3553e86c2821724074d862028438ce700 Mon Sep 17 00:00:00 2001 From: getzze Date: Fri, 2 Aug 2024 13:15:45 +0100 Subject: [PATCH 06/15] access save_shifted_instances from GUI change default save_shifted_instances --- sleap/config/pipeline_form.yaml | 14 ++++++-------- sleap/nn/tracking.py | 27 +++++++++++++++------------ 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml index 1bb930e58..cbe00e397 100644 --- a/sleap/config/pipeline_form.yaml +++ b/sleap/config/pipeline_form.yaml @@ -424,10 +424,6 @@ inference: This tracker "shifts" instances from previous frames using optical flow before matching instances in each frame to the shifted instances from prior frames.' - # - name: tracking.max_tracking - # label: Limit max number of tracks - # type: bool - default: false - name: tracking.max_tracks label: Max number of tracks type: optional_int @@ -459,10 +455,12 @@ inference: none_label: Use max (non-robust) range: 0,1 default: 0.95 - # - name: tracking.save_shifted_instances - # label: Save shifted instances - # type: bool - # default: false + - name: tracking.save_shifted_instances + label: Save shifted instances + help: 'Save the flow-shifted instances between elapsed frames. It improves + instance matching at the cost of using a bit more of memory.' + type: bool + default: true - type: text text: 'Kalman filter-based tracking:
Uses the above tracking options to track instances for an initial diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index b97a3c5f0..68141767a 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -139,7 +139,7 @@ class FlowCandidateMaker: img_scale: float = 1.0 of_window_size: int = 21 of_max_levels: int = 3 - save_shifted_instances: bool = False + save_shifted_instances: bool = True track_window: int = 5 shifted_instances: Dict[ @@ -1227,11 +1227,12 @@ def get_by_name_factory_options(cls): option["help"] = "For optical-flow: Number of pyramid scale levels to consider" options.append(option) - option = dict(name="save_shifted_instances", default=0) + option = dict(name="save_shifted_instances", default=1) option["type"] = int option["help"] = ( "If non-zero and tracking.tracker is set to flow, save the shifted " - "instances between elapsed frames" + "instances between elapsed frames. It uses a bit more of memory but gives " + "better instance matches." ) options.append(option) @@ -1245,9 +1246,10 @@ def int_list_func(s): option = dict(name="kf_init_frame_count", default="0") option["type"] = int - option[ - "help" - ] = "For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used." + option["help"] = ( + "For Kalman filter: Number of frames to track with other tracker. " + "0 means no Kalman filters will be used." + ) options.append(option) def float_list_func(s): @@ -1268,9 +1270,10 @@ def float_list_func(s): option = dict(name="oks_score_weighting", default="0") option["type"] = int option["help"] = ( - "For Object Keypoint similarity: if 0 (default), only the distance between the reference " - "and query keypoint is used to compute the similarity. If 1, each distance is weighted " - "by the prediction scores of the reference and query keypoint." + "For Object Keypoint similarity: if 0 (default), only the distance " + "between the reference and query keypoint is used to compute the " + "similarity. If 1, each distance is weighted by the prediction scores " + "of the reference and query keypoint." ) options.append(option) @@ -1278,10 +1281,10 @@ def float_list_func(s): option["type"] = str option["options"] = ["all", "ref", "union"] option["help"] = ( - "For Object Keypoint similarity: Determine how to normalize similarity score. " + "Object Keypoint similarity: Determine how to normalize similarity score. " "If 'all', similarity score is normalized by number of reference points. " - "If 'ref', similarity score is normalized by number of visible reference points. " - "If 'union', similarity score is normalized by number of points both visible " + "If 'ref', score is normalized by number of visible reference points. " + "If 'union', score is normalized by number of points both visible " "in query and reference instance." ) options.append(option) From 074a90bd38729d49d721c2c833251b1599aaaeb2 Mon Sep 17 00:00:00 2001 From: getzze Date: Wed, 31 Jul 2024 15:29:17 +0100 Subject: [PATCH 07/15] allow tracking only a range of frame indices --- sleap/nn/inference.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 4e969c16b..80423634e 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -5387,7 +5387,9 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]: ) ) else: - provider_list.append(LabelsReader(labels)) + provider_list.append( + LabelsReader(labels, example_indices=frame_list(args.frames)) + ) data_path_list.append(file_path) @@ -5651,14 +5653,16 @@ def main(args: Optional[list] = None): data_path = data_path_list[0] # Load predictions - data_path = args.data_path print("Loading predictions...") - labels_pr = sleap.load_file(data_path) + labels_pr = sleap.load_file(data_path.as_posix()) frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx) + if provider.example_indices is not None: + # Convert indices to a set to search in O(1), otherwise it is much slower + index_set = set(provider.example_indices) + frames = list(filter(lambda lf: lf.frame_idx in index_set, frames)) print("Starting tracker...") - frames = run_tracker(frames=frames, tracker=tracker) - tracker.final_pass(frames) + frames = tracker.run_tracker(frames=frames) labels_pr = Labels(labeled_frames=frames) @@ -5679,7 +5683,7 @@ def main(args: Optional[list] = None): labels_pr.provenance["sleap_version"] = sleap.__version__ labels_pr.provenance["platform"] = platform.platform() labels_pr.provenance["command"] = " ".join(sys.argv) - labels_pr.provenance["data_path"] = data_path + labels_pr.provenance["data_path"] = os.fspath(data_path) labels_pr.provenance["output_path"] = output_path labels_pr.provenance["total_elapsed"] = total_elapsed labels_pr.provenance["start_timestamp"] = start_timestamp From 1d4b92a25bb43e6b8bf3bb1a8195b19bf756683f Mon Sep 17 00:00:00 2001 From: getzze Date: Tue, 3 Sep 2024 16:34:44 +0100 Subject: [PATCH 08/15] refactor tracking progress --- sleap/nn/tracking.py | 152 +++++++++++++++++++++++-------------------- 1 file changed, 80 insertions(+), 72 deletions(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 68141767a..1374cd5d6 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -5,7 +5,7 @@ import sys from collections import deque from time import time -from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple +from typing import Callable, Deque, Dict, Iterable, Iterator, List, Optional, Tuple import attr import cv2 @@ -549,6 +549,82 @@ def run_step(self, lf: LabeledFrame) -> LabeledFrame: instances=self.track(**track_args), ) + def _run_tracker_json( + self, + frames: List[LabeledFrame], + max_length: int = 30, + ) -> Iterator[LabeledFrame]: + n_total = len(frames) + n_processed = 0 + n_batch = 0 + n_recent = deque(maxlen=max_length) + elapsed_recent = deque(maxlen=max_length) + last_report = time() + t0_all = time() + t0_batch = time() + + for lf in frames: + new_lf = self.run_step(lf) + + # Track timing and progress + elapsed_all = time() - t0_all + n_processed += 1 + n_batch += 1 + + # Report + if time() > last_report + self.report_period: + elapsed_batch = time() - t0_batch + t0_batch = time() + + # Compute recent rate + n_recent.append(n_batch) + n_batch = 0 + elapsed_recent.append(elapsed_batch) + rate = sum(n_recent) / sum(elapsed_recent) + eta = (n_total - n_processed) / rate + + print( + json.dumps( + { + "n_processed": n_processed, + "n_total": n_total, + "elapsed": elapsed_all, + "rate": rate, + "eta": eta, + } + ), + flush=True, + ) + last_report = time() + + yield new_lf + + def _run_tracker_rich(self, frames: List[LabeledFrame]) -> Iterator[LabeledFrame]: + with rich.progress.Progress( + "{task.description}", + rich.progress.BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + "ETA:", + rich.progress.TimeRemainingColumn(), + RateColumn(), + auto_refresh=False, + refresh_per_second=self.report_rate, + speed_estimate_period=5, + ) as progress: + task = progress.add_task("Tracking...", total=len(frames)) + last_report = time() + for lf in frames: + new_lf = self.run_step(lf) + + progress.update(task, advance=1) + + # Handle refreshing manually to support notebooks. + if time() > last_report + self.report_period: + progress.refresh() + last_report = time() + + yield new_lf + def run_tracker( self, frames: List[LabeledFrame], @@ -570,84 +646,16 @@ def run_tracker( return frames verbosity = verbosity or self.verbosity - new_lfs = [] # Run tracking on every frame if verbosity == "rich": - with rich.progress.Progress( - "{task.description}", - rich.progress.BarColumn(), - "[progress.percentage]{task.percentage:>3.0f}%", - "ETA:", - rich.progress.TimeRemainingColumn(), - RateColumn(), - auto_refresh=False, - refresh_per_second=self.report_rate, - speed_estimate_period=5, - ) as progress: - task = progress.add_task("Tracking...", total=len(frames)) - last_report = time() - for lf in frames: - new_lf = self.run_step(lf) - new_lfs.append(new_lf) - - progress.update(task, advance=1) - - # Handle refreshing manually to support notebooks. - elapsed_since_last_report = time() - last_report - if elapsed_since_last_report > self.report_period: - progress.refresh() + new_lfs = list(self._run_tracker_rich(frames)) elif verbosity == "json": - n_total = len(frames) - n_processed = 0 - n_batch = 0 - elapsed_all = 0 - n_recent = deque(maxlen=30) - elapsed_recent = deque(maxlen=30) - last_report = time() - t0_all = time() - t0_batch = time() - for lf in frames: - new_lf = self.run_step(lf) - new_lfs.append(new_lf) - - # Track timing and progress. - elapsed_all = time() - t0_all - n_processed += 1 - n_batch += 1 - - # Report. - elapsed_since_last_report = time() - last_report - if elapsed_since_last_report > self.report_period: - elapsed_batch = time() - t0_batch - t0_batch = time() - - # Compute recent rate. - n_recent.append(n_batch) - n_batch = 0 - elapsed_recent.append(elapsed_batch) - rate = sum(n_recent) / sum(elapsed_recent) - eta = (n_total - n_processed) / rate - - print( - json.dumps( - { - "n_processed": n_processed, - "n_total": n_total, - "elapsed": elapsed_all, - "rate": rate, - "eta": eta, - } - ), - flush=True, - ) - last_report = time() + new_lfs = list(self._run_tracker_json(frames)) else: - for lf in frames: - new_lf = self.run_step(lf) - new_lfs.append(new_lf) + new_lfs = list(self.run_step(lf) for lf in frames) # Run final_pass if final_pass: From 7b7ad1b850b33a07387fda6560047a3d58966ac7 Mon Sep 17 00:00:00 2001 From: getzze Date: Tue, 3 Sep 2024 16:35:03 +0100 Subject: [PATCH 09/15] refactor inference progress --- sleap/nn/inference.py | 212 ++++++++++++++++++++++++------------------ 1 file changed, 122 insertions(+), 90 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 80423634e..1c3b2d619 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -368,6 +368,122 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: def _initialize_inference_model(self): pass + def _process_batch(self, ex: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """Run prediction model on batch. + + This method handles running inference on a batch and postprocessing. + + Args: + ex: a dictionary holding the input for inference. + + Returns: + The input dictionary updated with the predictions. + """ + # Skip inference if model is not loaded + if self.inference_model is None: + return ex + + # Run inference on current batch. + preds = self.inference_model.predict_on_batch(ex, numpy=True) + + # Add model outputs to the input data example. + ex.update(preds) + + # Convert to numpy arrays if not already. + if isinstance(ex["video_ind"], tf.Tensor): + ex["video_ind"] = ex["video_ind"].numpy().flatten() + if isinstance(ex["frame_ind"], tf.Tensor): + ex["frame_ind"] = ex["frame_ind"].numpy().flatten() + + # Adjust for potential SizeMatcher scaling. + offset_x = ex.get("offset_x", 0) + offset_y = ex.get("offset_y", 0) + ex["instance_peaks"] -= np.reshape([offset_x, offset_y], [-1, 1, 1, 2]) + ex["instance_peaks"] /= np.expand_dims( + np.expand_dims(ex["scale"], axis=1), axis=1 + ) + + return ex + + def _run_batch_json( + self, + examples: List[Dict[str, np.ndarray]], + n_total: int, + max_length: int = 30, + ) -> Iterator[Dict[str, np.ndarray]]: + n_processed = 0 + n_recent = deque(maxlen=max_length) + elapsed_recent = deque(maxlen=max_length) + last_report = time() + t0_all = time() + t0_batch = time() + for ex in examples: + # Process batch of examples. + ex = self._process_batch(ex) + + # Track timing and progress. + elapsed_batch = time() - t0_batch + t0_batch = time() + n_batch = len(ex["frame_ind"]) + n_processed += n_batch + elapsed_all = time() - t0_all + + # Compute recent rate. + n_recent.append(n_batch) + elapsed_recent.append(elapsed_batch) + rate = sum(n_recent) / sum(elapsed_recent) + eta = (n_total - n_processed) / rate + + # Report. + if time() > last_report + self.report_period: + print( + json.dumps( + { + "n_processed": n_processed, + "n_total": n_total, + "elapsed": elapsed_all, + "rate": rate, + "eta": eta, + } + ), + flush=True, + ) + last_report = time() + + # Return results. + yield ex + + def _run_batch_rich( + self, + examples: List[Dict[str, np.ndarray]], + n_total: int, + ) -> Iterator[Dict[str, np.ndarray]]: + with rich.progress.Progress( + "{task.description}", + rich.progress.BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + "ETA:", + rich.progress.TimeRemainingColumn(), + RateColumn(), + auto_refresh=False, + refresh_per_second=self.report_rate, + speed_estimate_period=5, + ) as progress: + task = progress.add_task("Predicting...", total=n_total) + last_report = time() + for ex in examples: + ex = self._process_batch(ex) + + progress.update(task, advance=len(ex["frame_ind"])) + + # Handle refreshing manually to support notebooks. + if time() > last_report + self.report_period: + progress.refresh() + last_report = time() + + # Return results. + yield ex + def _predict_generator( self, data_provider: Provider ) -> Iterator[Dict[str, np.ndarray]]: @@ -389,106 +505,22 @@ def _predict_generator( if self.inference_model is None: self._initialize_inference_model() - def process_batch(ex): - # Run inference on current batch. - preds = self.inference_model.predict_on_batch(ex, numpy=True) - - # Add model outputs to the input data example. - ex.update(preds) - - # Convert to numpy arrays if not already. - if isinstance(ex["video_ind"], tf.Tensor): - ex["video_ind"] = ex["video_ind"].numpy().flatten() - if isinstance(ex["frame_ind"], tf.Tensor): - ex["frame_ind"] = ex["frame_ind"].numpy().flatten() - - # Adjust for potential SizeMatcher scaling. - offset_x = ex.get("offset_x", 0) - offset_y = ex.get("offset_y", 0) - ex["instance_peaks"] -= np.reshape([offset_x, offset_y], [-1, 1, 1, 2]) - ex["instance_peaks"] /= np.expand_dims( - np.expand_dims(ex["scale"], axis=1), axis=1 - ) - - return ex - # Compile loop examples before starting time to improve ETA + n_total = len(data_provider) examples = self.pipeline.make_dataset() # Loop over data batches with optional progress reporting. if self.verbosity == "rich": - with rich.progress.Progress( - "{task.description}", - rich.progress.BarColumn(), - "[progress.percentage]{task.percentage:>3.0f}%", - "ETA:", - rich.progress.TimeRemainingColumn(), - RateColumn(), - auto_refresh=False, - refresh_per_second=self.report_rate, - speed_estimate_period=5, - ) as progress: - task = progress.add_task("Predicting...", total=len(data_provider)) - last_report = time() - for ex in examples: - ex = process_batch(ex) - progress.update(task, advance=len(ex["frame_ind"])) - - # Handle refreshing manually to support notebooks. - elapsed_since_last_report = time() - last_report - if elapsed_since_last_report > self.report_period: - progress.refresh() - - # Return results. - yield ex + for ex in self._run_batch_rich(examples, n_total=n_total): + yield ex elif self.verbosity == "json": - n_processed = 0 - n_total = len(data_provider) - n_recent = deque(maxlen=30) - elapsed_recent = deque(maxlen=30) - last_report = time() - t0_all = time() - t0_batch = time() - for ex in examples: - # Process batch of examples. - ex = process_batch(ex) - - # Track timing and progress. - elapsed_batch = time() - t0_batch - t0_batch = time() - n_batch = len(ex["frame_ind"]) - n_processed += n_batch - elapsed_all = time() - t0_all - - # Compute recent rate. - n_recent.append(n_batch) - elapsed_recent.append(elapsed_batch) - rate = sum(n_recent) / sum(elapsed_recent) - eta = (n_total - n_processed) / rate - - # Report. - elapsed_since_last_report = time() - last_report - if elapsed_since_last_report > self.report_period: - print( - json.dumps( - { - "n_processed": n_processed, - "n_total": n_total, - "elapsed": elapsed_all, - "rate": rate, - "eta": eta, - } - ), - flush=True, - ) - last_report = time() - - # Return results. + for ex in self._run_batch_json(examples, n_total=n_total): yield ex + else: for ex in examples: - yield process_batch(ex) + yield self._process_batch(ex) def predict( self, data: Union[Provider, sleap.Labels, sleap.Video], make_labels: bool = True From b36855c79b3cb0c2b960a88d927e5f039c58bcf4 Mon Sep 17 00:00:00 2001 From: getzze Date: Wed, 11 Sep 2024 12:49:01 +0100 Subject: [PATCH 10/15] default save_shifted_instances to False --- sleap/config/pipeline_form.yaml | 2 +- sleap/nn/tracking.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml index cbe00e397..ed05b91f8 100644 --- a/sleap/config/pipeline_form.yaml +++ b/sleap/config/pipeline_form.yaml @@ -460,7 +460,7 @@ inference: help: 'Save the flow-shifted instances between elapsed frames. It improves instance matching at the cost of using a bit more of memory.' type: bool - default: true + default: false - type: text text: 'Kalman filter-based tracking:
Uses the above tracking options to track instances for an initial diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 1374cd5d6..746ab2881 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -139,7 +139,7 @@ class FlowCandidateMaker: img_scale: float = 1.0 of_window_size: int = 21 of_max_levels: int = 3 - save_shifted_instances: bool = True + save_shifted_instances: bool = False track_window: int = 5 shifted_instances: Dict[ @@ -1235,7 +1235,7 @@ def get_by_name_factory_options(cls): option["help"] = "For optical-flow: Number of pyramid scale levels to consider" options.append(option) - option = dict(name="save_shifted_instances", default=1) + option = dict(name="save_shifted_instances", default=0) option["type"] = int option["help"] = ( "If non-zero and tracking.tracker is set to flow, save the shifted " From 182c6bf4b58f1cda79dbf94a1353add051e11814 Mon Sep 17 00:00:00 2001 From: getzze Date: Wed, 11 Sep 2024 12:58:26 +0100 Subject: [PATCH 11/15] use pathlib.Path --- tests/nn/test_tracking_integration.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index c7c25476d..caebe49ff 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -2,6 +2,7 @@ import operator import os import time +from pathlib import Path import sleap from sleap.nn.inference import main as inference_cli @@ -130,9 +131,8 @@ def make_tracker( return tracker def make_filename(tracker_name, matcher_name, sim_name, scale=0): - return os.path.join( - dir, - f"{tracker_name}_{int(scale * 100)}_{matcher_name}_{sim_name}.h5", + return Path(dir).joinpath( + f"{tracker_name}_{int(scale * 100)}_{matcher_name}_{sim_name}.h5" ) def make_tracker_and_filename(*args, **kwargs): From 3b7a75b13f79bd4eb3f33346fe11a9c54ad942e2 Mon Sep 17 00:00:00 2001 From: getzze Date: Mon, 30 Sep 2024 12:57:27 +0100 Subject: [PATCH 12/15] add img_hw arg to Tracker.track --- sleap/nn/tracking.py | 2 ++ tests/nn/test_tracker_components.py | 1 - tests/nn/test_tracking_integration.py | 3 +-- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 746ab2881..d79c67d75 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -667,6 +667,7 @@ def run_tracker( def track( self, untracked_instances: List[InstanceType], + img_hw: Tuple[int], img: Optional[np.ndarray] = None, t: int = None, ): @@ -1561,6 +1562,7 @@ def cull_function(inst_list): def track( self, untracked_instances: List[InstanceType], + img_hw: Tuple[int], img: Optional[np.ndarray] = None, t: int = None, ) -> List[InstanceType]: diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index 9d3b65b38..fa0cc5f51 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -9,7 +9,6 @@ FrameMatches, greedy_matching, ) -from sleap.io.dataset import Labels from sleap.instance import PredictedInstance from sleap.skeleton import Skeleton diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index caebe49ff..c479462f8 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -1,4 +1,3 @@ -import inspect import operator import os import time @@ -7,7 +6,7 @@ import sleap from sleap.nn.inference import main as inference_cli import sleap.nn.tracker.components -from sleap.io.dataset import Labels, LabeledFrame +from sleap.io.dataset import Labels def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path): From 5ef39f8abc4dc327f30d9fe56ec023d0a92bb110 Mon Sep 17 00:00:00 2001 From: getzze Date: Fri, 25 Oct 2024 15:58:43 +0100 Subject: [PATCH 13/15] remove unused imports --- sleap/util.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sleap/util.py b/sleap/util.py index b92606868..e4d0c1eb7 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -5,7 +5,6 @@ from __future__ import annotations -import base64 import json import os import re @@ -28,7 +27,6 @@ import rapidjson import rich.progress import yaml -from PIL import Image import sleap.version as sleap_version From 347578df8de27d6a3a5b662ba6eb46f95441c9b6 Mon Sep 17 00:00:00 2001 From: getzze Date: Fri, 25 Oct 2024 16:25:46 +0100 Subject: [PATCH 14/15] coderabbit suggestion for division by zero error, cache_property and attr default --- sleap/nn/inference.py | 10 ++++++++-- sleap/nn/tracking.py | 31 +++++++++++++++---------------- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 1c3b2d619..14d14e7d8 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -48,8 +48,12 @@ if sys.version_info >= (3, 8): from functools import cached_property -else: # cached_property is define only for python >=3.8 - cached_property = property + +else: # cached_property is defined only for python >=3.8 + from functools import lru_cache + + def cached_property(func): + return property(lru_cache()(func)) import tensorflow as tf import numpy as np @@ -164,6 +168,8 @@ class Predictor(ABC): @cached_property def report_period(self) -> float: """Time between progress reports in seconds.""" + if self.report_rate <= 0: + raise ValueError("report_rate must be positive") return 1.0 / self.report_rate @classmethod diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index d79c67d75..b65bb2f90 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -1,6 +1,7 @@ """Tracking tools for linking grouped instances over time.""" import abc +import functools import json import sys from collections import deque @@ -35,8 +36,12 @@ if sys.version_info >= (3, 8): from functools import cached_property -else: # cached_property is define only for python >=3.8 - cached_property = property + +else: # cached_property is defined only for python >=3.8 + from functools import lru_cache + + def cached_property(func): + return property(lru_cache()(func)) @attr.s(eq=False, slots=True, auto_attribs=True) @@ -519,8 +524,12 @@ def get_candidates( class BaseTracker(abc.ABC): """Abstract base class for tracker.""" - verbosity: str - report_rate: float + verbosity: str = attr.ib( + validator=attr.validators.in_(["none", "rich", "json"]), + default="none", + kw_only=True, + ) + report_rate: float = attr.ib(default=2.0, kw_only=True) @property def is_valid(self): @@ -529,6 +538,8 @@ def is_valid(self): @cached_property def report_period(self) -> float: """Time between progress reports in seconds.""" + if self.report_rate <= 0: + raise ValueError("report_rate must be positive") return 1.0 / self.report_rate def run_step(self, lf: LabeledFrame) -> LabeledFrame: @@ -751,12 +762,6 @@ class Tracker(BaseTracker): last_matches: Optional[FrameMatches] = None - verbosity: str = attr.ib( - validator=attr.validators.in_(["none", "rich", "json"]), - default="none", - ) - report_rate: float = 2.0 - @property def is_valid(self): return self.similarity_function is not None @@ -1493,12 +1498,6 @@ class KalmanTracker(BaseTracker): last_t: int = 0 last_init_t: int = 0 - verbosity: str = attr.ib( - validator=attr.validators.in_(["none", "rich", "json"]), - default="none", - ) - report_rate: float = 2.0 - @property def is_valid(self): """Do we have everything we need to run tracking?""" From b61af69eefc0bc0e6b7a710cd5243c46855afa8b Mon Sep 17 00:00:00 2001 From: getzze Date: Fri, 25 Oct 2024 16:29:13 +0100 Subject: [PATCH 15/15] undo coderabbit suggestion for cached_property --- sleap/nn/inference.py | 92 +++++++++++++++++++++---------------------- sleap/nn/tracking.py | 9 +++-- 2 files changed, 49 insertions(+), 52 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 14d14e7d8..0923b6979 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -21,69 +21,67 @@ function which provides a simplified interface for creating `Predictor`s. """ -import attr import argparse +import atexit +import json import logging -import warnings import os -import sys -import tempfile import platform import shutil -import atexit import subprocess -import rich.progress -import pandas as pd -from rich.pretty import pprint +import sys +import tempfile +import warnings +from abc import ABC, abstractmethod from collections import deque -import json -from time import time from datetime import datetime from pathlib import Path -import tensorflow_hub as hub -from abc import ABC, abstractmethod -from typing import Text, Optional, List, Dict, Union, Iterator, Tuple -from threading import Thread from queue import Queue +from threading import Thread +from time import time +from typing import Dict, Iterator, List, Optional, Text, Tuple, Union if sys.version_info >= (3, 8): from functools import cached_property else: # cached_property is defined only for python >=3.8 - from functools import lru_cache - - def cached_property(func): - return property(lru_cache()(func)) + cached_property = property -import tensorflow as tf +import attr import numpy as np +import pandas as pd +import rich.progress +import tensorflow as tf +import tensorflow_hub as hub +from rich.pretty import pprint +from tensorflow.python.framework.convert_to_constants import ( + convert_variables_to_constants_v2, +) import sleap - -from sleap.nn.config import TrainingJobConfig, DataConfig -from sleap.nn.data.resizing import SizeMatcher -from sleap.nn.model import Model -from sleap.nn.tracking import Tracker -from sleap.nn.paf_grouping import PAFScorer +from sleap.instance import LabeledFrame, PredictedInstance +from sleap.io.dataset import Labels +from sleap.nn.config import DataConfig, TrainingJobConfig from sleap.nn.data.pipelines import ( - Provider, - Pipeline, + Batcher, + InstanceCentroidFinder, + KerasModelPredictor, LabelsReader, - VideoReader, Normalizer, - Resizer, + Pipeline, Prefetcher, - InstanceCentroidFinder, - KerasModelPredictor, + Provider, + Resizer, + VideoReader, ) +from sleap.nn.data.resizing import SizeMatcher +from sleap.nn.model import Model +from sleap.nn.paf_grouping import PAFScorer +from sleap.nn.tracking import Tracker from sleap.nn.utils import reset_input_layer -from sleap.io.dataset import Labels -from sleap.util import frame_list, make_scoped_dictionary, RateColumn -from sleap.instance import PredictedInstance, LabeledFrame +from sleap.util import RateColumn, frame_list, make_scoped_dictionary -from tensorflow.python.framework.convert_to_constants import ( - convert_variables_to_constants_v2, -) +logger = logging.getLogger(__name__) MOVENET_MODELS = { "lightning": { @@ -135,8 +133,6 @@ def cached_property(func): ], ) -logger = logging.getLogger(__name__) - def get_keras_model_path(path: Text) -> str: """Utility method for finding the path to a saved Keras model. @@ -169,7 +165,8 @@ class Predictor(ABC): def report_period(self) -> float: """Time between progress reports in seconds.""" if self.report_rate <= 0: - raise ValueError("report_rate must be positive") + logger.warning("report_rate must be positive, fallback to 1") + return 1.0 return 1.0 / self.report_rate @classmethod @@ -360,7 +357,7 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: ensure_rgb=(not self.is_grayscale), ) - pipeline += sleap.nn.data.pipelines.Batcher( + pipeline += Batcher( batch_size=self.batch_size, drop_remainder=False, unrag=False ) @@ -617,7 +614,7 @@ def export_model( ) + (keras_model_shape[3],) tracing_batch = np.zeros((1,) + sample_shape, dtype="uint8") - outputs = self.inference_model.predict(tracing_batch) + _ = self.inference_model.predict(tracing_batch) self.inference_model.export_model( save_path, signatures, save_traces, model_name, tensors, unrag_outputs @@ -2570,7 +2567,7 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: skeletons=self.confmap_config.data.labels.skeletons, ) - pipeline += sleap.nn.data.pipelines.Batcher( + pipeline += Batcher( batch_size=self.batch_size, drop_remainder=False, unrag=False ) @@ -4422,13 +4419,13 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: if self.centroid_model is None: anchor_part = self.confmap_config.data.instance_cropping.center_on_part - pipeline += sleap.nn.data.pipelines.InstanceCentroidFinder( + pipeline += InstanceCentroidFinder( center_on_anchor_part=anchor_part is not None, anchor_part_names=anchor_part, skeletons=self.confmap_config.data.labels.skeletons, ) - pipeline += sleap.nn.data.pipelines.Batcher( + pipeline += Batcher( batch_size=self.batch_size, drop_remainder=False, unrag=False ) @@ -4650,7 +4647,7 @@ def __init__(self, model_name="lightning"): ) def call(self, ex): - if type(ex) == dict: + if isinstance(ex, dict): img = ex["image"] else: @@ -5496,7 +5493,7 @@ def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor: max_instances=args.max_instances, ) - if type(predictor) == BottomUpPredictor: + if isinstance(predictor, BottomUpPredictor): predictor.inference_model.bottomup_layer.paf_scorer.max_edge_length_ratio = ( args.max_edge_length_ratio ) @@ -5608,7 +5605,6 @@ def main(args: Optional[list] = None): # Either run inference (and tracking) or just run tracking (if using an existing prediction where inference has already been run) if args.models is not None: - # Run inference on all files inputed for i, (data_path, provider) in enumerate(zip(data_path_list, provider_list)): # Setup models. diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index b65bb2f90..55170ea36 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -3,6 +3,7 @@ import abc import functools import json +import logging import sys from collections import deque from time import time @@ -38,10 +39,9 @@ from functools import cached_property else: # cached_property is defined only for python >=3.8 - from functools import lru_cache + cached_property = property - def cached_property(func): - return property(lru_cache()(func)) +logger = logging.getLogger(__name__) @attr.s(eq=False, slots=True, auto_attribs=True) @@ -539,7 +539,8 @@ def is_valid(self): def report_period(self) -> float: """Time between progress reports in seconds.""" if self.report_rate <= 0: - raise ValueError("report_rate must be positive") + logger.warning("report_rate must be positive, fallback to 1") + return 1.0 return 1.0 / self.report_rate def run_step(self, lf: LabeledFrame) -> LabeledFrame: