Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Top-down Centered-instance Pipeline #16

Merged
merged 61 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
2dfcd90
added make_centered_bboxes & normalize_bboxes
alckasoc Aug 3, 2023
1088e7f
added make_centered_bboxes & normalize_bboxes
alckasoc Aug 3, 2023
2d0a009
created test_instance_cropping.py
alckasoc Aug 3, 2023
02ea629
added test normalize bboxes; added find_global_peaks_rough
alckasoc Aug 6, 2023
711b3aa
black formatted
alckasoc Aug 6, 2023
3a0cfb7
fixed merges
alckasoc Aug 6, 2023
9a728aa
black formatted peak_finding
alckasoc Aug 6, 2023
e84535f
added make_grid_vectors, normalize_bboxes, integral_regression, added…
alckasoc Aug 10, 2023
36f6573
finished find_global_peaks with integral regression over centroid crops!
alckasoc Aug 10, 2023
b17af28
reformatted with pydocstyle & black
alckasoc Aug 10, 2023
3ea75ae
Merge remote-tracking branch 'origin/main' into vincent/find_peaks
alckasoc Aug 10, 2023
a506579
moved make_grid_vectors to data/utils
alckasoc Aug 10, 2023
02babb1
removed normalize_bboxes
alckasoc Aug 10, 2023
373f4b1
added tests docstrings
alckasoc Aug 10, 2023
6351314
sorted imports with isort
alckasoc Aug 10, 2023
008a994
remove unused imports
alckasoc Aug 10, 2023
b45619c
updated test cases for instance cropping
alckasoc Aug 10, 2023
381a49f
added minimal_cms.pt fixture + unit tests
alckasoc Aug 11, 2023
0ad336c
added minimal_bboxes fixture; added unit tests for crop_bboxes & inte…
alckasoc Aug 11, 2023
da1ba7e
added find_global_peaks unit tests
alckasoc Aug 11, 2023
7778512
finished find_local_peaks_rough!
alckasoc Aug 17, 2023
9f7ac3f
finished find_local_peaks!
alckasoc Aug 17, 2023
b9869d6
added unit tests for find_local_peaks and find_local_peaks_rough
alckasoc Aug 17, 2023
bfd1cac
updated test cases
alckasoc Aug 17, 2023
a8b3c31
added more test cases for find_local_peaks
alckasoc Aug 17, 2023
125625d
updated test cases
alckasoc Aug 17, 2023
a25d920
added architectures folder
alckasoc Aug 17, 2023
3ba92b6
added maxpool2d same padding, get_act_fn; added simpleconvblock, simp…
alckasoc Aug 17, 2023
f9558f2
added test_unet_reference
alckasoc Aug 18, 2023
28d57ca
black formatted common.py & test_unet.py
alckasoc Aug 18, 2023
8ca4538
fixed merge conflicts
alckasoc Aug 18, 2023
6df3c20
Merge branch 'main' into vincent/unet
alckasoc Aug 18, 2023
c4792a6
Merge branch 'vincent/unet' of https://github.com/talmolab/sleap-nn i…
alckasoc Aug 18, 2023
87cd034
deleted tmp nb
alckasoc Aug 18, 2023
7004869
_calc_same_pad returns int
alckasoc Aug 19, 2023
680778d
fixed test case
alckasoc Aug 19, 2023
7cd75dc
added simpleconvblock tests
alckasoc Aug 19, 2023
79b535d
added tests
alckasoc Aug 19, 2023
691af45
added tests for simple upsampling block
alckasoc Aug 19, 2023
2520fa2
updated test_unet
alckasoc Aug 28, 2023
bcf4069
removed unnecessary variables
alckasoc Aug 30, 2023
dbccdcf
updated augmentation random erase default values
alckasoc Aug 30, 2023
029a545
created data/pipelines.py
alckasoc Aug 30, 2023
3e5ae68
added base config in config/data; temporary till config system settled
alckasoc Aug 31, 2023
1b8002b
updated variable defaults to 0 and edited variable names in augmentation
alckasoc Aug 31, 2023
f1c64f4
updated parameter names in data/instance_cropping
alckasoc Aug 31, 2023
2a22674
added data/pipelines topdown pipeline make_base_pipeline
alckasoc Aug 31, 2023
f3ddf2f
added test_pipelines
alckasoc Aug 31, 2023
c861c72
removed configs
alckasoc Sep 5, 2023
31aadc1
updated augmentation class
alckasoc Sep 6, 2023
6630155
modified test
alckasoc Sep 6, 2023
55cf1a9
updated pipelines docstring
alckasoc Sep 6, 2023
9715c01
removed make_base_pipeline and updated tests
alckasoc Sep 6, 2023
7deec65
removed empty_cache in SleapDataset
alckasoc Sep 6, 2023
e32a0a9
Merge branch 'main' into vincent/topdownpipeline
alckasoc Sep 7, 2023
b1ef93c
updated test_pipelines
alckasoc Sep 7, 2023
ae523d1
updated sleapdataset to return a dict
alckasoc Sep 12, 2023
b421937
added key filter transformer block, removed sleap dataset, added type…
alckasoc Sep 12, 2023
fe61f15
updated type hints
alckasoc Sep 12, 2023
0214abb
added coderabbit suggestions
alckasoc Sep 12, 2023
e3b28da
fixed small squeeze issue
alckasoc Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sleap_nn/architectures/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Modules related to model architectures."""
156 changes: 156 additions & 0 deletions sleap_nn/architectures/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""Common utilities for architecture and model building."""
import torch
from torch import nn
from torch.nn import functional as F


