Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement base Model and Head classes #17

Merged
merged 100 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 92 commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
2dfcd90
added make_centered_bboxes & normalize_bboxes
alckasoc Aug 3, 2023
1088e7f
added make_centered_bboxes & normalize_bboxes
alckasoc Aug 3, 2023
2d0a009
created test_instance_cropping.py
alckasoc Aug 3, 2023
02ea629
added test normalize bboxes; added find_global_peaks_rough
alckasoc Aug 6, 2023
711b3aa
black formatted
alckasoc Aug 6, 2023
3a0cfb7
fixed merges
alckasoc Aug 6, 2023
9a728aa
black formatted peak_finding
alckasoc Aug 6, 2023
e84535f
added make_grid_vectors, normalize_bboxes, integral_regression, added…
alckasoc Aug 10, 2023
36f6573
finished find_global_peaks with integral regression over centroid crops!
alckasoc Aug 10, 2023
b17af28
reformatted with pydocstyle & black
alckasoc Aug 10, 2023
3ea75ae
Merge remote-tracking branch 'origin/main' into vincent/find_peaks
alckasoc Aug 10, 2023
a506579
moved make_grid_vectors to data/utils
alckasoc Aug 10, 2023
02babb1
removed normalize_bboxes
alckasoc Aug 10, 2023
373f4b1
added tests docstrings
alckasoc Aug 10, 2023
6351314
sorted imports with isort
alckasoc Aug 10, 2023
008a994
remove unused imports
alckasoc Aug 10, 2023
b45619c
updated test cases for instance cropping
alckasoc Aug 10, 2023
381a49f
added minimal_cms.pt fixture + unit tests
alckasoc Aug 11, 2023
0ad336c
added minimal_bboxes fixture; added unit tests for crop_bboxes & inte…
alckasoc Aug 11, 2023
da1ba7e
added find_global_peaks unit tests
alckasoc Aug 11, 2023
7778512
finished find_local_peaks_rough!
alckasoc Aug 17, 2023
9f7ac3f
finished find_local_peaks!
alckasoc Aug 17, 2023
b9869d6
added unit tests for find_local_peaks and find_local_peaks_rough
alckasoc Aug 17, 2023
bfd1cac
updated test cases
alckasoc Aug 17, 2023
a8b3c31
added more test cases for find_local_peaks
alckasoc Aug 17, 2023
125625d
updated test cases
alckasoc Aug 17, 2023
a25d920
added architectures folder
alckasoc Aug 17, 2023
3ba92b6
added maxpool2d same padding, get_act_fn; added simpleconvblock, simp…
alckasoc Aug 17, 2023
f9558f2
added test_unet_reference
alckasoc Aug 18, 2023
28d57ca
black formatted common.py & test_unet.py
alckasoc Aug 18, 2023
8ca4538
fixed merge conflicts
alckasoc Aug 18, 2023
6df3c20
Merge branch 'main' into vincent/unet
alckasoc Aug 18, 2023
c4792a6
Merge branch 'vincent/unet' of https://github.com/talmolab/sleap-nn i…
alckasoc Aug 18, 2023
87cd034
deleted tmp nb
alckasoc Aug 18, 2023
7004869
_calc_same_pad returns int
alckasoc Aug 19, 2023
680778d
fixed test case
alckasoc Aug 19, 2023
7cd75dc
added simpleconvblock tests
alckasoc Aug 19, 2023
79b535d
added tests
alckasoc Aug 19, 2023
691af45
added tests for simple upsampling block
alckasoc Aug 19, 2023
2520fa2
updated test_unet
alckasoc Aug 28, 2023
bcf4069
removed unnecessary variables
alckasoc Aug 30, 2023
dbccdcf
updated augmentation random erase default values
alckasoc Aug 30, 2023
029a545
created data/pipelines.py
alckasoc Aug 30, 2023
3e5ae68
added base config in config/data; temporary till config system settled
alckasoc Aug 31, 2023
1b8002b
updated variable defaults to 0 and edited variable names in augmentation
alckasoc Aug 31, 2023
f1c64f4
updated parameter names in data/instance_cropping
alckasoc Aug 31, 2023
2a22674
added data/pipelines topdown pipeline make_base_pipeline
alckasoc Aug 31, 2023
f3ddf2f
added test_pipelines
alckasoc Aug 31, 2023
c861c72
removed configs
alckasoc Sep 5, 2023
31aadc1
updated augmentation class
alckasoc Sep 6, 2023
6630155
modified test
alckasoc Sep 6, 2023
c7fc015
removed cuda cache
alckasoc Sep 6, 2023
4b439e9
added Model builder class and heads
alckasoc Sep 6, 2023
199e4d5
added type hinting for init
alckasoc Sep 6, 2023
4134473
black reformatted heads.py
alckasoc Sep 6, 2023
86f960e
updated model.py
alckasoc Sep 6, 2023
d576e63
updated test_model.py
alckasoc Sep 6, 2023
40089b7
updated test_model.py
alckasoc Sep 6, 2023
dddb6ac
updated pipelines docstring
alckasoc Sep 6, 2023
b75cacd
added from_config for Model
alckasoc Sep 7, 2023
979de3f
added more act fn to get_act_fn
alckasoc Sep 12, 2023
ff144a9
black reformatted & updated model.py & test_model.py
alckasoc Sep 12, 2023
4dacac9
updated config, typehints, black formatted & added doc strings
alckasoc Sep 12, 2023
a4fffd4
added test_heads.py
alckasoc Sep 12, 2023
ddc70c4
updated module docstring
alckasoc Sep 12, 2023
3f58407
updated Model docstring
alckasoc Sep 12, 2023
980f107
added coderabbit suggestions
alckasoc Sep 12, 2023
cf9d18f
black reformat
alckasoc Sep 13, 2023
0def8ae
added 2 helper methods for getting backbone/head; separated common an…
alckasoc Sep 13, 2023
53065a3
removed comments
alckasoc Sep 13, 2023
1a66ace
updated test_get_act_fn
alckasoc Sep 13, 2023
f762317
added multi-head feature to Model
alckasoc Sep 14, 2023
a70c5c4
black reformatted model.py
alckasoc Sep 14, 2023
f2543a2
added all test cases for heads.py
alckasoc Sep 14, 2023
e6fa8d4
reformatted test_heads.py
alckasoc Sep 14, 2023
ce86884
updated L44 in confidence_maps.py
alckasoc Sep 14, 2023
361f979
added output channels to unet
alckasoc Sep 19, 2023
81b5af0
resolved merge conflicts
alckasoc Sep 19, 2023
ca4e87d
Merge branch 'main' into vincent/models
alckasoc Sep 19, 2023
a768c4f
resolved merge conflicts + small bugs
alckasoc Sep 19, 2023
0be083f
black reformatted
alckasoc Sep 20, 2023
88bc306
added coderabbit suggestions
alckasoc Sep 20, 2023
27990c3
not sure how intermediate features + multi head would work
alckasoc Sep 20, 2023
212e94f
Separate Augmentations into Intensity and Geometric (#18)
alckasoc Sep 21, 2023
d44f8a6
pseudo code in model.py
alckasoc Sep 21, 2023
2930057
Merge branch 'vincent/models' of https://github.com/talmolab/sleap-nn…
alckasoc Sep 21, 2023
537c14b
small fix
alckasoc Sep 21, 2023
b759e7a
name property in heads.py
alckasoc Sep 21, 2023
e21e74c
name of head docstring added
alckasoc Sep 21, 2023
dce7294
Merge remote-tracking branch 'origin/main' into vincent/models
alckasoc Sep 26, 2023
99413a1
added ruff cache to gitignore; added head selection in Model class
alckasoc Sep 27, 2023
2880849
updated return value for decoder
alckasoc Oct 7, 2023
9afdebb
small change to model.py
alckasoc Oct 7, 2023
5cd8801
made model.py forward more efficient
alckasoc Oct 7, 2023
4eaf268
small comments updated in instance_cropping for clarity
alckasoc Oct 7, 2023
702009a
updated output structure of unet to dict; updated model.py attribute …
alckasoc Oct 12, 2023
521fc9c
fixed minor changes
alckasoc Oct 17, 2023
4eff9be
fixed minor changes
alckasoc Oct 17, 2023
b0d530b
updated ruff output format
alckasoc Oct 17, 2023
0325596
added anchor_ind to topdown pipeline config
alckasoc Oct 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ venv.bak/
.dmypy.json
dmypy.json

# ruff
.ruff_cache/

# Pyre type checker
.pyre/

Expand Down
55 changes: 0 additions & 55 deletions sleap_nn/architectures/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,58 +101,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
ceil_mode=self.ceil_mode,
return_indices=self.return_indices,
)


