Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamically update Predict On combo based on pipeline type #1300

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 139 additions & 83 deletions sleap/gui/learning/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,14 @@ def __init__(
self.labels = labels
self.skeleton = skeleton

# Attributes for selecting which frames to run inference on
self._frame_selection = None
self._frame_selection_descriptions = None
self._frame_count_dict = None

# Attributes for selecting the model type
self.current_pipeline = ""
self.previous_pipeline = None # Also used as a flag for first initialization

self.tabs: Dict[str, TrainingEditorWidget] = dict()
self.shown_tab_names = []
Expand Down Expand Up @@ -133,9 +138,12 @@ def __init__(
# Default to most recently trained pipeline (if there is one)
self.set_default_pipeline_tab()

# Connect functions to update pipeline tabs when pipeline changes
# Connect functions to update pipeline tabs and Predict On when pipeline changes
self.pipeline_form_widget.updatePipeline.connect(self.set_pipeline)
self.pipeline_form_widget.emitPipeline()
self.pipeline_form_widget.updatePipeline.connect(
self.populate_predict_frames_options
) # Connect this after setting and emmitting pipeline (to avoid double init)

self.connect_signals()

Expand Down Expand Up @@ -175,97 +183,143 @@ def count_total_frames_for_selection_option(

@property
def frame_selection(self) -> Dict[str, Dict[Video, List[int]]]:
"""
Returns dictionary with frames that user has selected for learning.
"""
"""Returns dictionary with frames that user has selected for learning."""
return self._frame_selection

@property
def frame_selection_descriptions(self):
return self._frame_selection_descriptions

@property
def default_frame_selection_ranking(self):
"""Returns list of options for frame selection, in order of preference."""

return ["clip", "suggestions", "current_frame", "nothing"]

@property
def non_sequential_frame_selections(self) -> List[str]:
"""Returns list of options for frame selection that are not sequential."""

return ["frame", "suggestions", "random", "random_video", "user"]

@property
def sequential_frame_selections(self) -> List[str]:
"""Returns list of options for frame selection that are sequential."""

return ["clip", "video", "all_videos"]

@frame_selection.setter
def frame_selection(self, frame_selection: Dict[str, Dict[Video, List[int]]]):
"""Sets options of frames on which to run learning."""
self._frame_selection = frame_selection

if "_predict_frames" in self.pipeline_form_widget.fields.keys():
prediction_options = []

total_random = 0
total_suggestions = 0
total_user = 0
random_video = 0
clip_length = 0
video_length = 0
all_videos_length = 0

# Determine which options are available given _frame_selection
if "random" in self._frame_selection:
total_random = self.count_total_frames_for_selection_option(
self._frame_selection["random"]
)
if "random_video" in self._frame_selection:
random_video = self.count_total_frames_for_selection_option(
self._frame_selection["random_video"]
)
if "suggestions" in self._frame_selection:
total_suggestions = self.count_total_frames_for_selection_option(
self._frame_selection["suggestions"]
)
if "user" in self._frame_selection:
total_user = self.count_total_frames_for_selection_option(
self._frame_selection["user"]
)
if "clip" in self._frame_selection:
clip_length = self.count_total_frames_for_selection_option(
self._frame_selection["clip"]
)
if "video" in self._frame_selection:
video_length = self.count_total_frames_for_selection_option(
self._frame_selection["video"]
)
if "all_videos" in self._frame_selection:
all_videos_length = self.count_total_frames_for_selection_option(
self._frame_selection["all_videos"]
)
# Return if predict frames field is not in pipeline form
if "_predict_frames" not in self.pipeline_form_widget.fields.keys():
return

# Build list of options
# Priority for default (lowest to highest):
# "nothing" (if training)
# "current frame" (if inference)
# "suggested frames" (if available)
# "selected clip" (if available)
if self.mode != "inference":
prediction_options.append("nothing")
prediction_options.append("current frame")
default_option = "nothing" if self.mode != "inference" else "current frame"

option = f"random frames ({total_random} total frames)"
prediction_options.append(option)

if random_video > 0:
option = f"random frames in current video ({random_video} frames)"
prediction_options.append(option)

if total_suggestions > 0:
option = f"suggested frames ({total_suggestions} total frames)"
prediction_options.append(option)
default_option = option

if total_user > 0:
option = f"user labeled frames ({total_user} total frames)"
prediction_options.append(option)

if clip_length > 0:
option = f"selected clip ({clip_length} frames)"
prediction_options.append(option)
default_option = option

prediction_options.append(f"entire current video ({video_length} frames)")

if len(self.labels.videos) > 1:
prediction_options.append(f"all videos ({all_videos_length} frames)")

self.pipeline_form_widget.fields["_predict_frames"].set_options(
prediction_options, default_option
# Count total frames for each option
self.count_frames_for_prediction_options()

# Populate list of options
self.populate_predict_frames_options(self.current_pipeline)

@property
def frame_count_dict(self):
"""Returns dictionary with number of frames for each option."""

return self._frame_count_dict

@frame_count_dict.setter
def frame_count_dict(self, frame_count_dict: Dict[str, int]):
"""Sets dictionary with number of frames for each option."""

self._frame_count_dict = frame_count_dict
self._frame_selection_descriptions = {
"frame": "current frame",
"clip": f"selected clip ({frame_count_dict.get('clip', None)} frames)",
"video": f"entire current video ({frame_count_dict.get('video', None)} frames)",
"all_videos": f"all videos ({frame_count_dict.get('all_videos', None)} frames)",
"suggestions": f"suggested frames ({frame_count_dict.get('suggestions', None)} total frames)",
"random": f"random frames ({frame_count_dict.get('random', None)} total frames)",
"random_video": f"random frames in current video ({frame_count_dict.get('random_video', None)} frames",
"user": f"user labeled frames ({frame_count_dict.get('user', None)} total frames)",
}

def count_frames_for_prediction_options(self):
frame_count_dict = {}
for option_name, video_and_frame_range in self.frame_selection.items():
f_count = self.count_total_frames_for_selection_option(
video_and_frame_range
)
if f_count > 0:
frame_count_dict[option_name] = f_count

# Custom logic to remove options that don't make sense for the current dialog
if (len(self.labels.videos) < 2) and ("all_videos" in frame_count_dict):
frame_count_dict.pop("all_videos")

# Use setter method to map option names to option descriptions using counts
self.frame_count_dict = frame_count_dict

return frame_count_dict

def populate_predict_frames_options(self, current_pipeline: str):
"""Populates the options for the predict frames field.
Args:
frame_counts: Dictionary of frame counts for each option.
"""

# Return if predict frames field is not in pipeline form
if "_predict_frames" not in self.pipeline_form_widget.fields.keys():
return

# Update options if pipeline has changed between training and retracking
if (
self.previous_pipeline is not None
and self.previous_pipeline != "none"
and current_pipeline != "none"
):
self.previous_pipeline = current_pipeline
return

prediction_options: List[str] = []

# Add training-only option(s)
if self.mode != "inference":
prediction_options.append("nothing")

# Add non-sequential options
if current_pipeline != "none":
for option in self.non_sequential_frame_selections:
if option in self.frame_count_dict:
prediction_options.append(self.frame_selection_descriptions[option])

# Add sequential options (for retracking, i.e. current pipeline is "none")
for option in self.sequential_frame_selections:
if option in self.frame_count_dict:
prediction_options.append(self.frame_selection_descriptions[option])

# Choose default option
default_option = None
if self.previous_pipeline is None:
# If no previous pipeline, choose first option that is in the default list
for option in self.default_frame_selection_ranking:
description = self.frame_selection_descriptions.get(option, None)
if description in prediction_options:
default_option = description
break
else:
# If previous pipeline, choose option that is closest to previous option
current_option = self.pipeline_form_widget.fields["_predict_frames"].value()
if current_option in prediction_options:
default_option = current_option
default_option = default_option or prediction_options[0]

# Set options
self.pipeline_form_widget.fields["_predict_frames"].set_options(
prediction_options, default_option
)
self.previous_pipeline = current_pipeline

def connect_signals(self):
self.pipeline_form_widget.valueChanged.connect(self.on_tab_data_change)
Expand Down Expand Up @@ -908,6 +962,8 @@ def current_pipeline(self):
return "bottom-up-id"
if "single" in pipeline_selected_label:
return "single"
if pipeline_selected_label == "none":
return "none"
return ""

@current_pipeline.setter
Expand Down