Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(3->2) Add method to update Instances across views in RecordingSession #1279

Open
wants to merge 45 commits into
base: liezl/ars-add-sessions-to-cache
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
190b82c
Update `Instance`s across views in `RecordingSession`
roomrys Apr 18, 2023
4aa285b
Fix-up implementation while manually testing
roomrys Apr 19, 2023
7754a66
Only update incomplete points
roomrys Apr 19, 2023
4dca8be
Lint
roomrys Apr 19, 2023
f5d41b3
Merge branch 'liezl/ars-add-sessions-to-cache' into liezl/asc-initial…
roomrys Apr 27, 2023
619d584
Merge branch 'liezl/ars-add-sessions-to-cache' into liezl/asc-initial…
roomrys Jul 20, 2023
96df7f4
Merge branch 'liezl/ars-add-sessions-to-cache' of https://github.com/…
roomrys Sep 29, 2023
280d48e
Modularize the `RecordingSessions.update_views` function
roomrys Sep 29, 2023
9559da7
Add more informative message to `RecordingSession.update_views`
roomrys Sep 29, 2023
45b8475
Add `require_multiple_views` parameter to `RecordingSession.get_insta…
roomrys Sep 29, 2023
ba2df88
(4->3) Fix (de)serialization of `RecordingSession` and add Multiview …
roomrys Sep 29, 2023
35c1521
Add test for
Oct 4, 2023
8098ae1
Fix spelling
Oct 4, 2023
addc8be
Add debugging messages and raise ValueError
Oct 4, 2023
74a8f83
Add test for `get_instances_accross_views`
Oct 4, 2023
4f65679
Lint
Oct 4, 2023
243ad3c
Modularize
Oct 4, 2023
bb0fda1
Modularize RecordingSession.update_views
Oct 4, 2023
f203953
Add comments on expected array size
Oct 4, 2023
2bc5c47
Breakout recording session tests and test reprojection
Oct 4, 2023
e0a57f2
Add test for RecordingSession.update_instances
Oct 4, 2023
a1965e9
Add keyword arguments to `update_views` call
roomrys Oct 9, 2023
62f4f7f
Add docstring to update_instances
roomrys Oct 9, 2023
9b67128
Fix update_instances test
roomrys Oct 9, 2023
99ef17e
Update debug/warning messages
roomrys Oct 9, 2023
38599f4
Add test for remove_video and labels cache
roomrys Oct 9, 2023
df1bc72
Add test for update_views
roomrys Oct 10, 2023
e85e4fb
Pass in excluded views to triangulate and reproject
roomrys Oct 10, 2023
7105656
Small fix in test_recording_session
roomrys Oct 10, 2023
bffeca8
Ensure instances are always passed in with correct order
roomrys Oct 10, 2023
c293d73
Add `TriangulateSession` command
roomrys Oct 13, 2023
001c849
Add tests for `TriangulateSession`
roomrys Oct 13, 2023
06dff9f
Handover call to triangulate to `TriangulateSession`
roomrys Oct 13, 2023
4706fb2
Remove triangulate functionality from `RecordingSession`
roomrys Oct 13, 2023
4994f04
Merge branch 'liezl/ars-add-sessions-to-cache' of https://github.com/…
roomrys Oct 19, 2023
7cc7a5a
Fix verify instances bug via typo
roomrys Dec 5, 2023
a2326d3
Better warning messages (w/o dialog pop-ups!)
roomrys Dec 6, 2023
2f56a1c
(3a -> 3) Add method to match instances across views (#1579)
roomrys Apr 18, 2024
712fc70
Remove unused imports
Apr 18, 2024
b1d2372
(3a -> 3) Serialize `FrameGroup` and `InstanceGroup` (update `Recordi…
roomrys Apr 23, 2024
0d26728
(3a -> 3) Add `FrameGroup` fixture (#1753)
roomrys Apr 23, 2024
8d3e232
(3a -> 3) Debug `TriangulateSession` command (#1755)
roomrys Apr 25, 2024
507eefa
(3a -> 3) Debug missing `PredictedInstance`s (not displaying in GUI) …
roomrys Apr 25, 2024
55c6fd4
Test `InstanceGroup` and `FrameGroup` (#1759)
roomrys Apr 30, 2024
07ea17b
Handle case when upserting_instances with no instances in at least on…
May 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,8 +1224,27 @@ def plotFrame(self, *args, **kwargs):

def _after_plot_update(self, frame_idx):
"""Run after plot is updated, but stay on same frame."""

video = self.state["video"]
roomrys marked this conversation as resolved.
Show resolved Hide resolved
instance = self.state["instance"]

# Redraw trails
overlay: TrackTrailOverlay = self.overlays["trails"]
overlay.redraw(self.state["video"], frame_idx)
overlay.redraw(video, frame_idx)

# Replot connected views for multi-camera projects
session = self.labels.get_session(video)
if session is not None and instance is not None:
track = instance.track
session.update_views(
frame_idx,
track,
)

# TODO(LM): Move this to a dock for replotting on demand
# Replot instance nodes (but keep same QtInstance)
for inst in self.player.view.instances:
inst.updatePoints(complete=False, user_change=False)
Copy link

@coderabbitai coderabbitai bot Sep 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The instances are updated using the updatePoints method. This could potentially be a performance bottleneck if there are many instances. Consider optimizing this part of the code if performance becomes an issue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would you propose we optimize this part of the code?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@roomrys, one way to optimize this part of the code could be to only update the instances that have actually changed. Currently, it seems like all instances are updated regardless of whether they've changed or not.

Here's a rough idea of how you might implement this:

# Keep track of the last frame index and track we updated
last_frame_idx = None
last_track = None

def _after_plot_update(self, frame_idx):
    # ...

    if session is not None and instance is not None:
        track = instance.track

        # Only update views if the frame index or track has changed
        if frame_idx != self.last_frame_idx or track != self.last_track:
            session.update_views(frame_idx, track)
            self.last_frame_idx = frame_idx
            self.last_track = track

            # Replot instance nodes (but keep same QtInstance)
            for inst in self.player.view.instances:
                inst.updatePoints(complete=False, user_change=False)

This way, we avoid unnecessary updates when the frame index and track haven't changed. However, please note that this is just a suggestion and may need to be adapted based on your specific use case and codebase.


def _after_plot_change(self, player, frame_idx, selected_inst):
"""Called each time a new frame is drawn."""
Expand Down
21 changes: 20 additions & 1 deletion sleap/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,23 @@ def points(self) -> Tuple[Point, ...]:
self._fix_array()
return tuple(point for point in self._points if not point.isnan())

def update_points(self, points: np.ndarray, exclude_complete: bool = False):
"""Update the points in this instance from a numpy.

Args:
points: The new points to update to.
exclude_complete: Whether to update points where Point.complete is True
"""
points_dict = dict()
for point_new, points_old, node_name in zip(
points, self._points, self.skeleton.node_names
):
if np.isnan(point_new).any() or (exclude_complete and points_old.complete):
continue
points_dict[node_name] = Point(x=point_new[0], y=point_new[1])
if len(points_dict) > 0:
Instance._points_dict_to_array(points_dict, self._points, self.skeleton)

def _fix_array(self):
"""Fix PointArray after nodes have been added or removed.

roomrys marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -1380,7 +1397,9 @@ def find(
Returns:
List of instances.
"""
instances = self.instances
instances = sorted(
self.instances, key=lambda inst: isinstance(inst, PredictedInstance)
) # Sort with PredictedInstances last
if user:
instances = list(filter(lambda inst: type(inst) == Instance, instances))
if track != -1: # use -1 since we want to accept None as possible value
Expand Down
226 changes: 214 additions & 12 deletions sleap/io/cameras.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
"""Module for storing information for camera groups."""
import logging
from pathlib import Path
import tempfile
import cattr
import toml

import cattr
import numpy as np

from pathlib import Path
from typing import List, Optional, Union, Iterator, Any, Dict, Tuple

from aniposelib.cameras import Camera, FisheyeCamera, CameraGroup
from attrs import define, field
from attrs.validators import deep_iterable, instance_of
import numpy as np


from sleap.util import deep_iterable_converter
from sleap_anipose import triangulate, reproject

# from sleap.io.dataset import Labels # TODO(LM): Circular import, implement Observer
from sleap.io.video import Video
from sleap.util import deep_iterable_converter


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -549,6 +550,206 @@ def remove_video(self, video: Video):
if self.labels is not None and self.labels.get_session(video) is not None:
self.labels.remove_session_video(self, video)

def get_videos_from_selected_cameras(
self, cams_to_include: Optional[List[Camcorder]] = None
) -> List[Video]:
"""Get all `Video`s from selected `Camcorder`s.

Args:
cams_to_include: List of `Camcorder`s to include. Defualt is all.

Returns:
List of `Video`s.
"""

# If no `Camcorder`s specified, then return all videos in session
if cams_to_include is None:
return self.videos

# Get all videos from selected `Camcorder`s
videos: List[Video] = []
for cam in cams_to_include:
video = self.get_video(cam)
if video is not None:
videos.append(video)

return videos

def get_all_views_at_frame(
self,
frame_idx,
cams_to_include: Optional[List[Camcorder]] = None,
) -> List["LabeledFrame"]:
"""Get all views at a given frame index.

Args:
frame_idx: Frame index to get views from (0-indexed).
cams_to_include: List of `Camcorder`s to include. Default is all.

Returns:
List of `LabeledFrame` objects.
"""

views: List["LabeledFrame"] = []

videos = self.get_videos_from_selected_cameras(cams_to_include=cams_to_include)
for video in videos:
lfs: List["LabeledFrame"] = self.labels.get((video, [frame_idx]))
if len(lfs) == 0:
logger.debug(
"No LabeledFrames found for video " f"{video} at {frame_idx}."
)
continue

lf = lfs[0]
if len(lf.instances) == 0:
logger.warning(
f"No Instances found for {lf}."
" There should be not empty LabeledFrames."
)
continue

views.append(lf)

return views

def get_instances_across_views(
self,
frame_idx: int,
cams_to_include: Optional[List[Camcorder]] = None,
track: Optional["Track"] = None,
require_multiple_views: bool = False,
) -> List["LabeledFrame"]:
"""Get all `Instances` accross all views at a given frame index.

Args:
frame_idx: Frame index to get instances from (0-indexed).
cams_to_include: List of `Camcorder`s to include. Default is all.
track: `Track` object used to find instances accross views. Default is None.
require_multiple_views: If True, then raise and error if one or less views
or instances are found.

Returns:
List of `Instances` objects.

Raises:
ValueError if require_multiple_view is true and one or less views or
instances are found.
"""

views: List["LabeledFrame"] = []
instances: List["Instances"] = []

# Get all views at this frame index
views = self.get_all_views_at_frame(
frame_idx=frame_idx,
cams_to_include=cams_to_include,
)

# If not enough views, then raise error
if len(views) <= 1 and require_multiple_views:
raise ValueError(
"One or less views found for frame "
f"{frame_idx} in {self.camera_cluster}."
)

# Find all instance accross all views
instances: List["Instances"] = []
for lf in views:
insts = lf.find(track=track)
if len(insts) > 0:
instances.append(insts[0])

# If not enough instances for multiple views, then raise error
if len(instances) <= 1 and require_multiple_views:
raise ValueError(
"One or less instances found for frame "
f"{frame_idx} in {self.camera_cluster}."
)

return instances

def calculate_reprojected_points(self, instances: List["Instances"]):
"""Triangulate and reproject instance coordinates.

Args:
instances: List of `Instances` objects.

Returns:
List of reprojected instance coordinates. Each element in the list is a
numpy array of shape (1, N, 2) where N is the number of nodes.
"""

# Gather instances into M x F x T x N x 2 arrays
# (M = # views, F = # frames = 1, T = # tracks = 1, N = # nodes, 2 = x, y)
inst_coords = np.stack([inst.numpy() for inst in instances], axis=0)
inst_coords = np.expand_dims(inst_coords, axis=1)
inst_coords = np.expand_dims(inst_coords, axis=1)
points_3d = triangulate(p2d=inst_coords, calib=self.camera_cluster)

# Update the views with the new 3D points
inst_coords_reprojected = reproject(points_3d, calib=self.camera_cluster)
insts_coords_list: List[np.ndarray] = np.split(
inst_coords_reprojected.squeeze(), inst_coords_reprojected.shape[0], axis=0
)

return insts_coords_list

def update_views(
self,
frame_idx: int,
cams_to_include: Optional[List[Camcorder]] = None,
track: Optional["Track"] = None,
):
"""Update the views of the `RecordingSession`.

Args:
frame_idx: Frame index to update (0-indexed).
cams_to_include: List of `Camcorder`s to include. Default is all.
track: `Track` object used to find instances accross views for updating.

Returns:
None
"""

# If not enough `Camcorder`s available/specified, then return
if (cams_to_include is not None and len(cams_to_include) <= 1) or (
len(self.videos) <= 1
):
logger.warning(
"One or less cameras available. "
"Multiple cameras needed to triangulate. "
"Skipping triangulation and reprojection."
)
return

# Get all views at this frame index
try:
instances = self.get_instances_across_views(
frame_idx,
cams_to_include=cams_to_include,
track=track,
require_multiple_views=True,
)
except ValueError:
# If not enough views or instances, then return
logger.warning(
"One or less instances found for frame "
f"{frame_idx} in {self.camera_cluster}. "
"Multiple instances accross multiple views needed to triangulate. "
"Skipping triangulation and reprojection."
)
return

# Triangulate, reproject, and update coordinates
insts_coords_list: List[np.ndarray] = self.calculate_reprojected_points(
instances
)
for inst, inst_coord in zip(instances, insts_coords_list):
inst.update_points(
inst_coord[0], exclude_complete=True
) # inst_coord is (1, N, 2)

roomrys marked this conversation as resolved.
Show resolved Hide resolved
def __attrs_post_init__(self):
self.camera_cluster.add_session(self)

Expand All @@ -559,7 +760,6 @@ def __len__(self):
return len(self.videos)

def __getattr__(self, attr: str) -> Any:

"""Try to find the attribute in the camera_cluster next."""
return getattr(self.camera_cluster, attr)

Expand Down Expand Up @@ -603,7 +803,10 @@ def __getitem__(
)

def __repr__(self):
return f"{self.__class__.__name__}(camera_cluster={self.camera_cluster})"
return (
f"{self.__class__.__name__}(videos:{len(self.videos)},"
f"camera_cluster={self.camera_cluster})"
)
roomrys marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def load(
Expand Down Expand Up @@ -658,7 +861,6 @@ def to_session_dict(self, video_to_idx: Dict[Video, int]) -> dict:
# and value is video index from `Labels.videos`
camcorder_to_video_idx_map = {}
for cam_idx, camcorder in enumerate(self.camera_cluster):

# Skip if Camcorder is not linked to any Video
if camcorder not in self._video_by_camcorder:
continue
Expand All @@ -668,7 +870,7 @@ def to_session_dict(self, video_to_idx: Dict[Video, int]) -> dict:
video_idx = video_to_idx.get(video, None)

if video_idx is not None:
camcorder_to_video_idx_map[cam_idx] = video_idx
camcorder_to_video_idx_map[str(cam_idx)] = str(video_idx)
else:
logger.warning(
f"Video {video} not found in `Labels.videos`. "
Expand Down Expand Up @@ -704,8 +906,8 @@ def from_session_dict(
# Retrieve all `Camcorder` and `Video` objects, then add to `RecordingSession`
camcorder_to_video_idx_map = session_dict["camcorder_to_video_idx_map"]
for cam_idx, video_idx in camcorder_to_video_idx_map.items():
camcorder = session.camera_cluster.cameras[cam_idx]
video = videos_list[video_idx]
camcorder = session.camera_cluster.cameras[int(cam_idx)]
video = videos_list[int(video_idx)]
session.add_video(video, camcorder)

return session
Expand Down
1 change: 1 addition & 0 deletions sleap/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ def __repr__(self) -> str:
"Labels("
f"labeled_frames={len(self.labeled_frames)}, "
f"videos={len(self.videos)}, "
f"sessions={len(self.sessions)}, "
f"skeletons={len(self.skeletons)}, "
f"tracks={len(self.tracks)}"
")"
Expand Down
6 changes: 3 additions & 3 deletions sleap/io/format/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def read_headers(

# These items are stored in separate lists because the metadata group got to be
# too big.
for key in ("videos", "tracks", "suggestions"):
for key in ("videos", "tracks", "suggestions", "sessions"):
hdf5_key = f"{key}_json"
if hdf5_key in f:
items = [json_loads(item_json) for item_json in f[hdf5_key]]
Expand Down Expand Up @@ -325,7 +325,7 @@ def append_unique(old, new):
if not append:
# These items are stored in separate lists because the metadata
# group got to be too big.
for key in ("videos", "tracks", "suggestions"):
for key in ("videos", "tracks", "suggestions", "sessions"):
# Convert for saving in hdf5 dataset
data = [np.string_(json_dumps(item)) for item in d[key]]

Expand All @@ -341,7 +341,7 @@ def append_unique(old, new):
meta_group.attrs["json"] = np.string_(json_dumps(d))

# FIXME: We can probably construct these from attrs fields
# We will store Instances and PredcitedInstances in the same
# We will store Instances and PredictedInstances in the same
# table. instance_type=0 or Instance and instance_type=1 for
# PredictedInstance, score will be ignored for Instances.
instance_dtype = np.dtype(
Expand Down
Loading
Loading