Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merge after #200]: Subframe segmentation #201

Merged
merged 9 commits into from
Aug 31, 2022
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ extractor depending on the version of the file. [PR #170](https://github.com/cat
* Implemented a more efficient case of the base `ImagingExtractor.get_frames` through `get_video` when the indices are contiguous. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195)
* Removed `max_frame` check on `MultiImagingExtractor.get_video()` to adhere to upper-bound slicing semantics. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195)
* Improved the `MultiImagingExtractor.get_video()` to no longer rely on `get_frames`. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195)
* Adding `dtype` consistency check across `MultiImaging` components as well as a direct override method. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195)
* Added `dtype` consistency check across `MultiImaging` components as well as a direct override method. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195)
* Added the `FrameSliceSegmentationExtractor` class and corresponding `Segmentation.frame_slice(...)` method. [PR #201](https://github.com/catalystneuro/neuroconv/pull/201)

### Fixes
* Fixed the reference to the proper `mov_field` in `Hdf5ImagingExtractor`. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from ...extraction_tools import PathType, FloatType, ArrayType
from ...extraction_tools import check_get_frames_args, get_video_shape
from ...extraction_tools import get_video_shape
from ...imagingextractor import ImagingExtractor
from ...segmentationextractor import SegmentationExtractor

Expand Down
108 changes: 104 additions & 4 deletions src/roiextractors/segmentationextractor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Union
from typing import Union, Optional, Tuple, Iterable

import numpy as np

Expand Down Expand Up @@ -159,6 +159,10 @@ def get_image_size(self) -> ArrayType:
"""
pass

def frame_slice(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None):
"""Return a new SegmentationExtractor ranging from the start_frame to the end_frame."""
return FrameSliceSegmentationExtractor(parent_segmentation=self, start_frame=start_frame, end_frame=end_frame)

def get_traces(self, roi_ids=None, start_frame=None, end_frame=None, name="raw"):
"""
Return RoiResponseSeries
Expand Down Expand Up @@ -285,7 +289,7 @@ def set_times(self, times: ArrayType):
assert len(times) == self.get_num_frames(), "'times' should have the same length of the number of frames!"
self._times = np.array(times, dtype=np.float64)

def frame_to_time(self, frame_indices: Union[IntType, ArrayType]) -> Union[FloatType, ArrayType]:
def frame_to_time(self, frames: Union[IntType, ArrayType]) -> Union[FloatType, ArrayType]:
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
"""Returns the timing of frames in unit of seconds.

Parameters
Expand All @@ -299,9 +303,9 @@ def frame_to_time(self, frame_indices: Union[IntType, ArrayType]) -> Union[Float
The corresponding times in seconds
"""
if self._times is None:
return np.round(frame_indices / self.get_sampling_frequency(), 6)
return np.round(frames / self.get_sampling_frequency(), 6)
else:
return self._times[frame_indices]
return self._times[frames]

@staticmethod
def write_segmentation(segmentation_extractor, save_path, overwrite=False):
Expand All @@ -319,3 +323,99 @@ def write_segmentation(segmentation_extractor, save_path, overwrite=False):
If True, the file is overwritten if existing (default False)
"""
raise NotImplementedError


class FrameSliceSegmentationExtractor(SegmentationExtractor):
"""
Class to get a lazy frame slice.

Do not use this class directly but use `.frame_slice(...)`
"""

extractor_name = "FrameSliceSegmentationExtractor"
is_writable = True

def __init__(
self,
parent_segmentation: SegmentationExtractor,
start_frame: Optional[int] = None,
end_frame: Optional[int] = None,
):
self._parent_segmentation = parent_segmentation
self._start_frame = start_frame or 0
self._end_frame = end_frame or self._parent_segmentation.get_num_frames()
self._num_frames = self._end_frame - self._start_frame

self._image_masks = self._parent_segmentation._image_masks

parent_size = self._parent_segmentation.get_num_frames()
if start_frame is None:
start_frame = 0
else:
assert 0 <= start_frame < parent_size
if end_frame is None:
end_frame = parent_size
else:
assert 0 < end_frame <= parent_size
assert end_frame > start_frame, "'start_frame' must be smaller than 'end_frame'!"

super().__init__()
if getattr(self._parent_segmentation, "_times") is not None:
self._times = self._parent_segmentation._times[start_frame:end_frame]

def get_accepted_list(self) -> list:
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
return self._parent_segmentation.get_accepted_list()

def get_rejected_list(self) -> list:
return self._parent_segmentation.get_rejected_list()

def get_traces(
self,
roi_ids: Optional[Iterable[int]] = None,
start_frame: Optional[int] = None,
end_frame: Optional[int] = None,
name: str = "raw",
) -> np.ndarray:
start_frame = min(start_frame or 0, self._num_frames)
end_frame = min(end_frame or self._num_frames, self._num_frames)
return self._parent_segmentation.get_traces(
CodyCBakerPhD marked this conversation as resolved.
Show resolved Hide resolved
roi_ids=roi_ids,
start_frame=start_frame + self._start_frame,
end_frame=end_frame + self._start_frame,
name=name,
)

def get_traces_dict(self):
return {
trace_name: self._parent_segmentation.get_traces(
start_frame=self._start_frame, end_frame=self._end_frame, name=trace_name
)
for trace_name, trace in self._parent_segmentation.get_traces_dict().items()
}

def get_image_size(self) -> Tuple[int, int]:
return tuple(self._parent_segmentation.get_image_size())

def get_num_frames(self) -> int:
return self._num_frames

def get_num_rois(self):
return self._parent_segmentation.get_num_rois()

def get_images_dict(self) -> dict:
return self._parent_segmentation.get_images_dict()

def get_image(self, name="correlation"):
return self._parent_segmentation.get_image(name=name)

def get_sampling_frequency(self) -> float:
return self._parent_segmentation.get_sampling_frequency()

def get_channel_names(self) -> list:
return self._parent_segmentation.get_channel_names()

def get_num_channels(self) -> int:
return self._parent_segmentation.get_num_channels()

def get_num_planes(self):
return self._parent_segmentation.get_num_planes()
1 change: 1 addition & 0 deletions src/roiextractors/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def generate_dummy_segmentation_extractor(
accepted_lst=accepeted_list,
rejected_list=rejected_list,
movie_dims=movie_dims,
channel_names=["channel_num_0"],
)

return dummy_segmentation_extractor
Expand Down
107 changes: 107 additions & 0 deletions tests/test_internals/test_frame_slice_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import unittest

import numpy as np
from hdmf.testing import TestCase
from numpy.testing import assert_array_equal
from parameterized import parameterized, param

from roiextractors.testing import generate_dummy_segmentation_extractor


def test_frame_slicing_segmentation_times():
num_frames = 10
timestamp_shift = 7.1
times = np.array(range(num_frames)) + timestamp_shift
start_frame, end_frame = 2, 7

toy_segmentation_example = generate_dummy_segmentation_extractor(num_frames=num_frames, num_rows=5, num_columns=4)
toy_segmentation_example.set_times(times=times)

frame_sliced_segmentation = toy_segmentation_example.frame_slice(start_frame=start_frame, end_frame=end_frame)
assert_array_equal(
frame_sliced_segmentation.frame_to_time(
frames=np.array([idx for idx in range(frame_sliced_segmentation.get_num_frames())])
),
times[start_frame:end_frame],
)


def segmentation_name_function(testcase_function, param_number, param):
return f"{testcase_function.__name__}_{param_number}_{parameterized.to_safe_name(param.kwargs['name'])}"


class BaseTestFrameSlicesegmentation(TestCase):
@classmethod
def setUpClass(cls):
cls.toy_segmentation_example = generate_dummy_segmentation_extractor(num_frames=15, num_rows=5, num_columns=4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that you have some tests for rejected and accepted list. Up to your but the generate_dummy_segmentation_extractor has a rejected list argument to generate that:

def generate_dummy_segmentation_extractor(
num_rois: int = 10,
num_frames: int = 30,
num_rows: int = 25,
num_columns: int = 25,
sampling_frequency: float = 30.0,
has_summary_images: bool = True,
has_raw_signal: bool = True,
has_dff_signal: bool = True,
has_deconvolved_signal: bool = True,
has_neuropil_signal: bool = True,
rejected_list: Optional[list] = None,

So you can test the non-default case when rejected_list != [ ]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I could. But I actually have yet to see any format in practice that bothers to save ROI info that has not been accepted. (so far, every single ROI written as series data or image mask is 'accepted' so 'rejected' is always empty). Maybe something to think about and revisit in the future

cls.frame_sliced_segmentation = cls.toy_segmentation_example.frame_slice(start_frame=2, end_frame=7)

def test_get_image_size(self):
assert self.frame_sliced_segmentation.get_image_size() == (5, 4)

def test_get_num_planes(self):
return self.frame_sliced_segmentation.get_num_planes() == 1

def test_get_num_frames(self):
assert self.frame_sliced_segmentation.get_num_frames() == 5

def test_get_sampling_frequency(self):
assert self.frame_sliced_segmentation.get_sampling_frequency() == 30.0

def test_get_channel_names(self):
assert self.frame_sliced_segmentation.get_channel_names() == ["channel_num_0"]

def test_get_num_channels(self):
assert self.frame_sliced_segmentation.get_num_channels() == 1

def test_get_num_rois(self):
assert self.frame_sliced_segmentation.get_num_rois() == 10

def test_get_accepted_list(self):
return assert_array_equal(
x=self.frame_sliced_segmentation.get_accepted_list(), y=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
)

def test_get_rejected_list(self):
return assert_array_equal(x=self.frame_sliced_segmentation.get_rejected_list(), y=[])

@parameterized.expand(
[param(name="raw"), param(name="dff"), param(name="neuropil"), param(name="deconvolved")],
name_func=segmentation_name_function,
)
def test_get_traces(self, name: str):
assert_array_equal(
x=self.frame_sliced_segmentation.get_traces(name=name),
y=self.toy_segmentation_example.get_traces(start_frame=2, end_frame=7, name=name),
)

def test_get_traces_dict(self):
true_dict = self.toy_segmentation_example.get_traces_dict()
for key in true_dict:
true_dict[key] = true_dict[key][2:7, :] if true_dict[key] is not None else true_dict[key]
self.assertCountEqual(first=self.frame_sliced_segmentation.get_traces_dict(), second=true_dict)

def test_get_images_dict(self):
self.assertCountEqual(
first=self.frame_sliced_segmentation.get_images_dict(),
second=self.toy_segmentation_example.get_images_dict(),
)

@parameterized.expand([param(name="mean"), param(name="correlation")], name_func=segmentation_name_function)
def test_get_image(self, name: str):
assert_array_equal(
x=self.frame_sliced_segmentation.get_image(name=name), y=self.toy_segmentation_example.get_image(name=name)
)


class TestMissingTraceFrameSlicesegmentation(BaseTestFrameSlicesegmentation):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not think on this pattern, it is neat.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multiplies the total number of tests run when doing it, but going for maximum safety here.

@classmethod
def setUpClass(cls):
cls.toy_segmentation_example = generate_dummy_segmentation_extractor(
num_frames=15, num_rows=5, num_columns=4, has_dff_signal=False
)
cls.frame_sliced_segmentation = cls.toy_segmentation_example.frame_slice(start_frame=2, end_frame=7)


if __name__ == "__main__":
unittest.main()
10 changes: 4 additions & 6 deletions tests/test_internals/test_testing_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def test_default_values(self):

# Test frame_to_time
times = np.round(np.arange(self.num_frames) / self.sampling_frequency, 6)
assert_array_equal(segmentation_extractor.frame_to_time(frame_indices=np.arange(self.num_frames)), times)
self.assertEqual(segmentation_extractor.frame_to_time(frame_indices=8), times[8])
assert_array_equal(segmentation_extractor.frame_to_time(frames=np.arange(self.num_frames)), times)
self.assertEqual(segmentation_extractor.frame_to_time(frames=8), times[8])

# Test image masks
assert segmentation_extractor.get_roi_image_masks().shape == (self.num_rows, self.num_columns, self.num_rois)
Expand Down Expand Up @@ -118,10 +118,8 @@ def test_frame_to_time_no_sampling_frequency(self):
times = np.arange(self.num_frames) / self.sampling_frequency
segmentation_extractor._times = times

self.assertEqual(segmentation_extractor.frame_to_time(frame_indices=2), times[2])
self.assertEqual(segmentation_extractor.frame_to_time(frames=2), times[2])
assert_array_equal(
segmentation_extractor.frame_to_time(
frame_indices=np.arange(self.num_frames),
),
segmentation_extractor.frame_to_time(frames=np.arange(self.num_frames)),
times,
)