def get_act_fn(activation: str) -> nn.Module:
"""Get an instance of an activation function module based on the provided name.

This function returns an instance of a PyTorch activation function module
corresponding to the given activation function name.

Args:
activation (str): Name of the activation function. Supported values are 'relu', 'sigmoid', and 'tanh'.

Returns:
nn.Module: An instance of the requested activation function module.

Raises:
KeyError: If the provided activation function name is not one of the supported values.

Example:
# Get an instance of the ReLU activation function
relu_fn = get_act_fn('relu')

# Apply the activation function to an input tensor
input_tensor = torch.randn(1, 64, 64)
output = relu_fn(input_tensor)
"""
activations = {"relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh()}

if activation not in activations:
raise KeyError(
f"Unsupported activation function: {activation}. Supported activations are: {', '.join(activations.keys())}"
)

return activations[activation]


def get_children_layers(model: torch.nn.Module) -> List[nn.Module]:
"""Recursively retrieves a flattened list of all children modules and submodules within the given model.

Args:
model: The PyTorch model to extract children from.

Returns:
list of nn.Module: A flattened list containing all children modules and submodules.
"""
children = list(model.children())
flattened_children = []
if children == []:
return model
else:
for child in children:
try:
flattened_children.extend(get_children_layers(child))
except TypeError:
flattened_children.append(get_children_layers(child))
return flattened_children
18 changes: 14 additions & 4 deletions sleap_nn/architectures/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
from typing import List, Text, Tuple, Union

import torch
from sleap_nn.architectures.common import MaxPool2dWithSamePadding, get_act_fn
from torch import nn

from sleap_nn.architectures.common import MaxPool2dWithSamePadding
from sleap_nn.architectures.utils import get_act_fn


class SimpleConvBlock(nn.Module):
"""A simple convolutional block module.
Expand Down Expand Up @@ -428,6 +430,8 @@ def __init__(
self.convs_per_block = convs_per_block
self.kernel_size = kernel_size

self.current_strides = []

self.decoder_stack = nn.ModuleList([])
for block in range(up_blocks):
prev_block_filters_in = -1 if block == 0 else block_filters_in
Expand All @@ -454,19 +458,25 @@ def __init__(
)
)

self.current_strides.append(current_stride)
current_stride = next_stride

def forward(self, x: torch.Tensor, features: List[torch.Tensor]) -> torch.Tensor:
def forward(
self, x: torch.Tensor, features: List[torch.Tensor]
) -> Tuple[List[torch.Tensor], List]:
"""Forward pass through the Decoder module.

Args:
x: Input tensor for the decoder.
features: List of feature tensors from different encoder levels.

Returns:
torch.Tensor: Output tensor after applying the decoder operations.
outputs: List of output tensors after applying the decoder operations.
current_strides: the current strides from the decoder blocks.
"""
outputs = []
for i in range(len(self.decoder_stack)):
x = self.decoder_stack[i](x, features[i])
outputs.append(x)

return x
return outputs, self.current_strides
Loading