-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added SAM2 model and post-processing (#19)
* added SAM2 * updates SAM2 model adapter to use all hierarchical features; added SAM2 Large & Base * updates post-processing using SAM2; updated setup * updated toml file to include sam2 dep from git * updated setup & envs yaml * fixed toml sam-2 dependency; clean-up after post-processing * commented sam2 dependency in pyproject.toml due to sam2 installation error via pip * Updated README * updated numpy to 1.24.4
- Loading branch information
Showing
10 changed files
with
249 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,13 +17,16 @@ allow-direct-references = true | |
only-include = ["src"] | ||
sources = ["src"] | ||
|
||
[tool.hatch.envs.default.env-vars] | ||
SAM2_BUILD_CUDA="0" | ||
|
||
# https://peps.python.org/pep-0621/ | ||
[project] | ||
name = "featureforest" | ||
dynamic = ["version"] | ||
description = "A napari plugin for segmentation using vision transformer models' features" | ||
description = "A napari plugin for segmentation using vision transformer features" | ||
readme = "README.md" | ||
requires-python = ">=3.8" | ||
requires-python = ">=3.10" | ||
license = { text = "BSD-3-Clause" } | ||
authors = [ | ||
{ name = "Mehdi Seifi", email = "[email protected]" }, | ||
|
@@ -33,18 +36,14 @@ classifiers = [ | |
"Development Status :: 3 - Alpha", | ||
"Framework :: napari", | ||
"Programming Language :: Python :: 3", | ||
"Programming Language :: Python :: 3.8", | ||
"Programming Language :: Python :: 3.9", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
"Programming Language :: Python :: 3.12", | ||
"License :: OSI Approved :: BSD License", | ||
"Topic :: Scientific/Engineering :: Image Processing", | ||
] | ||
dependencies = [ | ||
"numpy==1.23.5", | ||
"torch==2.1.2", | ||
"torchvision==0.16.2", | ||
"numpy==1.24.4", | ||
"opencv-python", | ||
"scikit-learn", | ||
"scikit-image", | ||
|
@@ -56,9 +55,12 @@ dependencies = [ | |
"napari", | ||
"h5py", | ||
"pooch", | ||
"iopath>=0.1.10", | ||
"torch==2.3.1", | ||
"torchvision==0.18.1", | ||
"timm==1.0.9", | ||
"segment-anything-py", | ||
"segment-anything-hq", | ||
# "sam-2 @ git+https://github.com/facebookresearch/sam2.git" | ||
] | ||
[project.optional-dependencies] | ||
# development dependencies and tooling | ||
|
@@ -86,12 +88,12 @@ markers = [ | |
|
||
[tool.black] | ||
line-length = 90 | ||
target-version = ['py38', 'py39', 'py310'] | ||
target-version = ["py310", "py311", "py312"] | ||
|
||
|
||
[tool.ruff] | ||
line-length = 90 | ||
target-version = "py38" | ||
target-version = "py310" | ||
src = ["src"] | ||
select = [ | ||
"E", "F", "W", #flake8 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,3 @@ | ||
from typing import Tuple | ||
|
||
import torch | ||
|
||
from segment_anything.modeling import Sam | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from .model import get_large_model, get_base_model | ||
from .adapter import SAM2Adapter | ||
|
||
|
||
__all__ = [ | ||
"get_large_model", | ||
"get_base_model", | ||
"SAM2Adapter", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from typing import Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch import Tensor | ||
from torchvision.transforms import v2 as tv_transforms2 | ||
|
||
from featureforest.models.base import BaseModelAdapter | ||
from featureforest.utils.data import ( | ||
get_patch_size, | ||
get_nonoverlapped_patches, | ||
) | ||
|
||
|
||
class SAM2Adapter(BaseModelAdapter): | ||
"""SAM2 model adapter | ||
""" | ||
def __init__( | ||
self, | ||
image_encoder: nn.Module, | ||
img_height: float, | ||
img_width: float, | ||
device: torch.device, | ||
name: str = "SAM2_Large" | ||
) -> None: | ||
super().__init__(image_encoder, img_height, img_width, device) | ||
# for different flavors of SAM2 only the name is different. | ||
self.name = name | ||
# we need sam2 image encoder part | ||
self.encoder = image_encoder | ||
self.encoder_num_channels = 256 | ||
self._set_patch_size() | ||
self.device = device | ||
|
||
# input transform for sam | ||
self.sam_input_dim = 1024 | ||
self.input_transforms = tv_transforms2.Compose([ | ||
tv_transforms2.Resize( | ||
(self.sam_input_dim, self.sam_input_dim), | ||
interpolation=tv_transforms2.InterpolationMode.BICUBIC, | ||
antialias=True | ||
), | ||
]) | ||
# to transform feature patches back to the original patch size | ||
self.embedding_transform = tv_transforms2.Compose([ | ||
tv_transforms2.Resize( | ||
(self.patch_size, self.patch_size), | ||
interpolation=tv_transforms2.InterpolationMode.BICUBIC, | ||
antialias=True | ||
), | ||
]) | ||
|
||
def _set_patch_size(self) -> None: | ||
self.patch_size = get_patch_size(self.img_height, self.img_width) | ||
self.overlap = self.patch_size // 2 | ||
|
||
def get_features_patches( | ||
self, in_patches: Tensor | ||
) -> Tuple[Tensor, Tensor]: | ||
# get the image encoder outputs | ||
with torch.no_grad(): | ||
output = self.encoder( | ||
self.input_transforms(in_patches) | ||
) | ||
# backbone_fpn contains 3 levels of features from hight to low resolution. | ||
# [b, 256, 256, 256] | ||
# [b, 256, 128, 128] | ||
# [b, 256, 64, 64] | ||
features = [ | ||
self.embedding_transform(feat.cpu()) | ||
for feat in output["backbone_fpn"] | ||
] | ||
features = torch.cat(features, dim=1) | ||
out_feature_patches = get_nonoverlapped_patches(features, self.patch_size, self.overlap) | ||
|
||
return out_feature_patches | ||
|
||
def get_total_output_channels(self) -> int: | ||
return 256 * 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from pathlib import Path | ||
|
||
import torch | ||
|
||
from sam2.modeling.sam2_base import SAM2Base | ||
from sam2.build_sam import build_sam2 | ||
|
||
from featureforest.utils.downloader import download_model | ||
from featureforest.models.SAM2.adapter import SAM2Adapter | ||
|
||
|
||
def get_large_model( | ||
img_height: float, img_width: float, *args, **kwargs | ||
) -> SAM2Adapter: | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
print(f"running on {device}") | ||
# download model's weights | ||
model_url = \ | ||
"https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt" | ||
model_file = download_model( | ||
model_url=model_url, | ||
model_name="sam2.1_hiera_large.pt" | ||
) | ||
if model_file is None: | ||
raise ValueError(f"Could not download the model from {model_url}.") | ||
|
||
# init the model | ||
model: SAM2Base = build_sam2( | ||
config_file= "configs/sam2.1/sam2.1_hiera_l.yaml", | ||
ckpt_path=model_file, | ||
device="cpu" | ||
) | ||
# to save some GPU memory, only put the encoder part on GPU | ||
sam_image_encoder = model.image_encoder.to(device) | ||
sam_image_encoder.eval() | ||
|
||
# create the model adapter | ||
sam2_model_adapter = SAM2Adapter( | ||
sam_image_encoder, img_height, img_width, device, "SAM2_Large" | ||
) | ||
|
||
return sam2_model_adapter | ||
|
||
|
||
def get_base_model( | ||
img_height: float, img_width: float, *args, **kwargs | ||
) -> SAM2Adapter: | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
print(f"running on {device}") | ||
# download model's weights | ||
model_url = \ | ||
"https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt" | ||
model_file = download_model( | ||
model_url=model_url, | ||
model_name="sam2.1_hiera_base_plus.pt" | ||
) | ||
if model_file is None: | ||
raise ValueError(f"Could not download the model from {model_url}.") | ||
|
||
# init the model | ||
model: SAM2Base = build_sam2( | ||
config_file= "configs/sam2.1/sam2.1_hiera_b+.yaml", | ||
ckpt_path=model_file, | ||
device="cpu" | ||
) | ||
# to save some GPU memory, only put the encoder part on GPU | ||
sam_image_encoder = model.image_encoder.to(device) | ||
sam_image_encoder.eval() | ||
|
||
# create the model adapter | ||
sam2_model_adapter = SAM2Adapter( | ||
sam_image_encoder, img_height, img_width, device, "SAM2_Base" | ||
) | ||
|
||
return sam2_model_adapter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.