Skip to content

Commit

Permalink
Add option to export to CSV via sleap-convert and API (#1730)
Browse files Browse the repository at this point in the history
* Add csv as a format option

* Add analysis to format

* Add csv suffix to output path

* Add condition for csv analysis file

* Add export function to Labels class

* delete print statement

* lint

* Add `analysis.csv` as parametrize input for `sleap-convert` tests

* test `export_csv` method added to `Labels` class

* black formatting

* use `Path` to construct filename

* add `analysis.csv` to cli guide for `sleap-convert`

---------

Co-authored-by: Talmo Pereira <[email protected]>
  • Loading branch information
eberrigan and talmo authored Apr 9, 2024
1 parent d4ad3bb commit f0c44c0
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 4 deletions.
3 changes: 2 additions & 1 deletion docs/guides/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,8 @@ optional arguments:
analysis file for the latter video is given a default name.
--format FORMAT Output format. Default ('slp') is SLEAP dataset;
'analysis' results in analysis.h5 file; 'analysis.nix' results
in an analysis nix file; 'h5' or 'json' results in SLEAP dataset
in an analysis nix file; 'analysis.csv' results
in an analysis csv file; 'h5' or 'json' results in SLEAP dataset
with specified file format.
--video VIDEO Path to video (if needed for conversion).
```
Expand Down
22 changes: 21 additions & 1 deletion sleap/io/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def create_parser():
help="Output format. Default ('slp') is SLEAP dataset; "
"'analysis' results in analysis.h5 file; "
"'analysis.nix' results in an analysis nix file;"
"'analysis.csv' results in an analysis csv file;"
"'h5' or 'json' results in SLEAP dataset "
"with specified file format.",
)
Expand Down Expand Up @@ -135,7 +136,12 @@ def main(args: list = None):
outnames = [path for path in args.outputs]
if len(outnames) < len(vids):
# if there are less outnames provided than videos to convert...
out_suffix = "nix" if "nix" in args.format else "h5"
if "nix" in args.format:
out_suffix = "nix"
elif "csv" in args.format:
out_suffix = "csv"
else:
out_suffix = "h5"
fn = args.input_path
fn = re.sub("(\.json(\.zip)?|\.h5|\.slp)$", "", fn)
fn = PurePath(fn)
Expand All @@ -158,6 +164,20 @@ def main(args: list = None):
NixAdaptor.write(outname, labels, args.input_path, video)
except ValueError as e:
print(e.args[0])

elif "csv" in args.format:
from sleap.info.write_tracking_h5 import main as write_analysis

for video, output_path in zip(vids, outnames):
write_analysis(
labels,
output_path=output_path,
labels_path=args.input_path,
all_frames=True,
video=video,
csv=True,
)

else:
from sleap.info.write_tracking_h5 import main as write_analysis

Expand Down
13 changes: 13 additions & 0 deletions sleap/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2055,6 +2055,19 @@ def export(self, filename: str):

SleapAnalysisAdaptor.write(filename, self)

def export_csv(self, filename: str):
"""Export labels to CSV format.
Args:
filename: Output path for the CSV format file.
Notes:
This will write the contents of the labels out as a CSV file.
"""
from sleap.io.format.csv import CSVAdaptor

CSVAdaptor.write(filename, self)

def export_nwb(
self,
filename: str,
Expand Down
4 changes: 2 additions & 2 deletions tests/io/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest


@pytest.mark.parametrize("format", ["analysis", "analysis.nix"])
@pytest.mark.parametrize("format", ["analysis", "analysis.nix", "analysis.csv"])
def test_analysis_format(
min_labels_slp: Labels,
min_labels_slp_path: Labels,
Expand All @@ -27,7 +27,7 @@ def generate_filenames(paths, format="analysis"):
labels_path = str(slp_path)
fn = re.sub("(\\.json(\\.zip)?|\\.h5|\\.slp)$", "", labels_path)
fn = PurePath(fn)
out_suffix = "nix" if "nix" in format else "h5"
out_suffix = "nix" if "nix" in format else "csv" if "csv" in format else "h5"
default_names = [
default_analysis_filename(
labels=labels,
Expand Down
44 changes: 44 additions & 0 deletions tests/io/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import pandas as pd
import pytest
import numpy as np
from pathlib import Path, PurePath

import sleap
from sleap.info.write_tracking_h5 import get_nodes_as_np_strings
from sleap.skeleton import Skeleton
from sleap.instance import Instance, Point, LabeledFrame, PredictedInstance, Track
from sleap.io.video import Video, MediaVideo
Expand Down Expand Up @@ -1559,3 +1561,45 @@ def test_export_nwb(centered_pair_predictions: Labels, tmpdir):
# Read from NWB file
read_labels = NDXPoseAdaptor.read(NDXPoseAdaptor, filehandle.FileHandle(filename))
assert_read_labels_match(centered_pair_predictions, read_labels)


@pytest.mark.parametrize(
"labels_fixture_name",
[
"centered_pair_labels",
"centered_pair_predictions",
"min_labels",
"min_labels_slp",
"min_labels_robot",
],
)
def test_export_csv(labels_fixture_name, tmpdir, request):
# Retrieve Labels fixture by name
labels_fixture = request.getfixturevalue(labels_fixture_name)

# Generate the filename for the CSV file
csv_filename = Path(tmpdir) / (labels_fixture_name + "_export.csv")

# Export to CSV file
labels_fixture.export_csv(str(csv_filename))

# Assert that the CSV file was created
assert csv_filename.is_file(), f"CSV file '{csv_filename}' was not created"


def test_exported_csv(tmpdir, min_labels_slp, minimal_instance_predictions_csv_path):
# Construct the filename for the CSV file
filename_csv = Path(tmpdir) / "minimal_instance_predictions_export.csv"
labels = min_labels_slp
# Export to CSV file
labels.export_csv(filename_csv)
# Read the CSV file
labels_csv = pd.read_csv(filename_csv)

# Read the csv file fixture
csv_predictions = pd.read_csv(minimal_instance_predictions_csv_path)

assert labels_csv.equals(csv_predictions)

# check number of cols
assert len(labels_csv.columns) - 3 == len(get_nodes_as_np_strings(labels)) * 3

0 comments on commit f0c44c0

Please sign in to comment.