Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(3->2) Add method to update Instances across views in RecordingSession #1279

Open
wants to merge 45 commits into
base: liezl/ars-add-sessions-to-cache
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
190b82c
Update `Instance`s across views in `RecordingSession`
roomrys Apr 18, 2023
4aa285b
Fix-up implementation while manually testing
roomrys Apr 19, 2023
7754a66
Only update incomplete points
roomrys Apr 19, 2023
4dca8be
Lint
roomrys Apr 19, 2023
f5d41b3
Merge branch 'liezl/ars-add-sessions-to-cache' into liezl/asc-initial…
roomrys Apr 27, 2023
619d584
Merge branch 'liezl/ars-add-sessions-to-cache' into liezl/asc-initial…
roomrys Jul 20, 2023
96df7f4
Merge branch 'liezl/ars-add-sessions-to-cache' of https://github.com/…
roomrys Sep 29, 2023
280d48e
Modularize the `RecordingSessions.update_views` function
roomrys Sep 29, 2023
9559da7
Add more informative message to `RecordingSession.update_views`
roomrys Sep 29, 2023
45b8475
Add `require_multiple_views` parameter to `RecordingSession.get_insta…
roomrys Sep 29, 2023
ba2df88
(4->3) Fix (de)serialization of `RecordingSession` and add Multiview …
roomrys Sep 29, 2023
35c1521
Add test for
Oct 4, 2023
8098ae1
Fix spelling
Oct 4, 2023
addc8be
Add debugging messages and raise ValueError
Oct 4, 2023
74a8f83
Add test for `get_instances_accross_views`
Oct 4, 2023
4f65679
Lint
Oct 4, 2023
243ad3c
Modularize
Oct 4, 2023
bb0fda1
Modularize RecordingSession.update_views
Oct 4, 2023
f203953
Add comments on expected array size
Oct 4, 2023
2bc5c47
Breakout recording session tests and test reprojection
Oct 4, 2023
e0a57f2
Add test for RecordingSession.update_instances
Oct 4, 2023
a1965e9
Add keyword arguments to `update_views` call
roomrys Oct 9, 2023
62f4f7f
Add docstring to update_instances
roomrys Oct 9, 2023
9b67128
Fix update_instances test
roomrys Oct 9, 2023
99ef17e
Update debug/warning messages
roomrys Oct 9, 2023
38599f4
Add test for remove_video and labels cache
roomrys Oct 9, 2023
df1bc72
Add test for update_views
roomrys Oct 10, 2023
e85e4fb
Pass in excluded views to triangulate and reproject
roomrys Oct 10, 2023
7105656
Small fix in test_recording_session
roomrys Oct 10, 2023
bffeca8
Ensure instances are always passed in with correct order
roomrys Oct 10, 2023
c293d73
Add `TriangulateSession` command
roomrys Oct 13, 2023
001c849
Add tests for `TriangulateSession`
roomrys Oct 13, 2023
06dff9f
Handover call to triangulate to `TriangulateSession`
roomrys Oct 13, 2023
4706fb2
Remove triangulate functionality from `RecordingSession`
roomrys Oct 13, 2023
4994f04
Merge branch 'liezl/ars-add-sessions-to-cache' of https://github.com/…
roomrys Oct 19, 2023
7cc7a5a
Fix verify instances bug via typo
roomrys Dec 5, 2023
a2326d3
Better warning messages (w/o dialog pop-ups!)
roomrys Dec 6, 2023
2f56a1c
(3a -> 3) Add method to match instances across views (#1579)
roomrys Apr 18, 2024
712fc70
Remove unused imports
Apr 18, 2024
b1d2372
(3a -> 3) Serialize `FrameGroup` and `InstanceGroup` (update `Recordi…
roomrys Apr 23, 2024
0d26728
(3a -> 3) Add `FrameGroup` fixture (#1753)
roomrys Apr 23, 2024
8d3e232
(3a -> 3) Debug `TriangulateSession` command (#1755)
roomrys Apr 25, 2024
507eefa
(3a -> 3) Debug missing `PredictedInstance`s (not displaying in GUI) …
roomrys Apr 25, 2024
55c6fd4
Test `InstanceGroup` and `FrameGroup` (#1759)
roomrys Apr 30, 2024
07ea17b
Handle case when upserting_instances with no instances in at least on…
May 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,8 +1231,16 @@ def plotFrame(self, *args, **kwargs):

def _after_plot_update(self, frame_idx):
"""Run after plot is updated, but stay on same frame."""

video = self.state["video"]
roomrys marked this conversation as resolved.
Show resolved Hide resolved

# Redraw trails
overlay: TrackTrailOverlay = self.overlays["trails"]
overlay.redraw(self.state["video"], frame_idx)
overlay.redraw(video, frame_idx)

# Replot connected views for multi-camera projects
cams_to_include = None # TODO: make this configurable via GUI
self.commands.triangulateSession(cams_to_include=cams_to_include)

def _after_plot_change(self, player, frame_idx, selected_inst):
"""Called each time a new frame is drawn."""
Expand Down
157 changes: 152 additions & 5 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ class which inherits from `AppCommand` (or a more specialized class such as
import traceback
from enum import Enum
from glob import glob
from itertools import permutations, product
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused imports to clean up the code.

- from itertools import permutations, product

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
from itertools import permutations, product

from pathlib import Path, PurePath
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused import to clean up the code.

- from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast
+ from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union


import attr
import cv2
import numpy as np
from qtpy import QtCore, QtWidgets, QtGui
from qtpy.QtWidgets import QMessageBox, QProgressDialog
from sleap_anipose import triangulate, reproject
from qtpy import QtCore, QtGui, QtWidgets

from sleap.gui.dialogs.delete import DeleteDialog
from sleap.gui.dialogs.filedialog import FileDialog
Expand All @@ -53,7 +54,7 @@ class which inherits from `AppCommand` (or a more specialized class such as
from sleap.gui.state import GuiState
from sleap.gui.suggestions import VideoFrameSuggestions
from sleap.instance import Instance, LabeledFrame, Point, PredictedInstance, Track
from sleap.io.cameras import RecordingSession
from sleap.io.cameras import Camcorder, InstanceGroup, FrameGroup, RecordingSession
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused import to clean up the code.

- from sleap.io.cameras import Camcorder, InstanceGroup, FrameGroup, RecordingSession
+ from sleap.io.cameras import InstanceGroup, FrameGroup, RecordingSession

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
from sleap.io.cameras import Camcorder, InstanceGroup, FrameGroup, RecordingSession
from sleap.io.cameras import InstanceGroup, FrameGroup, RecordingSession

from sleap.io.convert import default_analysis_filename
from sleap.io.dataset import Labels
from sleap.io.format.adaptor import Adaptor
Expand Down Expand Up @@ -613,6 +614,24 @@ def generateSuggestions(self, params: Dict):
"""Generates suggestions using given params dictionary."""
self.execute(GenerateSuggestions, **params)

def triangulateSession(
self,
frame_idx: Optional[int] = None,
video: Optional[Video] = None,
instance: Optional[Instance] = None,
session: Optional[RecordingSession] = None,
cams_to_include: Optional[List[Camcorder]] = None,
):
"""Triangulates `Instance`s for selected views in a `RecordingSession`."""
self.execute(
TriangulateSession,
frame_idx=frame_idx,
video=video,
instance=instance,
session=session,
cams_to_include=cams_to_include,
)

def openWebsite(self, url):
"""Open a website from URL using the native system browser."""
self.execute(OpenWebsite, url=url)
Expand Down Expand Up @@ -1929,7 +1948,6 @@ class AddSession(EditCommand):

@staticmethod
def do_action(context: CommandContext, params: dict):

camera_calibration = params["camera_calibration"]
session = RecordingSession.load(filename=camera_calibration)

Expand Down Expand Up @@ -3369,6 +3387,135 @@ def do_action(cls, context: CommandContext, params: dict):
context.labels.append(current_frame)


class TriangulateSession(EditCommand):
topics = [UpdateTopic.frame, UpdateTopic.project_instances]

@classmethod
def do_action(cls, context: CommandContext, params: dict):
"""Triangulate, reproject, and update instances in a session at a frame index.

Args:
context: The command context.
params: The command parameters.
video: The `Video` object to use. Default is current video.
session: The `RecordingSession` object to use. Default is current
video's session.
frame_idx: The frame index to use. Default is current frame index.
instance: The `Instance` object to use. Default is current instance.
show_dialog: If True, then show a warning dialog. Default is True.
ask_again: If True, then ask for views/instances again. Default is False.
"""

# Get `FrameGroup` for the current frame index
video = params.get("video", None) or context.state["video"]
session = params.get("session", None) or context.labels.get_session(video)
frame_idx: int = params["frame_idx"]
frame_group: FrameGroup = (
params.get("frame_group", None) or session.frame_groups[frame_idx]
)

# Get the `InstanceGroup` from `Instance` if any
instance = params.get("instance", None) or context.state["instance"]
instance_group = frame_group.get_instance_group(instance)

# If instance_group is None, then we will try to triangulate entire frame_group
instance_groups = (
[instance_group]
if instance_group is not None
else frame_group.instance_groups
)

# Retain instance groups that have enough views/instances for triangulation
instance_groups = TriangulateSession.has_enough_instances(
frame_group=frame_group,
instance_groups=instance_groups,
frame_idx=frame_idx,
instance=instance,
)
if instance_groups is None or len(instance_groups) == 0:
return # Not enough instances for triangulation

# Get the `FrameGroup` of shape M=include x T x N x 2
fg_tensor = frame_group.numpy(instance_groups=instance_groups, pred_as_nan=True)

# Add extra dimension for number of frames
frame_group_tensor = np.expand_dims(fg_tensor, axis=1) # M=include x F=1 xTxNx2

# Triangulate to one 3D pose per instance
points_3d = triangulate(
p2d=frame_group_tensor,
calib=session.camera_cluster,
excluded_views=frame_group.excluded_views,
) # F x T x N x 3

# Reproject onto all views
pts_reprojected = reproject(
points_3d,
calib=session.camera_cluster,
excluded_views=frame_group.excluded_views,
) # M=include x F=1 x T x N x 2

# Sqeeze back to the original shape
points_reprojected = np.squeeze(pts_reprojected, axis=1) # M=include x TxNx2

# Update or create/insert ("upsert") instance points
frame_group.upsert_points(
points=points_reprojected,
instance_groups=instance_groups,
exclude_complete=True,
)

@classmethod
def has_enough_instances(
cls,
frame_group: FrameGroup,
instance_groups: Optional[List[InstanceGroup]],
frame_idx: Optional[int] = None,
instance: Optional[Instance] = None,
) -> Optional[List[InstanceGroup]]:
"""Filters out instance groups without enough instances for triangulation.

Args:
frame_group: The `FrameGroup` object to use.
instance_groups: A list of `InstanceGroup` objects to use.
frame_idx: The frame index to use.
instance: The `Instance` object to use (only used in logging).

Returns:
A list of `InstanceGroup` objects with enough instances for triangulation.
"""

if instance is None:
instance = "" # Just used for logging

if frame_idx is None:
frame_idx = "" # Just used for logging

if len(instance_groups) < 1:
logger.warning(
f"Require at least 1 instance group, but found "
f"{len(frame_group.instance_groups)} for frame group {frame_group} at "
f"frame {frame_idx}."
f"\nSkipping triangulation."
)
return None # No instance groups found

# Assert that there are enough views and instances
instance_groups_to_tri = []
for instance_group in instance_groups:
instances = instance_group.get_instances(frame_group.cams_to_include)
if len(instances) < 2:
# Not enough instances
logger.warning(
f"Not enough instances in {instance_group} for triangulation."
f"\nSkipping instance group."
)
continue
instance_groups_to_tri.append(instance_group)

return instance_groups_to_tri # `InstanceGroup`s with enough instances


def open_website(url: str):
"""Open website in default browser.

Expand Down
2 changes: 1 addition & 1 deletion sleap/gui/dialogs/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def _delete(self, lf_inst_list: List[Tuple[LabeledFrame, Instance]]):
for lf, inst in lf_inst_list:
self.context.labels.remove_instance(lf, inst, in_transaction=True)
if not lf.instances:
self.context.labels.remove(lf)
self.context.labels.remove_frame(lf=lf, update_cache=False)

# Update caches since we skipped doing this after each deletion
self.context.labels.update_cache()
Expand Down
69 changes: 53 additions & 16 deletions sleap/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@
out of sync if the skeleton is manipulated.
"""

import logging
import math

import numpy as np
import cattr

from copy import copy
from typing import Dict, List, Optional, Union, Tuple, ForwardRef
from typing import Dict, ForwardRef, List, Optional, Tuple, Union

import attr
import cattr
import numpy as np
from numpy.lib.recfunctions import structured_to_unstructured

import sleap
from sleap.skeleton import Skeleton, Node
from sleap.io.video import Video
from sleap.skeleton import Node, Skeleton

import attr
logger = logging.getLogger(__name__)


class Point(np.record):
Expand All @@ -57,7 +57,6 @@ def __new__(
visible: bool = True,
complete: bool = False,
) -> "Point":

# HACK: This is a crazy way to instantiate at new Point but I can't figure
# out how recarray does it. So I just use it to make matrix of size 1 and
# index in to get the np.record/Point
Expand Down Expand Up @@ -124,7 +123,6 @@ def __new__(
complete: bool = False,
score: float = 0.0,
) -> "PredictedPoint":

# HACK: This is a crazy way to instantiate at new Point but I can't figure
# out how recarray does it. So I just use it to make matrix of size 1 and
# index in to get the np.record/Point
Expand Down Expand Up @@ -184,7 +182,6 @@ def __new__(
aligned=False,
order="C",
) -> "PointArray":

dtype = subtype._record_type.dtype

if dtype is not None:
Expand Down Expand Up @@ -445,12 +442,10 @@ def __attrs_post_init__(self):
# If the user did not pass a points list initialize a point array for future
# points.
if self._points is None or len(self._points) == 0:

# Initialize an empty point array that is the size of the skeleton.
self._points = self._point_array_type.make_default(len(self.skeleton.nodes))

else:

if type(self._points) is dict:
roomrys marked this conversation as resolved.
Show resolved Hide resolved
parray = self._point_array_type.make_default(len(self.skeleton.nodes))
Instance._points_dict_to_array(self._points, parray, self.skeleton)
Expand Down Expand Up @@ -505,9 +500,11 @@ def _points_dict_to_array(
)
try:
parray[skeleton.node_to_index(node)] = point
# parray[skeleton.node_to_index(node.name)] = point
except:
roomrys marked this conversation as resolved.
Show resolved Hide resolved
pass
logger.debug(
f"Could not set point for node {node} in {skeleton} "
f"with point {point}"
)

def _node_to_index(self, node: Union[str, Node]) -> int:
"""Helper method to get the index of a node from its name.
Expand Down Expand Up @@ -720,6 +717,45 @@ def points(self) -> Tuple[Point, ...]:
self._fix_array()
return tuple(point for point in self._points if not point.isnan())

def update_points(self, points: np.ndarray, exclude_complete: bool = False):
"""Update the points in this instance from a numpy.

Args:
points: The new points to update to.
exclude_complete: Whether to update points where Point.complete is True
"""
points_dict = dict()
for point_new, points_old, node_name in zip(
points, self._points, self.skeleton.node_names
):

# Skip if new point is nan or old point is complete
if np.isnan(point_new).any() or (exclude_complete and points_old.complete):
continue

# Grab the x, y from the new point and visible, complete from the old point
x, y = point_new
visible = points_old.visible
complete = points_old.complete

# Create a new point and add to the dict
if type(self._points) == PredictedPointArray:
# TODO(LM): The point score is meant to rate the confidence of the
# prediction, but this method updates from triangulation.
score = points_old.score
point_obj = PredictedPoint(
x=x, y=y, visible=visible, complete=complete, score=score
)
else:
point_obj = Point(x=x, y=y, visible=visible, complete=complete)

# Update the points dict
points_dict[node_name] = point_obj

# Update the points
if len(points_dict) > 0:
Instance._points_dict_to_array(points_dict, self._points, self.skeleton)

def _fix_array(self):
"""Fix PointArray after nodes have been added or removed.

Expand Down Expand Up @@ -1199,7 +1235,6 @@ def make_instance_cattr() -> cattr.Converter:
converter.register_unstructure_hook(PredictedPointArray, lambda x: None)

def unstructure_instance(x: Instance):

# Unstructure everything but the points array, nodes, and frame attribute
d = {
field.name: converter.unstructure(x.__getattribute__(field.name))
Expand Down Expand Up @@ -1380,7 +1415,9 @@ def find(
Returns:
List of instances.
"""
instances = self.instances
instances = sorted(
self.instances, key=lambda inst: isinstance(inst, PredictedInstance)
) # Sort with PredictedInstances last
if user:
instances = list(filter(lambda inst: type(inst) == Instance, instances))
if track != -1: # use -1 since we want to accept None as possible value
Expand Down
Loading
Loading