Skip to content

Commit

Permalink
Merge pull request #188 from catalystneuro/add_set_times_to_segmentat…
Browse files Browse the repository at this point in the history
…ionextractor

`SegmentationExtractor`: add `set_times`
  • Loading branch information
CodyCBakerPhD authored Aug 10, 2022
2 parents 85f77d0 + eaf79a4 commit 66899bf
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ extractor depending on the version of the file. [PR #170](https://github.com/cat

### Improvements
* Add `frame_to_time` to `SegmentationExtractor`, `get_roi_ids` is now a class method. [PR #187](https://github.com/catalystneuro/roiextractors/pull/187)
* Add `set_times` to `SegmentationExtractor`. [PR #188](https://github.com/catalystneuro/roiextractors/pull/188)

### Fixes

Expand Down
11 changes: 11 additions & 0 deletions src/roiextractors/segmentationextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,17 @@ def get_num_planes(self):
"""
return self._num_planes

def set_times(self, times: ArrayType):
"""Sets the recording times in seconds for each frame.
Parameters
----------
times: array-like
The times in seconds for each frame
"""
assert len(times) == self.get_num_frames(), "'times' should have the same length of the number of frames!"
self._times = np.array(times, dtype=np.float64)

def frame_to_time(self, frame_indices: Union[IntType, ArrayType]) -> Union[FloatType, ArrayType]:
"""Returns the timing of frames in unit of seconds.
Expand Down
43 changes: 40 additions & 3 deletions tests/test_internals/test_testing_tools.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import unittest
import numpy as np
from hdmf.testing import TestCase
from numpy.testing import assert_array_equal

import numpy as np
from numpy.testing import assert_array_equal

from roiextractors.testing import generate_dummy_segmentation_extractor
from roiextractors.testing import (
generate_dummy_segmentation_extractor,
_assert_iterable_complete,
)


class TestDummySegmentationExtractor(unittest.TestCase):
class TestDummySegmentationExtractor(TestCase):
def setUp(self) -> None:
self.num_rois = 10
self.num_frames = 30
Expand Down Expand Up @@ -79,6 +84,38 @@ def test_passing_parameters(self):
assert segmentation_extractor.get_traces(name="deconvolved").shape == (self.num_rois, self.num_frames)
assert segmentation_extractor.get_traces(name="neuropil").shape == (self.num_rois, self.num_frames)

def test_set_times(self):
"""Test that set_times sets the times in the expected way."""

segmentation_extractor = generate_dummy_segmentation_extractor()

num_frames = segmentation_extractor.get_num_frames()
sampling_frequency = segmentation_extractor.get_sampling_frequency()

# Check that times have not been set yet
assert segmentation_extractor._times is None

# Set times with an array that has the same length as the number of frames
times_to_set = np.round(np.arange(num_frames) / sampling_frequency, 6)
segmentation_extractor.set_times(times_to_set)

assert_array_equal(segmentation_extractor._times, times_to_set)

_assert_iterable_complete(
iterable=segmentation_extractor._times,
dtypes=np.ndarray,
element_dtypes=np.float64,
shape=(num_frames,),
)

# Set times with an array that is too short
times_to_set = np.round(np.arange(num_frames - 1) / sampling_frequency, 6)
with self.assertRaisesWith(
exc_type=AssertionError,
exc_msg="'times' should have the same length of the number of frames!",
):
segmentation_extractor.set_times(times_to_set)

def test_frame_to_time_no_sampling_frequency(self):
segmentation_extractor = generate_dummy_segmentation_extractor(
sampling_frequency=None,
Expand Down

0 comments on commit 66899bf

Please sign in to comment.