Skip to content

Commit

Permalink
Merge pull request #201 from catalystneuro/subframe_segmentation
Browse files Browse the repository at this point in the history
[Merge after #200]: Subframe segmentation
  • Loading branch information
CodyCBakerPhD authored Aug 31, 2022
2 parents b71f17d + 5cff4ac commit f8c51f0
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 12 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ extractor depending on the version of the file. [PR #170](https://github.com/cat
* Implemented a more efficient case of the base `ImagingExtractor.get_frames` through `get_video` when the indices are contiguous. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195)
* Removed `max_frame` check on `MultiImagingExtractor.get_video()` to adhere to upper-bound slicing semantics. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195)
* Improved the `MultiImagingExtractor.get_video()` to no longer rely on `get_frames`. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195)
* Adding `dtype` consistency check across `MultiImaging` components as well as a direct override method. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195)
* Added `dtype` consistency check across `MultiImaging` components as well as a direct override method. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195)
* Added the `FrameSliceSegmentationExtractor` class and corresponding `Segmentation.frame_slice(...)` method. [PR #201](https://github.com/catalystneuro/neuroconv/pull/201)

### Fixes
* Fixed the reference to the proper `mov_field` in `Hdf5ImagingExtractor`. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from ...extraction_tools import PathType, FloatType, ArrayType
from ...extraction_tools import check_get_frames_args, get_video_shape
from ...extraction_tools import get_video_shape
from ...imagingextractor import ImagingExtractor
from ...segmentationextractor import SegmentationExtractor

Expand Down
108 changes: 104 additions & 4 deletions src/roiextractors/segmentationextractor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Union
from typing import Union, Optional, Tuple, Iterable

import numpy as np

Expand Down Expand Up @@ -159,6 +159,10 @@ def get_image_size(self) -> ArrayType:
"""
pass

def frame_slice(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None):
"""Return a new SegmentationExtractor ranging from the start_frame to the end_frame."""
return FrameSliceSegmentationExtractor(parent_segmentation=self, start_frame=start_frame, end_frame=end_frame)

def get_traces(self, roi_ids=None, start_frame=None, end_frame=None, name="raw"):
"""
Return RoiResponseSeries
Expand Down Expand Up @@ -285,7 +289,7 @@ def set_times(self, times: ArrayType):
assert len(times) == self.get_num_frames(), "'times' should have the same length of the number of frames!"
self._times = np.array(times, dtype=np.float64)

def frame_to_time(self, frame_indices: Union[IntType, ArrayType]) -> Union[FloatType, ArrayType]:
def frame_to_time(self, frames: Union[IntType, ArrayType]) -> Union[FloatType, ArrayType]:
"""Returns the timing of frames in unit of seconds.
Parameters
Expand All @@ -299,9 +303,9 @@ def frame_to_time(self, frame_indices: Union[IntType, ArrayType]) -> Union[Float
The corresponding times in seconds
"""
if self._times is None:
return np.round(frame_indices / self.get_sampling_frequency(), 6)
return np.round(frames / self.get_sampling_frequency(), 6)
else:
return self._times[frame_indices]
return self._times[frames]

@staticmethod
def write_segmentation(segmentation_extractor, save_path, overwrite=False):
Expand All @@ -319,3 +323,99 @@ def write_segmentation(segmentation_extractor, save_path, overwrite=False):
If True, the file is overwritten if existing (default False)
"""
raise NotImplementedError


class FrameSliceSegmentationExtractor(SegmentationExtractor):
"""
Class to get a lazy frame slice.
Do not use this class directly but use `.frame_slice(...)`
"""

extractor_name = "FrameSliceSegmentationExtractor"
is_writable = True

def __init__(
self,
parent_segmentation: SegmentationExtractor,
start_frame: Optional[int] = None,
end_frame: Optional[int] = None,
):
self._parent_segmentation = parent_segmentation
self._start_frame = start_frame or 0
self._end_frame = end_frame or self._parent_segmentation.get_num_frames()
self._num_frames = self._end_frame - self._start_frame

self._image_masks = self._parent_segmentation._image_masks

parent_size = self._parent_segmentation.get_num_frames()
if start_frame is None:
start_frame = 0
else:
assert 0 <= start_frame < parent_size
if end_frame is None:
end_frame = parent_size
else:
assert 0 < end_frame <= parent_size
assert end_frame > start_frame, "'start_frame' must be smaller than 'end_frame'!"

super().__init__()
if getattr(self._parent_segmentation, "_times") is not None:
self._times = self._parent_segmentation._times[start_frame:end_frame]

def get_accepted_list(self) -> list:
return self._parent_segmentation.get_accepted_list()

def get_rejected_list(self) -> list:
return self._parent_segmentation.get_rejected_list()

def get_traces(
self,
roi_ids: Optional[Iterable[int]] = None,
start_frame: Optional[int] = None,
end_frame: Optional[int] = None,
name: str = "raw",
) -> np.ndarray:
start_frame = min(start_frame or 0, self._num_frames)
end_frame = min(end_frame or self._num_frames, self._num_frames)
return self._parent_segmentation.get_traces(
roi_ids=roi_ids,
start_frame=start_frame + self._start_frame,
end_frame=end_frame + self._start_frame,
name=name,
)

def get_traces_dict(self):
return {
trace_name: self._parent_segmentation.get_traces(
start_frame=self._start_frame, end_frame=self._end_frame, name=trace_name
)
for trace_name, trace in self._parent_segmentation.get_traces_dict().items()
}

def get_image_size(self) -> Tuple[int, int]:
return tuple(self._parent_segmentation.get_image_size())

def get_num_frames(self) -> int:
return self._num_frames

def get_num_rois(self):
return self._parent_segmentation.get_num_rois()

def get_images_dict(self) -> dict:
return self._parent_segmentation.get_images_dict()

def get_image(self, name="correlation"):
return self._parent_segmentation.get_image(name=name)

def get_sampling_frequency(self) -> float:
return self._parent_segmentation.get_sampling_frequency()

def get_channel_names(self) -> list:
return self._parent_segmentation.get_channel_names()

def get_num_channels(self) -> int:
return self._parent_segmentation.get_num_channels()

def get_num_planes(self):
return self._parent_segmentation.get_num_planes()
1 change: 1 addition & 0 deletions src/roiextractors/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def generate_dummy_segmentation_extractor(
accepted_lst=accepeted_list,
rejected_list=rejected_list,
movie_dims=movie_dims,
channel_names=["channel_num_0"],
)

return dummy_segmentation_extractor
Expand Down
107 changes: 107 additions & 0 deletions tests/test_internals/test_frame_slice_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import unittest

import numpy as np
from hdmf.testing import TestCase
from numpy.testing import assert_array_equal
from parameterized import parameterized, param

from roiextractors.testing import generate_dummy_segmentation_extractor


def test_frame_slicing_segmentation_times():
num_frames = 10
timestamp_shift = 7.1
times = np.array(range(num_frames)) + timestamp_shift
start_frame, end_frame = 2, 7

toy_segmentation_example = generate_dummy_segmentation_extractor(num_frames=num_frames, num_rows=5, num_columns=4)
toy_segmentation_example.set_times(times=times)

frame_sliced_segmentation = toy_segmentation_example.frame_slice(start_frame=start_frame, end_frame=end_frame)
assert_array_equal(
frame_sliced_segmentation.frame_to_time(
frames=np.array([idx for idx in range(frame_sliced_segmentation.get_num_frames())])
),
times[start_frame:end_frame],
)


def segmentation_name_function(testcase_function, param_number, param):
return f"{testcase_function.__name__}_{param_number}_{parameterized.to_safe_name(param.kwargs['name'])}"


class BaseTestFrameSlicesegmentation(TestCase):
@classmethod
def setUpClass(cls):
cls.toy_segmentation_example = generate_dummy_segmentation_extractor(num_frames=15, num_rows=5, num_columns=4)
cls.frame_sliced_segmentation = cls.toy_segmentation_example.frame_slice(start_frame=2, end_frame=7)

def test_get_image_size(self):
assert self.frame_sliced_segmentation.get_image_size() == (5, 4)

def test_get_num_planes(self):
return self.frame_sliced_segmentation.get_num_planes() == 1

def test_get_num_frames(self):
assert self.frame_sliced_segmentation.get_num_frames() == 5

def test_get_sampling_frequency(self):
assert self.frame_sliced_segmentation.get_sampling_frequency() == 30.0

def test_get_channel_names(self):
assert self.frame_sliced_segmentation.get_channel_names() == ["channel_num_0"]

def test_get_num_channels(self):
assert self.frame_sliced_segmentation.get_num_channels() == 1

def test_get_num_rois(self):
assert self.frame_sliced_segmentation.get_num_rois() == 10

def test_get_accepted_list(self):
return assert_array_equal(
x=self.frame_sliced_segmentation.get_accepted_list(), y=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
)

def test_get_rejected_list(self):
return assert_array_equal(x=self.frame_sliced_segmentation.get_rejected_list(), y=[])

@parameterized.expand(
[param(name="raw"), param(name="dff"), param(name="neuropil"), param(name="deconvolved")],
name_func=segmentation_name_function,
)
def test_get_traces(self, name: str):
assert_array_equal(
x=self.frame_sliced_segmentation.get_traces(name=name),
y=self.toy_segmentation_example.get_traces(start_frame=2, end_frame=7, name=name),
)

def test_get_traces_dict(self):
true_dict = self.toy_segmentation_example.get_traces_dict()
for key in true_dict:
true_dict[key] = true_dict[key][2:7, :] if true_dict[key] is not None else true_dict[key]
self.assertCountEqual(first=self.frame_sliced_segmentation.get_traces_dict(), second=true_dict)

def test_get_images_dict(self):
self.assertCountEqual(
first=self.frame_sliced_segmentation.get_images_dict(),
second=self.toy_segmentation_example.get_images_dict(),
)

@parameterized.expand([param(name="mean"), param(name="correlation")], name_func=segmentation_name_function)
def test_get_image(self, name: str):
assert_array_equal(
x=self.frame_sliced_segmentation.get_image(name=name), y=self.toy_segmentation_example.get_image(name=name)
)


class TestMissingTraceFrameSlicesegmentation(BaseTestFrameSlicesegmentation):
@classmethod
def setUpClass(cls):
cls.toy_segmentation_example = generate_dummy_segmentation_extractor(
num_frames=15, num_rows=5, num_columns=4, has_dff_signal=False
)
cls.frame_sliced_segmentation = cls.toy_segmentation_example.frame_slice(start_frame=2, end_frame=7)


if __name__ == "__main__":
unittest.main()
10 changes: 4 additions & 6 deletions tests/test_internals/test_testing_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def test_default_values(self):

# Test frame_to_time
times = np.round(np.arange(self.num_frames) / self.sampling_frequency, 6)
assert_array_equal(segmentation_extractor.frame_to_time(frame_indices=np.arange(self.num_frames)), times)
self.assertEqual(segmentation_extractor.frame_to_time(frame_indices=8), times[8])
assert_array_equal(segmentation_extractor.frame_to_time(frames=np.arange(self.num_frames)), times)
self.assertEqual(segmentation_extractor.frame_to_time(frames=8), times[8])

# Test image masks
assert segmentation_extractor.get_roi_image_masks().shape == (self.num_rows, self.num_columns, self.num_rois)
Expand Down Expand Up @@ -118,10 +118,8 @@ def test_frame_to_time_no_sampling_frequency(self):
times = np.arange(self.num_frames) / self.sampling_frequency
segmentation_extractor._times = times

self.assertEqual(segmentation_extractor.frame_to_time(frame_indices=2), times[2])
self.assertEqual(segmentation_extractor.frame_to_time(frames=2), times[2])
assert_array_equal(
segmentation_extractor.frame_to_time(
frame_indices=np.arange(self.num_frames),
),
segmentation_extractor.frame_to_time(frames=np.arange(self.num_frames)),
times,
)

0 comments on commit f8c51f0

Please sign in to comment.