Skip to content

Commit

Permalink
Import DLC with uniquebodyparts, add Tracks (#1562)
Browse files Browse the repository at this point in the history
* Import DLC with uniquebodyparts, add Tracks

* add tests

* correct tests
  • Loading branch information
getzze authored Oct 20, 2023
1 parent 5c3441c commit 46bd21d
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 27 deletions.
52 changes: 38 additions & 14 deletions sleap/io/format/deeplabcut.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import numpy as np
import pandas as pd

from typing import List, Optional
from typing import List, Optional, Dict

from sleap import Labels, Video, Skeleton
from sleap.instance import Instance, LabeledFrame, Point
from sleap.instance import Instance, LabeledFrame, Point, Track
from sleap.util import find_files_by_suffix

from .adaptor import Adaptor, SleapObjectType
Expand Down Expand Up @@ -119,11 +119,12 @@ def read_frames(

# Pull out animal and node names from the columns.
start_col = 3 if is_new_format else 1
animal_names = []
tracks: Dict[str, Optional[Track]] = {}
node_names = []
for animal_name, node_name, _ in data.columns[start_col:][::2]:
if animal_name not in animal_names:
animal_names.append(animal_name)
# Keep the starting frame index for each individual/track
if animal_name not in tracks.keys():
tracks[animal_name] = None
if node_name not in node_names:
node_names.append(node_name)

Expand Down Expand Up @@ -177,23 +178,33 @@ def read_frames(

instances = []
if is_multianimal:
for animal_name in animal_names:
for animal_name in tracks.keys():
any_not_missing = False
# Get points for each node.
instance_points = dict()
for node in node_names:
x, y = (
data[(animal_name, node, "x")][i],
data[(animal_name, node, "y")][i],
)
if (animal_name, node) in data.columns:
x, y = (
data[(animal_name, node, "x")][i],
data[(animal_name, node, "y")][i],
)
else:
x, y = np.nan, np.nan
instance_points[node] = Point(x, y)
if ~(np.isnan(x) and np.isnan(y)):
any_not_missing = True

if any_not_missing:
# Create track
if tracks[animal_name] is None:
tracks[animal_name] = Track(spawned_on=i, name=animal_name)
# Create instance with points.
instances.append(
Instance(skeleton=skeleton, points=instance_points)
Instance(
skeleton=skeleton,
points=instance_points,
track=tracks[animal_name],
)
)
else:
# Get points for each node.
Expand Down Expand Up @@ -270,6 +281,8 @@ def read(
skeleton = Skeleton()
if project_data.get("multianimalbodyparts", False):
skeleton.add_nodes(project_data["multianimalbodyparts"])
if "uniquebodyparts" in project_data:
skeleton.add_nodes(project_data["uniquebodyparts"])
else:
skeleton.add_nodes(project_data["bodyparts"])

Expand Down Expand Up @@ -298,13 +311,24 @@ def read(
# If subdirectory is foo, we look for foo.mp4 in videos dir.

shortname = os.path.split(data_subdir)[-1]
video_path = os.path.join(videos_dir, f"{shortname}.mp4")

if os.path.exists(video_path):
video_path = None
if os.path.exists(videos_dir):
with os.scandir(videos_dir) as file_iterator:
for file in file_iterator:
if not file.is_file():
continue
if os.path.splitext(file.name)[0] != shortname:
continue
video_path = os.path.join(videos_dir, file.name)
break

if video_path is not None and os.path.exists(video_path):
video = Video.from_filename(video_path)
else:
# When no video is found, the individual frame images
# stored in the labeled data subdir will be used.
if video_path is None:
video_path = os.path.join(videos_dir, f"{shortname}.mp4")
print(
f"Unable to find {video_path} so using individual frame images."
)
Expand Down
2 changes: 1 addition & 1 deletion sleap/nn/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def summary():
for gpu in all_gpus:
print(f" Device: {gpu.name}")
print(f" Available: {gpu in gpus}")
print(f" Initalized: {is_initialized(gpu)}")
print(f" Initialized: {is_initialized(gpu)}")
print(
f" Memory growth: {tf.config.experimental.get_memory_growth(gpu)}"
)
Expand Down
14 changes: 7 additions & 7 deletions tests/data/dlc/labeled-data/video/CollectedData_LM.csv
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
scorer,,,LM,LM,LM,LM,LM,LM,LM,LM,LM,LM,LM,LM
individuals,,,individual1,individual1,individual1,individual1,individual1,individual1,individual2,individual2,individual2,individual2,individual2,individual2
bodyparts,,,A,A,B,B,C,C,A,A,B,B,C,C
coords,,,x,y,x,y,x,y,x,y,x,y,x,y
labeled-data,video,img000.png,0,1,2,3,4,5,6,7,8,9,10,11
labeled-data,video,img001.png,12,13,,,15,16,17,18,,,20,21
scorer,,,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer
individuals,,,Animal1,Animal1,Animal1,Animal1,Animal1,Animal1,Animal2,Animal2,Animal2,Animal2,Animal2,Animal2,single,single,single,single
bodyparts,,,A,A,B,B,C,C,A,A,B,B,C,C,D,D,E,E
coords,,,x,y,x,y,x,y,x,y,x,y,x,y,x,y,x,y
labeled-data,video,img000.png,0,1,2,3,4,5,6,7,8,9,10,11,,,,
labeled-data,video,img001.png,12,13,,,15,16,17,18,,,20,21,22,23,24,25
labeled-data,video,img002.png,,,,,,,,,,,,
labeled-data,video,img003.png,22,23,24,25,26,27,,,,,,
labeled-data,video,img003.png,26,27,28,29,30,31,,,,,,,32,33,34,35
8 changes: 8 additions & 0 deletions tests/data/dlc/labeled-data/video/maudlc_testdata.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer
individuals,Animal1,Animal1,Animal1,Animal1,Animal1,Animal1,Animal2,Animal2,Animal2,Animal2,Animal2,Animal2,single,single,single,single
bodyparts,A,A,B,B,C,C,A,A,B,B,C,C,D,D,E,E
coords,x,y,x,y,x,y,x,y,x,y,x,y,x,y,x,y
labeled-data/video/img000.png,0,1,2,3,4,5,6,7,8,9,10,11,,,,
labeled-data/video/img001.png,12,13,,,15,16,17,18,,,20,21,22,23,24,25
labeled-data/video/img002.png,,,,,,,,,,,,
labeled-data/video/img003.png,26,27,28,29,30,31,,,,,,,32,33,34,35
8 changes: 8 additions & 0 deletions tests/data/dlc/labeled-data/video/maudlc_testdata_v2.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
scorer,,,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer,Scorer
individuals,,,Animal1,Animal1,Animal1,Animal1,Animal1,Animal1,Animal2,Animal2,Animal2,Animal2,Animal2,Animal2,single,single,single,single
bodyparts,,,A,A,B,B,C,C,A,A,B,B,C,C,D,D,E,E
coords,,,x,y,x,y,x,y,x,y,x,y,x,y,x,y,x,y
labeled-data,video,img000.png,0,1,2,3,4,5,6,7,8,9,10,11,,,,
labeled-data,video,img001.png,12,13,,,15,16,17,18,,,20,21,22,23,24,25
labeled-data,video,img002.png,,,,,,,,,,,,
labeled-data,video,img003.png,26,27,28,29,30,31,,,,,,,32,33,34,35
8 changes: 5 additions & 3 deletions tests/data/dlc/madlc_230_config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Project definitions (do not edit)
Task: madlc_2.3.0
Task: maudlc_2.3.0
scorer: LM
date: Mar1
multianimalproject: true
identity: false

# Project path (change when moving around)
project_path: D:\social-leap-estimates-animal-poses\pull-requests\sleap\tests\data\dlc\madlc_testdata_v3
project_path: D:\social-leap-estimates-animal-poses\pull-requests\sleap\tests\data\dlc\maudlc_testdata_v3

# Annotation data set configuration (and individual video cropping parameters)
video_sets:
Expand All @@ -16,7 +16,9 @@ individuals:
- individual1
- individual2
- individual3
uniquebodyparts: []
uniquebodyparts:
- D
- E
multianimalbodyparts:
- A
- B
Expand Down
2 changes: 1 addition & 1 deletion tests/gui/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_import_labels_from_dlc_folder():
assert len(labels.videos) == 2
assert len(labels.skeletons) == 1
assert len(labels.nodes) == 3
assert len(labels.tracks) == 0
assert len(labels.tracks) == 3

assert set(
[fix_path_separator(l.video.backend.filename) for l in labels.labeled_frames]
Expand Down
73 changes: 72 additions & 1 deletion tests/io/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ def test_matching_adaptor(centered_pair_predictions_hdf5_path):
[
"tests/data/dlc/labeled-data/video/madlc_testdata.csv",
"tests/data/dlc/labeled-data/video/madlc_testdata_v2.csv",
"tests/data/dlc/madlc_230_config.yaml",
],
)
def test_madlc(test_data):
Expand Down Expand Up @@ -232,6 +231,78 @@ def test_madlc(test_data):
assert labels[2].frame_idx == 3


@pytest.mark.parametrize(
"test_data",
[
"tests/data/dlc/labeled-data/video/maudlc_testdata.csv",
"tests/data/dlc/labeled-data/video/maudlc_testdata_v2.csv",
"tests/data/dlc/madlc_230_config.yaml",
],
)
def test_maudlc(test_data):
labels = read(
test_data,
for_object="labels",
as_format="deeplabcut",
)

assert labels.skeleton.node_names == ["A", "B", "C", "D", "E"]
assert len(labels.videos) == 1
assert len(labels.video.filenames) == 4
assert labels.videos[0].filenames[0].endswith("img000.png")
assert labels.videos[0].filenames[1].endswith("img001.png")
assert labels.videos[0].filenames[2].endswith("img002.png")
assert labels.videos[0].filenames[3].endswith("img003.png")

# Assert frames without any coor are not labeled
assert len(labels) == 3

# Assert number of instances per frame is correct
assert len(labels[0]) == 2
assert len(labels[1]) == 3
assert len(labels[2]) == 2

assert_array_equal(
labels[0][0].numpy(),
[[0, 1], [2, 3], [4, 5], [np.nan, np.nan], [np.nan, np.nan]],
)
assert_array_equal(
labels[0][1].numpy(),
[[6, 7], [8, 9], [10, 11], [np.nan, np.nan], [np.nan, np.nan]],
)
assert_array_equal(
labels[1][0].numpy(),
[[12, 13], [np.nan, np.nan], [15, 16], [np.nan, np.nan], [np.nan, np.nan]],
)
assert_array_equal(
labels[1][1].numpy(),
[[17, 18], [np.nan, np.nan], [20, 21], [np.nan, np.nan], [np.nan, np.nan]],
)
assert_array_equal(
labels[1][2].numpy(),
[[np.nan, np.nan], [np.nan, np.nan], [np.nan, np.nan], [22, 23], [24, 25]],
)
assert_array_equal(
labels[2][0].numpy(),
[[26, 27], [28, 29], [30, 31], [np.nan, np.nan], [np.nan, np.nan]],
)
assert_array_equal(
labels[2][1].numpy(),
[[np.nan, np.nan], [np.nan, np.nan], [np.nan, np.nan], [32, 33], [34, 35]],
)
assert labels[2].frame_idx == 3

# Assert tracks are correct
assert len(labels.tracks) == 3
sorted_animals = sorted(["Animal1", "Animal2", "single"])
assert sorted([t.name for t in labels.tracks]) == sorted_animals
for t in labels.tracks:
if t.name == "single":
assert t.spawned_on == 1
else:
assert t.spawned_on == 0


@pytest.mark.parametrize(
"test_data",
[
Expand Down

0 comments on commit 46bd21d

Please sign in to comment.