Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gitttt-1234 committed Dec 23, 2024
1 parent 12081d7 commit d0af4e2
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 40 deletions.
27 changes: 9 additions & 18 deletions sleap/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,13 +480,10 @@ def _update_from_labels(self, merge: bool = False):

if len(self.nodes) == 0:
self.nodes = list(
set().union(
{node for skeleton in self.skeletons for node in skeleton.nodes}
)
set([node for skeleton in self.skeletons for node in skeleton.nodes])
)

if merge:

# remove duplicate skeletons during merge
skeletons = [self.skeletons[0]]
for lf in self.labels:
Expand All @@ -503,13 +500,11 @@ def _update_from_labels(self, merge: bool = False):

# updates nodes after removing duplicate skeletons
self.nodes = list(
set().union(
{node for skeleton in self.skeletons for node in skeleton.nodes}
)
set([node for skeleton in self.skeletons for node in skeleton.nodes])
)

# Ditto for tracks, a pattern is emerging here
if merge or len(self.tracks) == 0:
if len(self.tracks) == 0:
# Get tracks from any Instances or PredictedInstances
other_tracks = {
instance.track
Expand All @@ -531,13 +526,11 @@ def _update_from_labels(self, merge: bool = False):
)

# Get list of other tracks not already in track list
# new_tracks = list(other_tracks - set(self.tracks))
new_tracks = []
if not self.tracks:
new_tracks = list(other_tracks)
else:
for t in other_tracks:
for track in self.tracks:
new_tracks = list(other_tracks - set(self.tracks))
if self.tracks and merge:
new_tracks = [self.tracks[0]]
for track in other_tracks:
for t in new_tracks:
if not track.matches(t):
new_tracks.append(t)

Expand Down Expand Up @@ -1929,9 +1922,7 @@ def to_dict(self, skip_labels: bool = False) -> Dict[str, Any]:
# We shouldn't have to do this here, but for some reason we're missing nodes
# which are in the skeleton but don't have points (in the first instance?).
self.nodes = list(
set().union(
{node for skeleton in self.skeletons for node in skeleton.nodes}
)
set([node for skeleton in self.skeletons for node in skeleton.nodes])
)

# Register some unstructure hooks since we don't want complete deserialization
Expand Down
2 changes: 1 addition & 1 deletion tests/gui/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_import_labels_from_dlc_folder():
assert len(labels.videos) == 2
assert len(labels.skeletons) == 1
assert len(labels.nodes) == 3
assert len(labels.tracks) == 3
assert len(labels.tracks) == 2

assert set(
[fix_path_separator(l.video.backend.filename) for l in labels.labeled_frames]
Expand Down
21 changes: 0 additions & 21 deletions tests/io/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,27 +728,6 @@ def test_unify_skeletons():
labels.to_dict()


def test_dont_unify_skeletons():
vid = Video.from_filename("foo.mp4")

skeleton_a = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json")
skeleton_b = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json")
# skeleton_b.add_node("foo")

lf_a = LabeledFrame(vid, frame_idx=2, instances=[Instance(skeleton_a)])
lf_b = LabeledFrame(vid, frame_idx=3, instances=[Instance(skeleton_b)])

labels = Labels(labeled_frames=[lf_a])
labels.extend_from([lf_b], unify=False)
ids = skeleton_ids_from_label_instances(labels)

# Make sure we still have two distinct skeleton objects
assert len(set(ids)) == 2

# Make sure we can serialize this
labels.to_dict()


def test_instance_access():
labels = Labels()

Expand Down

0 comments on commit d0af4e2

Please sign in to comment.