From f2eb4055c55130e532bb00fdca26ceca7bb17ac1 Mon Sep 17 00:00:00 2001 From: Cody Baker Date: Sun, 28 Aug 2022 13:43:30 -0400 Subject: [PATCH 1/6] added frame slicing to segmentation; added tests --- src/roiextractors/segmentationextractor.py | 120 +++++++++++++++++- src/roiextractors/testing.py | 1 + .../test_frame_slice_segmentation.py | 101 +++++++++++++++ 3 files changed, 221 insertions(+), 1 deletion(-) create mode 100644 tests/test_internals/test_frame_slice_segmentation.py diff --git a/src/roiextractors/segmentationextractor.py b/src/roiextractors/segmentationextractor.py index 04adf6de..6701a2c6 100644 --- a/src/roiextractors/segmentationextractor.py +++ b/src/roiextractors/segmentationextractor.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod -from typing import Union +from typing import Union, Optional, Tuple import numpy as np +from numpy.typing import ArrayLike from .extraction_tools import ArrayType, IntType, FloatType from .extraction_tools import _pixel_mask_extractor @@ -159,6 +160,10 @@ def get_image_size(self) -> ArrayType: """ pass + def frame_slice(self, start_frame, end_frame): + """Return a new SegmentationExtractor ranging from the start_frame to the end_frame.""" + return FrameSliceSegmentationExtractor(parent_imaging=self, start_frame=start_frame, end_frame=end_frame) + def get_traces(self, roi_ids=None, start_frame=None, end_frame=None, name="raw"): """ Return RoiResponseSeries @@ -319,3 +324,116 @@ def write_segmentation(segmentation_extractor, save_path, overwrite=False): If True, the file is overwritten if existing (default False) """ raise NotImplementedError + + +class FrameSliceSegmentationExtractor(SegmentationExtractor): + """ + Class to get a lazy frame slice. + + Do not use this class directly but use `.frame_slice(...)` + """ + + extractor_name = "FrameSliceSegmentationExtractor" + is_writable = True + + def __init__( + self, + parent_segmentation: SegmentationExtractor, + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, + ): + self._parent_segmentation = parent_segmentation + self._start_frame = start_frame or 0 + self._end_frame = end_frame or self._parent_segmentation.get_num_frames() + self._num_frames = self._end_frame - self._start_frame + + self._image_masks = self._parent_segmentation._image_masks + + parent_size = self._parent_segmentation.get_num_frames() + if start_frame is None: + start_frame = 0 + else: + assert 0 <= start_frame < parent_size + if end_frame is None: + end_frame = parent_size + else: + assert 0 < end_frame <= parent_size + assert end_frame > start_frame, "'start_frame' must be smaller than 'end_frame'!" + + super().__init__() + + def get_accepted_list(self) -> list: + return self._parent_segmentation.get_accepted_list() + + def get_rejected_list(self) -> list: + return self._parent_segmentation.get_rejected_list() + + def get_traces( + self, + roi_ids: Optional[ArrayLike[int]] = None, + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, + name: str = "raw", + ) -> np.ndarray: + start_frame = min(start_frame, self._num_frames) + end_frame = min(end_frame, self._num_frames) + return self._parent_segmentation.get_trace( + roi_ids=roi_ids, + start_frame=start_frame + self._start_frame, + end_frame=end_frame + self._start_frame, + name=name, + ) + + def get_traces_dict(self): + if not isinstance(self._parent_segmentation._roi_response_raw, np.ndarray): + raise NotImplementedError( + "The `get_traces_dict` method for SubFrameSegementations does not yet support underlying trace types " + "other than numpy arrays (which includes memory maps)." + ) + + def _safe_subtrace(trace: Optional[np.ndarray], start_frame: int, end_frame: int): + if trace is None: + return None + return trace[:, start_frame:end_frame] + + return dict( + raw=_safe_subtrace( + trace=self._parent_segmentation._roi_response_raw, start=self._start_frame, end=self._end_frame + ), + dff=_safe_subtrace( + trace=self._parent_segmentation._roi_response_dff, start=self._start_frame, end=self._end_frame + ), + neuropil=_safe_subtrace( + trace=self._parent_segmentation._roi_response_neuropil, start=self._start_frame, end=self._end_frame + ), + deconvolved=_safe_subtrace( + trace=self._parent_segmentation._roi_response_deconvolved, start=self._start_frame, end=self._end_frame + ), + ) + + def get_image_size(self) -> Tuple[int, int]: + return tuple(self._parent_imaging.get_image_size()) + + def get_num_frames(self) -> int: + return self._num_frames + + def get_num_rois(self): + return self._parent_segmentation.get_num_rois() + + def get_images_dict(self) -> dict: + return self._parent_segmentation.get_images_dict() + + def get_image(self, name="correlation"): + return self._parent_segmentation.get_image(name=name) + + def get_sampling_frequency(self) -> float: + return self._parent_segmentation.get_sampling_frequency() + + def get_channel_names(self) -> list: + return self._parent_segmentation.get_channel_names() + + def get_num_channels(self) -> int: + return self._parent_segmentation.get_num_channels() + + def get_num_planes(self): + return self._parent_segmentation.get_num_planes() diff --git a/src/roiextractors/testing.py b/src/roiextractors/testing.py index cf1b9470..4ddf6247 100644 --- a/src/roiextractors/testing.py +++ b/src/roiextractors/testing.py @@ -144,6 +144,7 @@ def generate_dummy_segmentation_extractor( accepted_lst=accepeted_list, rejected_list=rejected_list, movie_dims=movie_dims, + channel_names=["channel_num_0"], ) return dummy_segmentation_extractor diff --git a/tests/test_internals/test_frame_slice_segmentation.py b/tests/test_internals/test_frame_slice_segmentation.py new file mode 100644 index 00000000..28608416 --- /dev/null +++ b/tests/test_internals/test_frame_slice_segmentation.py @@ -0,0 +1,101 @@ +import unittest + +import numpy as np +from hdmf.testing import TestCase +from numpy.testing import assert_array_equal +from parameterized import parameterized, param + +from roiextractors.testing import generate_dummy_segmentation_extractor + + +def test_frame_slicing_segmentation_times(): + num_frames = 10 + timestamp_shift = 7.1 + times = np.array(range(num_frames)) + timestamp_shift + start_frame, end_frame = 2, 7 + + toy_segmentation_example = generate_dummy_segmentation_extractor( + num_frames=num_frames, num_rows=5, num_columns=4, num_channels=1 + ) + toy_segmentation_example.set_times(times=times) + + frame_sliced_segmentation = toy_segmentation_example.frame_slice(start_frame=start_frame, end_frame=end_frame) + assert_array_equal( + frame_sliced_segmentation.frame_to_time( + frames=np.array([idx for idx in range(frame_sliced_segmentation.get_num_frames())]) + ), + times[start_frame:end_frame], + ) + + +def segmentation_name_function(testcase_function, param_number, param): + return f"{testcase_function.__name__}_{param_number}_{parameterized.to_safe_name(param.kwargs['name'].__name__)}" + + +class TestFrameSlicesegmentation(TestCase): + @classmethod + def setUpClass(cls): + cls.toy_segmentation_example = generate_dummy_segmentation_extractor(num_frames=10, num_rows=5, num_columns=4) + cls.frame_sliced_segmentation = cls.toy_segmentation_example.frame_slice(start_frame=2, end_frame=7) + + def test_get_image_size(self): + assert self.frame_sliced_segmentation.get_image_size() == (5, 4) + + def test_get_num_planes(self): + return self.frame_sliced_segmentation.get_num_planes() == 1 + + def test_get_num_frames(self): + assert self.frame_sliced_segmentation.get_num_frames() == 5 + + def test_get_sampling_frequency(self): + assert self.frame_sliced_segmentation.get_sampling_frequency() == 30.0 + + def test_get_channel_names(self): + assert self.frame_sliced_segmentation.get_channel_names() == ["channel_num_0"] + + def test_get_num_channels(self): + assert self.frame_sliced_segmentation.get_num_channels() == 1 + + def test_get_num_rois(self): + assert self.frame_sliced_segmentation.get_num_rois() == 30 + + def test_get_accepted_list(self): + return assert_array_equal( + x=self.frame_sliced_segmentation.get_accepted_list(), y=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + ) + + def test_get_rejected_list(self): + return assert_array_equal( + x=self.frame_sliced_segmentation.get_rejected_list(), + y=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], + ) + + @parameterized.expand( + [param(name="raw"), param(name="dff"), param(name="neuropil"), param(name="deconvolved")], + name_func=segmentation_name_function, + ) + def test_get_traces(self, name: str): + assert_array_equal( + x=self.frame_sliced_segmentation.get_traces(name=name), + y=self.toy_segmentation_example.get_traces(start_frame=2, end_frame=7, name=name), + ) + + def test_get_traces_dict(self): + true_dict = self.toy_segmentation_example.get_traces_dict() + for key in true_dict: + true_dict[key] = true_dict[key][:, 2:7] + self.assertContainerEqual(container1=self.frame_sliced_segmentation.get_traces_dict(), container2=true_dict) + + def test_get_images_dict(self): + self.assertContainerEqual( + container1=self.frame_sliced_segmentation.get_images_dict(), + container2=self.toy_segmentation_example.get_images_dict(), + ) + + @parameterized.expand([param(name="mean"), param(name="correlation")], name_func=segmentation_name_function) + def test_get_image(self, name: str): + assert self.frame_sliced_segmentation.get_image(name=name) == self.toy_segmentation_example.get_image(name=name) + + +if __name__ == "__main__": + unittest.main() From f1ffaf42646fb6bf5872ac05d1dcde779f52e28d Mon Sep 17 00:00:00 2001 From: Cody Baker Date: Sun, 28 Aug 2022 13:52:54 -0400 Subject: [PATCH 2/6] orientation corrections --- src/roiextractors/segmentationextractor.py | 16 +++++++++------- .../test_frame_slice_segmentation.py | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/roiextractors/segmentationextractor.py b/src/roiextractors/segmentationextractor.py index e0b0a0ec..54e37102 100644 --- a/src/roiextractors/segmentationextractor.py +++ b/src/roiextractors/segmentationextractor.py @@ -385,16 +385,18 @@ def get_traces( ) def get_traces_dict(self): - if not isinstance(self._parent_segmentation._roi_response_raw, np.ndarray): - raise NotImplementedError( - "The `get_traces_dict` method for SubFrameSegementations does not yet support underlying trace types " - "other than numpy arrays (which includes memory maps)." - ) + for trace in self._parent_segmentation.get_traces_dict().values(): + if trace is not None and len(trace.shape) > 0: + if not isinstance(trace, np.ndarray): + raise NotImplementedError( + "The `get_traces_dict` method for SubFrameSegementations does not yet support underlying trace " + "types other than numpy arrays (which includes memory maps)." + ) def _safe_subtrace(trace: Optional[np.ndarray], start_frame: int, end_frame: int): - if trace is None: + if trace is None and len(trace.shape) > 0: return None - return trace[:, start_frame:end_frame] + return trace[start_frame:end_frame, :] return dict( raw=_safe_subtrace( diff --git a/tests/test_internals/test_frame_slice_segmentation.py b/tests/test_internals/test_frame_slice_segmentation.py index 28608416..9d0be522 100644 --- a/tests/test_internals/test_frame_slice_segmentation.py +++ b/tests/test_internals/test_frame_slice_segmentation.py @@ -83,7 +83,7 @@ def test_get_traces(self, name: str): def test_get_traces_dict(self): true_dict = self.toy_segmentation_example.get_traces_dict() for key in true_dict: - true_dict[key] = true_dict[key][:, 2:7] + true_dict[key] = true_dict[key][2:7, :] self.assertContainerEqual(container1=self.frame_sliced_segmentation.get_traces_dict(), container2=true_dict) def test_get_images_dict(self): From 9e955248d2532c74522cf896b7c27367304041d9 Mon Sep 17 00:00:00 2001 From: CodyCBakerPhD Date: Sun, 28 Aug 2022 19:04:00 +0000 Subject: [PATCH 3/6] debug --- .../numpyextractors/numpyextractors.py | 2 +- src/roiextractors/segmentationextractor.py | 29 ++++++++++--------- .../test_frame_slice_segmentation.py | 27 ++++++++--------- 3 files changed, 28 insertions(+), 30 deletions(-) diff --git a/src/roiextractors/extractors/numpyextractors/numpyextractors.py b/src/roiextractors/extractors/numpyextractors/numpyextractors.py index 20507d6f..50db8104 100644 --- a/src/roiextractors/extractors/numpyextractors/numpyextractors.py +++ b/src/roiextractors/extractors/numpyextractors/numpyextractors.py @@ -4,7 +4,7 @@ import numpy as np from ...extraction_tools import PathType, FloatType, ArrayType -from ...extraction_tools import check_get_frames_args, get_video_shape +from ...extraction_tools import get_video_shape from ...imagingextractor import ImagingExtractor from ...segmentationextractor import SegmentationExtractor diff --git a/src/roiextractors/segmentationextractor.py b/src/roiextractors/segmentationextractor.py index 54e37102..5903f3df 100644 --- a/src/roiextractors/segmentationextractor.py +++ b/src/roiextractors/segmentationextractor.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod -from typing import Union, Optional, Tuple +from typing import Union, Optional, Tuple, Iterable import numpy as np -from numpy.typing import ArrayLike from .extraction_tools import ArrayType, IntType, FloatType from .extraction_tools import _pixel_mask_extractor @@ -160,9 +159,9 @@ def get_image_size(self) -> ArrayType: """ pass - def frame_slice(self, start_frame, end_frame): + def frame_slice(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None): """Return a new SegmentationExtractor ranging from the start_frame to the end_frame.""" - return FrameSliceSegmentationExtractor(parent_imaging=self, start_frame=start_frame, end_frame=end_frame) + return FrameSliceSegmentationExtractor(parent_segmentation=self, start_frame=start_frame, end_frame=end_frame) def get_traces(self, roi_ids=None, start_frame=None, end_frame=None, name="raw"): """ @@ -290,7 +289,7 @@ def set_times(self, times: ArrayType): assert len(times) == self.get_num_frames(), "'times' should have the same length of the number of frames!" self._times = np.array(times, dtype=np.float64) - def frame_to_time(self, frame_indices: Union[IntType, ArrayType]) -> Union[FloatType, ArrayType]: + def frame_to_time(self, frames: Union[IntType, ArrayType]) -> Union[FloatType, ArrayType]: """Returns the timing of frames in unit of seconds. Parameters @@ -304,9 +303,9 @@ def frame_to_time(self, frame_indices: Union[IntType, ArrayType]) -> Union[Float The corresponding times in seconds """ if self._times is None: - return np.round(frame_indices / self.get_sampling_frequency(), 6) + return np.round(frames / self.get_sampling_frequency(), 6) else: - return self._times[frame_indices] + return self._times[frames] @staticmethod def write_segmentation(segmentation_extractor, save_path, overwrite=False): @@ -361,6 +360,8 @@ def __init__( assert end_frame > start_frame, "'start_frame' must be smaller than 'end_frame'!" super().__init__() + if getattr(self._parent_segmentation, "_times") is not None: + self._times = self._parent_segmentation._times[start_frame:end_frame] def get_accepted_list(self) -> list: return self._parent_segmentation.get_accepted_list() @@ -370,14 +371,14 @@ def get_rejected_list(self) -> list: def get_traces( self, - roi_ids: Optional[ArrayLike[int]] = None, + roi_ids: Optional[Iterable[int]] = None, start_frame: Optional[int] = None, end_frame: Optional[int] = None, name: str = "raw", ) -> np.ndarray: - start_frame = min(start_frame, self._num_frames) - end_frame = min(end_frame, self._num_frames) - return self._parent_segmentation.get_trace( + start_frame = min(start_frame or 0, self._num_frames) + end_frame = min(end_frame or self._num_frames, self._num_frames) + return self._parent_segmentation.get_traces( roi_ids=roi_ids, start_frame=start_frame + self._start_frame, end_frame=end_frame + self._start_frame, @@ -393,10 +394,10 @@ def get_traces_dict(self): "types other than numpy arrays (which includes memory maps)." ) - def _safe_subtrace(trace: Optional[np.ndarray], start_frame: int, end_frame: int): + def _safe_subtrace(trace: Optional[np.ndarray], start: int, end: int): if trace is None and len(trace.shape) > 0: return None - return trace[start_frame:end_frame, :] + return trace[start:end, :] return dict( raw=_safe_subtrace( @@ -414,7 +415,7 @@ def _safe_subtrace(trace: Optional[np.ndarray], start_frame: int, end_frame: int ) def get_image_size(self) -> Tuple[int, int]: - return tuple(self._parent_imaging.get_image_size()) + return tuple(self._parent_segmentation.get_image_size()) def get_num_frames(self) -> int: return self._num_frames diff --git a/tests/test_internals/test_frame_slice_segmentation.py b/tests/test_internals/test_frame_slice_segmentation.py index 9d0be522..e2d7ca68 100644 --- a/tests/test_internals/test_frame_slice_segmentation.py +++ b/tests/test_internals/test_frame_slice_segmentation.py @@ -14,9 +14,7 @@ def test_frame_slicing_segmentation_times(): times = np.array(range(num_frames)) + timestamp_shift start_frame, end_frame = 2, 7 - toy_segmentation_example = generate_dummy_segmentation_extractor( - num_frames=num_frames, num_rows=5, num_columns=4, num_channels=1 - ) + toy_segmentation_example = generate_dummy_segmentation_extractor(num_frames=num_frames, num_rows=5, num_columns=4) toy_segmentation_example.set_times(times=times) frame_sliced_segmentation = toy_segmentation_example.frame_slice(start_frame=start_frame, end_frame=end_frame) @@ -29,13 +27,13 @@ def test_frame_slicing_segmentation_times(): def segmentation_name_function(testcase_function, param_number, param): - return f"{testcase_function.__name__}_{param_number}_{parameterized.to_safe_name(param.kwargs['name'].__name__)}" + return f"{testcase_function.__name__}_{param_number}_{parameterized.to_safe_name(param.kwargs['name'])}" class TestFrameSlicesegmentation(TestCase): @classmethod def setUpClass(cls): - cls.toy_segmentation_example = generate_dummy_segmentation_extractor(num_frames=10, num_rows=5, num_columns=4) + cls.toy_segmentation_example = generate_dummy_segmentation_extractor(num_frames=15, num_rows=5, num_columns=4) cls.frame_sliced_segmentation = cls.toy_segmentation_example.frame_slice(start_frame=2, end_frame=7) def test_get_image_size(self): @@ -57,7 +55,7 @@ def test_get_num_channels(self): assert self.frame_sliced_segmentation.get_num_channels() == 1 def test_get_num_rois(self): - assert self.frame_sliced_segmentation.get_num_rois() == 30 + assert self.frame_sliced_segmentation.get_num_rois() == 10 def test_get_accepted_list(self): return assert_array_equal( @@ -65,10 +63,7 @@ def test_get_accepted_list(self): ) def test_get_rejected_list(self): - return assert_array_equal( - x=self.frame_sliced_segmentation.get_rejected_list(), - y=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], - ) + return assert_array_equal(x=self.frame_sliced_segmentation.get_rejected_list(), y=[]) @parameterized.expand( [param(name="raw"), param(name="dff"), param(name="neuropil"), param(name="deconvolved")], @@ -84,17 +79,19 @@ def test_get_traces_dict(self): true_dict = self.toy_segmentation_example.get_traces_dict() for key in true_dict: true_dict[key] = true_dict[key][2:7, :] - self.assertContainerEqual(container1=self.frame_sliced_segmentation.get_traces_dict(), container2=true_dict) + self.assertCountEqual(first=self.frame_sliced_segmentation.get_traces_dict(), second=true_dict) def test_get_images_dict(self): - self.assertContainerEqual( - container1=self.frame_sliced_segmentation.get_images_dict(), - container2=self.toy_segmentation_example.get_images_dict(), + self.assertCountEqual( + first=self.frame_sliced_segmentation.get_images_dict(), + second=self.toy_segmentation_example.get_images_dict(), ) @parameterized.expand([param(name="mean"), param(name="correlation")], name_func=segmentation_name_function) def test_get_image(self, name: str): - assert self.frame_sliced_segmentation.get_image(name=name) == self.toy_segmentation_example.get_image(name=name) + assert_array_equal( + x=self.frame_sliced_segmentation.get_image(name=name), y=self.toy_segmentation_example.get_image(name=name) + ) if __name__ == "__main__": From 67455ca6f14e83631db9ac42d67662dcca7ebc2d Mon Sep 17 00:00:00 2001 From: CodyCBakerPhD Date: Sun, 28 Aug 2022 19:08:38 +0000 Subject: [PATCH 4/6] debug --- tests/test_internals/test_testing_tools.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/test_internals/test_testing_tools.py b/tests/test_internals/test_testing_tools.py index f5c7e0af..bb8603d4 100644 --- a/tests/test_internals/test_testing_tools.py +++ b/tests/test_internals/test_testing_tools.py @@ -33,8 +33,8 @@ def test_default_values(self): # Test frame_to_time times = np.round(np.arange(self.num_frames) / self.sampling_frequency, 6) - assert_array_equal(segmentation_extractor.frame_to_time(frame_indices=np.arange(self.num_frames)), times) - self.assertEqual(segmentation_extractor.frame_to_time(frame_indices=8), times[8]) + assert_array_equal(segmentation_extractor.frame_to_time(frames=np.arange(self.num_frames)), times) + self.assertEqual(segmentation_extractor.frame_to_time(frames=8), times[8]) # Test image masks assert segmentation_extractor.get_roi_image_masks().shape == (self.num_rows, self.num_columns, self.num_rois) @@ -118,10 +118,8 @@ def test_frame_to_time_no_sampling_frequency(self): times = np.arange(self.num_frames) / self.sampling_frequency segmentation_extractor._times = times - self.assertEqual(segmentation_extractor.frame_to_time(frame_indices=2), times[2]) + self.assertEqual(segmentation_extractor.frame_to_time(frames=2), times[2]) assert_array_equal( - segmentation_extractor.frame_to_time( - frame_indices=np.arange(self.num_frames), - ), + segmentation_extractor.frame_to_time(frames=np.arange(self.num_frames)), times, ) From de255d6bcad8dce2feab766db332d4c8da28878d Mon Sep 17 00:00:00 2001 From: Cody Baker <51133164+CodyCBakerPhD@users.noreply.github.com> Date: Sun, 28 Aug 2022 16:12:57 -0400 Subject: [PATCH 5/6] Update CHANGELOG.md --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index be0a2965..2c646512 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,8 @@ extractor depending on the version of the file. [PR #170](https://github.com/cat * Implemented a more efficient case of the base `ImagingExtractor.get_frames` through `get_video` when the indices are contiguous. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195) * Removed `max_frame` check on `MultiImagingExtractor.get_video()` to adhere to upper-bound slicing semantics. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195) * Improved the `MultiImagingExtractor.get_video()` to no longer rely on `get_frames`. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195) -* Adding `dtype` consistency check across `MultiImaging` components as well as a direct override method. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195) +* Added `dtype` consistency check across `MultiImaging` components as well as a direct override method. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195) +* Added the `FrameSliceSegmentationExtractor` class and corresponding `Segmentation.frame_slice(...)` method. [PR #201](https://github.com/catalystneuro/neuroconv/pull/201) ### Fixes * Fixed the reference to the proper `mov_field` in `Hdf5ImagingExtractor`. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195) From 5cff4ac9942e06f6fb70722c4c7fe47c5c4fc550 Mon Sep 17 00:00:00 2001 From: CodyCBakerPhD Date: Tue, 30 Aug 2022 18:27:20 +0000 Subject: [PATCH 6/6] PR suggestions --- src/roiextractors/segmentationextractor.py | 33 ++++--------------- .../test_frame_slice_segmentation.py | 13 ++++++-- 2 files changed, 17 insertions(+), 29 deletions(-) diff --git a/src/roiextractors/segmentationextractor.py b/src/roiextractors/segmentationextractor.py index 5903f3df..9ae49cd5 100644 --- a/src/roiextractors/segmentationextractor.py +++ b/src/roiextractors/segmentationextractor.py @@ -386,33 +386,12 @@ def get_traces( ) def get_traces_dict(self): - for trace in self._parent_segmentation.get_traces_dict().values(): - if trace is not None and len(trace.shape) > 0: - if not isinstance(trace, np.ndarray): - raise NotImplementedError( - "The `get_traces_dict` method for SubFrameSegementations does not yet support underlying trace " - "types other than numpy arrays (which includes memory maps)." - ) - - def _safe_subtrace(trace: Optional[np.ndarray], start: int, end: int): - if trace is None and len(trace.shape) > 0: - return None - return trace[start:end, :] - - return dict( - raw=_safe_subtrace( - trace=self._parent_segmentation._roi_response_raw, start=self._start_frame, end=self._end_frame - ), - dff=_safe_subtrace( - trace=self._parent_segmentation._roi_response_dff, start=self._start_frame, end=self._end_frame - ), - neuropil=_safe_subtrace( - trace=self._parent_segmentation._roi_response_neuropil, start=self._start_frame, end=self._end_frame - ), - deconvolved=_safe_subtrace( - trace=self._parent_segmentation._roi_response_deconvolved, start=self._start_frame, end=self._end_frame - ), - ) + return { + trace_name: self._parent_segmentation.get_traces( + start_frame=self._start_frame, end_frame=self._end_frame, name=trace_name + ) + for trace_name, trace in self._parent_segmentation.get_traces_dict().items() + } def get_image_size(self) -> Tuple[int, int]: return tuple(self._parent_segmentation.get_image_size()) diff --git a/tests/test_internals/test_frame_slice_segmentation.py b/tests/test_internals/test_frame_slice_segmentation.py index e2d7ca68..4a0ac59b 100644 --- a/tests/test_internals/test_frame_slice_segmentation.py +++ b/tests/test_internals/test_frame_slice_segmentation.py @@ -30,7 +30,7 @@ def segmentation_name_function(testcase_function, param_number, param): return f"{testcase_function.__name__}_{param_number}_{parameterized.to_safe_name(param.kwargs['name'])}" -class TestFrameSlicesegmentation(TestCase): +class BaseTestFrameSlicesegmentation(TestCase): @classmethod def setUpClass(cls): cls.toy_segmentation_example = generate_dummy_segmentation_extractor(num_frames=15, num_rows=5, num_columns=4) @@ -78,7 +78,7 @@ def test_get_traces(self, name: str): def test_get_traces_dict(self): true_dict = self.toy_segmentation_example.get_traces_dict() for key in true_dict: - true_dict[key] = true_dict[key][2:7, :] + true_dict[key] = true_dict[key][2:7, :] if true_dict[key] is not None else true_dict[key] self.assertCountEqual(first=self.frame_sliced_segmentation.get_traces_dict(), second=true_dict) def test_get_images_dict(self): @@ -94,5 +94,14 @@ def test_get_image(self, name: str): ) +class TestMissingTraceFrameSlicesegmentation(BaseTestFrameSlicesegmentation): + @classmethod + def setUpClass(cls): + cls.toy_segmentation_example = generate_dummy_segmentation_extractor( + num_frames=15, num_rows=5, num_columns=4, has_dff_signal=False + ) + cls.frame_sliced_segmentation = cls.toy_segmentation_example.frame_slice(start_frame=2, end_frame=7) + + if __name__ == "__main__": unittest.main()