class MaxPool2dWithSamePadding(nn.MaxPool2d):
"""A MaxPool2d module with support for same padding.

This class extends the torch.nn.MaxPool2d module and adds the ability
to perform 'same' padding, similar to 'same' padding in convolutional
layers. When 'same' padding is specified, the input tensor is padded
with zeros to ensure that the output spatial dimensions match the input
spatial dimensions as closely as possible.

Args:
nn.MaxPool2d arguments: Arguments that are passed to the parent
torch.nn.MaxPool2d class.

Attributes:
Inherits all attributes from torch.nn.MaxPool2d.

Methods:
forward(x: torch.Tensor) -> torch.Tensor:
Forward pass through the MaxPool2dWithSamePadding module.

Note:
The 'same' padding is applied only when self.padding is set to "same".

Example:
# Create an instance of MaxPool2dWithSamePadding
maxpool_layer = MaxPool2dWithSamePadding(kernel_size=3, stride=2, padding="same")

# Perform a forward pass on an input tensor
input_tensor = torch.rand(1, 3, 32, 32) # Example input tensor
output = maxpool_layer(input_tensor) # Apply the MaxPool2d operation with same padding.
"""

def _calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
"""Calculate the required padding to achieve 'same' padding.

Args:
i (int): Input dimension (height or width).
k (int): Kernel size.
s (int): Stride.
d (int): Dilation.

Returns:
int: The calculated padding value.
"""
return int(
max(
(torch.ceil(torch.tensor(i / s)).item() - 1) * s + (k - 1) * d + 1 - i,
0,
)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the MaxPool2dWithSamePadding module.

Args:
x (torch.Tensor): Input tensor.

Returns:
torch.Tensor: Output tensor after applying the MaxPool2d operation.
"""
if self.padding == "same":
ih, iw = x.size()[-2:]

pad_h = self._calc_same_pad(
i=ih,
k=self.kernel_size
if type(self.kernel_size) is int
else self.kernel_size[0],
s=self.stride if type(self.stride) is int else self.stride[0],
d=self.dilation if type(self.dilation) is int else self.dilation[0],
)
pad_w = self._calc_same_pad(
i=iw,
k=self.kernel_size
if type(self.kernel_size) is int
else self.kernel_size[1],
s=self.stride if type(self.stride) is int else self.stride[1],
d=self.dilation if type(self.dilation) is int else self.dilation[1],
)

if pad_h > 0 or pad_w > 0:
x = F.pad(
x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
)
self.padding = 0

return F.max_pool2d(
x,
self.kernel_size,
self.stride,
self.padding,
self.dilation,
ceil_mode=self.ceil_mode,
return_indices=self.return_indices,
)


def get_act_fn(activation: str) -> nn.Module:
"""Get an instance of an activation function module based on the provided name.

This function returns an instance of a PyTorch activation function module
corresponding to the given activation function name.

Args:
activation (str): Name of the activation function. Supported values are 'relu', 'sigmoid', and 'tanh'.

Returns:
nn.Module: An instance of the requested activation function module.

Raises:
KeyError: If the provided activation function name is not one of the supported values.

Example:
# Get an instance of the ReLU activation function
relu_fn = get_act_fn('relu')

# Apply the activation function to an input tensor
input_tensor = torch.randn(1, 64, 64)
output = relu_fn(input_tensor)
"""
activations = {"relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh()}

if activation not in activations:
raise KeyError(
f"Unsupported activation function: {activation}. Supported activations are: {', '.join(activations.keys())}"
)

return activations[activation]


def get_children_layers(model: torch.nn.Module):
"""Recursively retrieves a flattened list of all children modules and submodules within the given model.

Args:
model: The PyTorch model to extract children from.

Returns:
list of nn.Module: A flattened list containing all children modules and submodules.
"""
children = list(model.children())
flattened_children = []
if children == []:
return model
else:
for child in children:
try:
flattened_children.extend(get_children_layers(child))
except TypeError:
flattened_children.append(get_children_layers(child))
return flattened_children
Loading