-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Top-down Centered-instance Pipeline (#16)
* added make_centered_bboxes & normalize_bboxes * added make_centered_bboxes & normalize_bboxes * created test_instance_cropping.py * added test normalize bboxes; added find_global_peaks_rough * black formatted * black formatted peak_finding * added make_grid_vectors, normalize_bboxes, integral_regression, added docstring to make_centered_bboxes, fixed find_global_peaks_rough; added crop_bboxes * finished find_global_peaks with integral regression over centroid crops! * reformatted with pydocstyle & black * moved make_grid_vectors to data/utils * removed normalize_bboxes * added tests docstrings * sorted imports with isort * remove unused imports * updated test cases for instance cropping * added minimal_cms.pt fixture + unit tests * added minimal_bboxes fixture; added unit tests for crop_bboxes & integral_regression * added find_global_peaks unit tests * finished find_local_peaks_rough! * finished find_local_peaks! * added unit tests for find_local_peaks and find_local_peaks_rough * updated test cases * added more test cases for find_local_peaks * updated test cases * added architectures folder * added maxpool2d same padding, get_act_fn; added simpleconvblock, simpleupsamplingblock, encoder, decoder; added unet * added test_unet_reference * black formatted common.py & test_unet.py * deleted tmp nb * _calc_same_pad returns int * fixed test case * added simpleconvblock tests * added tests * added tests for simple upsampling block * updated test_unet * removed unnecessary variables * updated augmentation random erase default values * created data/pipelines.py * added base config in config/data; temporary till config system settled * updated variable defaults to 0 and edited variable names in augmentation * updated parameter names in data/instance_cropping * added data/pipelines topdown pipeline make_base_pipeline * added test_pipelines * removed configs * updated augmentation class * modified test * updated pipelines docstring * removed make_base_pipeline and updated tests * removed empty_cache in SleapDataset * updated test_pipelines * updated sleapdataset to return a dict * added key filter transformer block, removed sleap dataset, added type hinting * updated type hints * added coderabbit suggestions * fixed small squeeze issue
- Loading branch information
Showing
13 changed files
with
368 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
"""General purpose transformers for common pipeline processing tasks.""" | ||
from typing import Callable, Dict, Iterator, List, Text | ||
|
||
import torch | ||
from torch.utils.data.datapipes.datapipe import IterDataPipe | ||
|
||
|
||
class KeyFilter(IterDataPipe): | ||
"""Transformer for filtering example keys.""" | ||
|
||
def __init__(self, source_dp: IterDataPipe, keep_keys: List[Text] = None) -> None: | ||
"""Initialize KeyFilter with the source `DataPipe.""" | ||
self.dp = source_dp | ||
self.keep_keys = set(keep_keys) if keep_keys else None | ||
|
||
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: | ||
"""Return a dictionary filtered for the relevant outputs. | ||
The input dictionary includes: | ||
- image: the full frame image | ||
- instances: all keypoints of all instances in the frame image | ||
- centroids: all centroids of all instances in the frame image | ||
- instance: the individual instance's keypoints | ||
- instance_bbox: the individual instance's bbox | ||
- instance_image: the individual instance's cropped image | ||
- confidence_maps: the individual instance's heatmap | ||
""" | ||
for example in self.dp: | ||
if self.keep_keys is None: | ||
# If keep_keys is not provided, yield the entire example. | ||
yield example | ||
else: | ||
# Filter the example dictionary based on keep_keys. | ||
filtered_example = { | ||
key: value | ||
for key, value in example.items() | ||
if key in self.keep_keys | ||
} | ||
yield filtered_example |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.