Skip to content

Commit

Permalink
Added SAM2 model and post-processing (#19)
Browse files Browse the repository at this point in the history
* added SAM2

* updates SAM2 model adapter to use all hierarchical features; added SAM2 Large & Base

* updates post-processing using SAM2; updated setup

* updated toml file to include sam2 dep from git

* updated setup & envs yaml

* fixed toml sam-2 dependency; clean-up after post-processing

* commented sam2 dependency in pyproject.toml due to sam2 installation error via pip

* Updated README

* updated numpy to 1.24.4
  • Loading branch information
mese79 authored Nov 9, 2024
1 parent 62b6dc8 commit a8707a3
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 90 deletions.
33 changes: 13 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,16 @@
[![codecov](https://codecov.io/gh/juglab/featureforest/branch/main/graph/badge.svg)](https://codecov.io/gh/juglab/featureforest)
[![napari hub](https://img.shields.io/endpoint?url=https://api.napari-hub.org/shields/featureforest)](https://napari-hub.org/plugins/featureforest)

A napari plugin for segmentation using vision transformers' features.
We developed a *napari* plugin to train a *Random Forest* model using extracted embeddings of ViT models for input and just a few scribble labels provided by the user. This approach can do the segmentation of desired objects almost as well as manual segmentations but in a much shorter time with less manual effort.
**A napari plugin for making image annotation using feature space of vision transformers and random forest classifier.**
We developed a *napari* plugin to train a *Random Forest* model using extracted features of vision foundation models and just a few scribble labels provided by the user as input. This approach can do the segmentation of desired objects almost as well as manual segmentations but in a much shorter time with less manual effort.

----------------------------------

## Documentation
The plugin documentation is [here](docs/index.md).

## Installation
It is highly recommended to use a python environment manager like [conda] to create a clean environment for installation.
You can install all the requirements using provided environment config files:
To install this plugin you need to use [conda] or [mamba] to create a environment and install the requirements. Use the commands below to create the environment and install the plugin:
```bash
# for GPU
conda env create -f ./env_gpu.yml
Expand All @@ -27,9 +26,11 @@ conda env create -f ./env_gpu.yml
conda env create -f ./env_cpu.yml
```

#### Note: You need to install `sam-2` which can be install easily using conda. To install `sam-2` using `pip` please refer to the official [sam-2](https://github.com/facebookresearch/sam2) repository.

### Requirements
- `python >= 3.9`
- `numpy`
- `python >= 3.10`
- `numpy==1.24.4`
- `opencv-python`
- `scikit-learn`
- `scikit-image`
Expand All @@ -39,16 +40,18 @@ conda env create -f ./env_cpu.yml
- `qtpy`
- `napari`
- `h5py`
- `pytorch=2.1.2`
- `torchvision=0.16.2`
- `pytorch=2.3.1`
- `torchvision=0.18.1`
- `timm=1.0.9`
- `pynrrd`
- `segment-anything`
- `sam-2`

If you want to install the plugin manually using GPU, please follow the pytorch installation instruction [here](https://pytorch.org/get-started/locally/).
For detailed napari installation see [here](https://napari.org/stable/tutorials/fundamentals/installation).

### Installing The Plugin
If you use the conda `env.yml` file, the plugin will be installed automatically. But in case you already have the environment setup,
If you use the provided conda environment yaml files, the plugin will be installed automatically. But in case you already have the environment setup,
you can just install the plugin. First clone the repository:
```bash
git clone https://github.com/juglab/featureforest
Expand All @@ -59,17 +62,6 @@ cd ./featureforest
pip install .
```

<!-- You can install `featureforest` via [pip]:
pip install featureforest -->




<!-- ## Contributing
Contributions are very welcome. Tests can be run with [tox], please ensure
the coverage at least stays the same before you submit a pull request. -->

## License

Expand All @@ -96,3 +88,4 @@ If you encounter any problems, please [file an issue] along with a detailed desc
[pip]: https://pypi.org/project/pip/
[PyPI]: https://pypi.org/
[conda]: https://conda.io/projects/conda/en/latest/index.html
[mamba]: https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html
11 changes: 6 additions & 5 deletions env_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- python=3.9
- python=3.10
- pyqt=5.15.10
- qtpy
- magicgui
Expand All @@ -16,16 +16,17 @@ dependencies:
- napari-plugin-engine
- napari-svg
- h5py
- pytorch=2.1.2
- torchvision=0.16.2
- pytorch=2.3.1
- torchvision=0.18.1
- sam-2
- pooch
- pip
- pip:
- numpy==1.23.5 # timm drops deprecation warnings with newer versions
- numpy==1.24.4 # timm drops deprecation warnings with newer versions
- matplotlib
- opencv-python
- timm==1.0.9
- pynrrd
- iopath>=0.1.10
- git+https://github.com/facebookresearch/segment-anything.git
- segment-anything-hq
- git+https://github.com/juglab/featureforest.git
11 changes: 6 additions & 5 deletions env_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
- conda-forge
- defaults
dependencies:
- python=3.9
- python=3.10
- pyqt=5.15.10
- qtpy
- magicgui
Expand All @@ -18,16 +18,17 @@ dependencies:
- napari-svg
- h5py
- pytorch-cuda=11.8
- pytorch=2.1.2
- torchvision=0.16.2
- pytorch=2.3.1
- torchvision=0.18.1
- sam-2
- pooch
- pip
- pip:
- numpy==1.23.5 # timm drops deprecation warnings with newer versions
- numpy==1.24.4 # timm drops deprecation warnings with newer versions
- matplotlib
- opencv-python
- timm==1.0.9
- pynrrd
- iopath>=0.1.10
- git+https://github.com/facebookresearch/segment-anything.git
- segment-anything-hq
- git+https://github.com/juglab/featureforest.git
22 changes: 12 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@ allow-direct-references = true
only-include = ["src"]
sources = ["src"]

[tool.hatch.envs.default.env-vars]
SAM2_BUILD_CUDA="0"

# https://peps.python.org/pep-0621/
[project]
name = "featureforest"
dynamic = ["version"]
description = "A napari plugin for segmentation using vision transformer models' features"
description = "A napari plugin for segmentation using vision transformer features"
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.10"
license = { text = "BSD-3-Clause" }
authors = [
{ name = "Mehdi Seifi", email = "[email protected]" },
Expand All @@ -33,18 +36,14 @@ classifiers = [
"Development Status :: 3 - Alpha",
"Framework :: napari",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: BSD License",
"Topic :: Scientific/Engineering :: Image Processing",
]
dependencies = [
"numpy==1.23.5",
"torch==2.1.2",
"torchvision==0.16.2",
"numpy==1.24.4",
"opencv-python",
"scikit-learn",
"scikit-image",
Expand All @@ -56,9 +55,12 @@ dependencies = [
"napari",
"h5py",
"pooch",
"iopath>=0.1.10",
"torch==2.3.1",
"torchvision==0.18.1",
"timm==1.0.9",
"segment-anything-py",
"segment-anything-hq",
# "sam-2 @ git+https://github.com/facebookresearch/sam2.git"
]
[project.optional-dependencies]
# development dependencies and tooling
Expand Down Expand Up @@ -86,12 +88,12 @@ markers = [

[tool.black]
line-length = 90
target-version = ['py38', 'py39', 'py310']
target-version = ["py310", "py311", "py312"]


[tool.ruff]
line-length = 90
target-version = "py38"
target-version = "py310"
src = ["src"]
select = [
"E", "F", "W", #flake8
Expand Down
2 changes: 0 additions & 2 deletions src/featureforest/models/SAM/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Tuple

import torch

from segment_anything.modeling import Sam
Expand Down
9 changes: 9 additions & 0 deletions src/featureforest/models/SAM2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .model import get_large_model, get_base_model
from .adapter import SAM2Adapter


__all__ = [
"get_large_model",
"get_base_model",
"SAM2Adapter",
]
79 changes: 79 additions & 0 deletions src/featureforest/models/SAM2/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Tuple

import torch
import torch.nn as nn
from torch import Tensor
from torchvision.transforms import v2 as tv_transforms2

from featureforest.models.base import BaseModelAdapter
from featureforest.utils.data import (
get_patch_size,
get_nonoverlapped_patches,
)


class SAM2Adapter(BaseModelAdapter):
"""SAM2 model adapter
"""
def __init__(
self,
image_encoder: nn.Module,
img_height: float,
img_width: float,
device: torch.device,
name: str = "SAM2_Large"
) -> None:
super().__init__(image_encoder, img_height, img_width, device)
# for different flavors of SAM2 only the name is different.
self.name = name
# we need sam2 image encoder part
self.encoder = image_encoder
self.encoder_num_channels = 256
self._set_patch_size()
self.device = device

# input transform for sam
self.sam_input_dim = 1024
self.input_transforms = tv_transforms2.Compose([
tv_transforms2.Resize(
(self.sam_input_dim, self.sam_input_dim),
interpolation=tv_transforms2.InterpolationMode.BICUBIC,
antialias=True
),
])
# to transform feature patches back to the original patch size
self.embedding_transform = tv_transforms2.Compose([
tv_transforms2.Resize(
(self.patch_size, self.patch_size),
interpolation=tv_transforms2.InterpolationMode.BICUBIC,
antialias=True
),
])

def _set_patch_size(self) -> None:
self.patch_size = get_patch_size(self.img_height, self.img_width)
self.overlap = self.patch_size // 2

def get_features_patches(
self, in_patches: Tensor
) -> Tuple[Tensor, Tensor]:
# get the image encoder outputs
with torch.no_grad():
output = self.encoder(
self.input_transforms(in_patches)
)
# backbone_fpn contains 3 levels of features from hight to low resolution.
# [b, 256, 256, 256]
# [b, 256, 128, 128]
# [b, 256, 64, 64]
features = [
self.embedding_transform(feat.cpu())
for feat in output["backbone_fpn"]
]
features = torch.cat(features, dim=1)
out_feature_patches = get_nonoverlapped_patches(features, self.patch_size, self.overlap)

return out_feature_patches

def get_total_output_channels(self) -> int:
return 256 * 3
75 changes: 75 additions & 0 deletions src/featureforest/models/SAM2/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from pathlib import Path

import torch

from sam2.modeling.sam2_base import SAM2Base
from sam2.build_sam import build_sam2

from featureforest.utils.downloader import download_model
from featureforest.models.SAM2.adapter import SAM2Adapter


def get_large_model(
img_height: float, img_width: float, *args, **kwargs
) -> SAM2Adapter:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"running on {device}")
# download model's weights
model_url = \
"https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt"
model_file = download_model(
model_url=model_url,
model_name="sam2.1_hiera_large.pt"
)
if model_file is None:
raise ValueError(f"Could not download the model from {model_url}.")

# init the model
model: SAM2Base = build_sam2(
config_file= "configs/sam2.1/sam2.1_hiera_l.yaml",
ckpt_path=model_file,
device="cpu"
)
# to save some GPU memory, only put the encoder part on GPU
sam_image_encoder = model.image_encoder.to(device)
sam_image_encoder.eval()

# create the model adapter
sam2_model_adapter = SAM2Adapter(
sam_image_encoder, img_height, img_width, device, "SAM2_Large"
)

return sam2_model_adapter


def get_base_model(
img_height: float, img_width: float, *args, **kwargs
) -> SAM2Adapter:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"running on {device}")
# download model's weights
model_url = \
"https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt"
model_file = download_model(
model_url=model_url,
model_name="sam2.1_hiera_base_plus.pt"
)
if model_file is None:
raise ValueError(f"Could not download the model from {model_url}.")

# init the model
model: SAM2Base = build_sam2(
config_file= "configs/sam2.1/sam2.1_hiera_b+.yaml",
ckpt_path=model_file,
device="cpu"
)
# to save some GPU memory, only put the encoder part on GPU
sam_image_encoder = model.image_encoder.to(device)
sam_image_encoder.eval()

# create the model adapter
sam2_model_adapter = SAM2Adapter(
sam_image_encoder, img_height, img_width, device, "SAM2_Base"
)

return sam2_model_adapter
4 changes: 4 additions & 0 deletions src/featureforest/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
from .MobileSAM import get_model as get_mobile_sam_model
from .SAM import get_model as get_sam_model
from .DinoV2 import get_model as get_dino_v2_model
from .SAM2 import get_large_model as get_sam2_large_model
from .SAM2 import get_base_model as get_sam2_base_model


_MODELS_DICT = {
"SAM2_Large": get_sam2_large_model,
"SAM2_Base": get_sam2_base_model,
"MobileSAM": get_mobile_sam_model,
"SAM": get_sam_model,
"DinoV2": get_dino_v2_model,
Expand Down
Loading

0 comments on commit a8707a3

Please sign in to comment.