Skip to content

Commit

Permalink
Implement base Model and Head classes (#17)
Browse files Browse the repository at this point in the history
* added make_centered_bboxes & normalize_bboxes

* added make_centered_bboxes & normalize_bboxes

* created test_instance_cropping.py

* added test normalize bboxes; added find_global_peaks_rough

* black formatted

* black formatted peak_finding

* added make_grid_vectors, normalize_bboxes, integral_regression, added docstring to make_centered_bboxes, fixed find_global_peaks_rough; added crop_bboxes

* finished find_global_peaks with integral regression over centroid crops!

* reformatted with pydocstyle & black

* moved make_grid_vectors to data/utils

* removed normalize_bboxes

* added tests docstrings

* sorted imports with isort

* remove unused imports

* updated test cases for instance cropping

* added minimal_cms.pt fixture + unit tests

* added minimal_bboxes fixture; added unit tests for crop_bboxes & integral_regression

* added find_global_peaks unit tests

* finished find_local_peaks_rough!

* finished find_local_peaks!

* added unit tests for find_local_peaks and find_local_peaks_rough

* updated test cases

* added more test cases for find_local_peaks

* updated test cases

* added architectures folder

* added maxpool2d same padding, get_act_fn; added simpleconvblock, simpleupsamplingblock, encoder, decoder; added unet

* added test_unet_reference

* black formatted common.py & test_unet.py

* deleted tmp nb

* _calc_same_pad returns int

* fixed test case

* added simpleconvblock tests

* added tests

* added tests for simple upsampling block

* updated test_unet

* removed unnecessary variables

* updated augmentation random erase default values

* created data/pipelines.py

* added base config in config/data; temporary till config system settled

* updated variable defaults to 0 and edited variable names in augmentation

* updated parameter names in data/instance_cropping

* added data/pipelines topdown pipeline make_base_pipeline

* added test_pipelines

* removed configs

* updated augmentation class

* modified test

* removed cuda cache

* added Model builder class and heads

* added type hinting for init

* black reformatted heads.py

* updated model.py

* updated test_model.py

* updated test_model.py

* updated pipelines docstring

* added from_config for Model

* added more act fn to get_act_fn

* black reformatted & updated model.py & test_model.py

* updated config, typehints, black formatted & added doc strings

* added test_heads.py

* updated module docstring

* updated Model docstring

* added coderabbit suggestions

* black reformat

* added 2 helper methods for getting backbone/head; separated common and utils in architectures

* removed comments

* updated test_get_act_fn

* added multi-head feature to Model

* black reformatted model.py

* added all test cases for heads.py

* reformatted test_heads.py

* updated L44 in confidence_maps.py

* added output channels to unet

* resolved merge conflicts + small bugs

* black reformatted

* added coderabbit suggestions

* not sure how intermediate features + multi head would work

* Separate Augmentations into Intensity and Geometric (#18)

* initial commit

* separated intensity and geometric augmentations

* test

* pseudo code in model.py

* small fix

* name property in heads.py

* name of head docstring added

* added ruff cache to gitignore; added head selection in Model class

* updated return value for decoder

* small change to model.py

* made model.py forward more efficient

* small comments updated in instance_cropping for clarity

* updated output structure of unet to dict; updated model.py attribute head

* fixed minor changes

* fixed minor changes

* updated ruff output format

* added anchor_ind to topdown pipeline config
  • Loading branch information
alckasoc authored Oct 19, 2023
1 parent 11a1a33 commit 0dff84f
Show file tree
Hide file tree
Showing 20 changed files with 1,439 additions and 141 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ venv.bak/
.dmypy.json
dmypy.json

# ruff
.ruff_cache/

# Pyre type checker
.pyre/

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Repository = "https://github.com/talmolab/sleap-nn"
line-length = 88

[tool.ruff]
format = "github"
output-format = "github"
select = [
"D", # pydocstyle
]
Expand Down
55 changes: 0 additions & 55 deletions sleap_nn/architectures/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,58 +101,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
ceil_mode=self.ceil_mode,
return_indices=self.return_indices,
)


def get_act_fn(activation: str) -> nn.Module:
"""Get an instance of an activation function module based on the provided name.
This function returns an instance of a PyTorch activation function module
corresponding to the given activation function name.
Args:
activation (str): Name of the activation function. Supported values are 'relu', 'sigmoid', and 'tanh'.
Returns:
nn.Module: An instance of the requested activation function module.
Raises:
KeyError: If the provided activation function name is not one of the supported values.
Example:
# Get an instance of the ReLU activation function
relu_fn = get_act_fn('relu')
# Apply the activation function to an input tensor
input_tensor = torch.randn(1, 64, 64)
output = relu_fn(input_tensor)
"""
activations = {"relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh()}

if activation not in activations:
raise KeyError(
f"Unsupported activation function: {activation}. Supported activations are: {', '.join(activations.keys())}"
)

return activations[activation]


def get_children_layers(model: torch.nn.Module) -> List[nn.Module]:
"""Recursively retrieves a flattened list of all children modules and submodules within the given model.
Args:
model: The PyTorch model to extract children from.
Returns:
list of nn.Module: A flattened list containing all children modules and submodules.
"""
children = list(model.children())
flattened_children = []
if children == []:
return model
else:
for child in children:
try:
flattened_children.extend(get_children_layers(child))
except TypeError:
flattened_children.append(get_children_layers(child))
return flattened_children
21 changes: 17 additions & 4 deletions sleap_nn/architectures/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
from typing import List, Text, Tuple, Union

import torch
from sleap_nn.architectures.common import MaxPool2dWithSamePadding, get_act_fn
from torch import nn

from sleap_nn.architectures.common import MaxPool2dWithSamePadding
from sleap_nn.architectures.utils import get_act_fn


class SimpleConvBlock(nn.Module):
"""A simple convolutional block module.
Expand Down Expand Up @@ -428,6 +430,8 @@ def __init__(
self.convs_per_block = convs_per_block
self.kernel_size = kernel_size

self.current_strides = []

self.decoder_stack = nn.ModuleList([])
for block in range(up_blocks):
prev_block_filters_in = -1 if block == 0 else block_filters_in
Expand All @@ -454,19 +458,28 @@ def __init__(
)
)

self.current_strides.append(current_stride)
current_stride = next_stride

def forward(self, x: torch.Tensor, features: List[torch.Tensor]) -> torch.Tensor:
def forward(
self, x: torch.Tensor, features: List[torch.Tensor]
) -> Tuple[List[torch.Tensor], List]:
"""Forward pass through the Decoder module.
Args:
x: Input tensor for the decoder.
features: List of feature tensors from different encoder levels.
Returns:
torch.Tensor: Output tensor after applying the decoder operations.
outputs: List of output tensors after applying the decoder operations.
current_strides: the current strides from the decoder blocks.
"""
outputs = {
"outputs": [],
}
for i in range(len(self.decoder_stack)):
x = self.decoder_stack[i](x, features[i])
outputs["outputs"].append(x)
outputs["strides"] = self.current_strides

return x
return outputs
Loading

0 comments on commit 0dff84f

Please sign in to comment.