From a8707a307c3656569dc6ee22a812946bda0022bc Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sat, 9 Nov 2024 18:20:24 +0100 Subject: [PATCH] Added SAM2 model and post-processing (#19) * added SAM2 * updates SAM2 model adapter to use all hierarchical features; added SAM2 Large & Base * updates post-processing using SAM2; updated setup * updated toml file to include sam2 dep from git * updated setup & envs yaml * fixed toml sam-2 dependency; clean-up after post-processing * commented sam2 dependency in pyproject.toml due to sam2 installation error via pip * Updated README * updated numpy to 1.24.4 --- README.md | 33 +++---- env_cpu.yml | 11 ++- env_gpu.yml | 11 ++- pyproject.toml | 22 +++-- src/featureforest/models/SAM/model.py | 2 - src/featureforest/models/SAM2/__init__.py | 9 ++ src/featureforest/models/SAM2/adapter.py | 79 ++++++++++++++++ src/featureforest/models/SAM2/model.py | 75 +++++++++++++++ src/featureforest/models/__init__.py | 4 + .../postprocess/postprocess_with_sam.py | 93 +++++++++---------- 10 files changed, 249 insertions(+), 90 deletions(-) create mode 100644 src/featureforest/models/SAM2/__init__.py create mode 100644 src/featureforest/models/SAM2/adapter.py create mode 100644 src/featureforest/models/SAM2/model.py diff --git a/README.md b/README.md index 46baa62..a003e41 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,8 @@ [![codecov](https://codecov.io/gh/juglab/featureforest/branch/main/graph/badge.svg)](https://codecov.io/gh/juglab/featureforest) [![napari hub](https://img.shields.io/endpoint?url=https://api.napari-hub.org/shields/featureforest)](https://napari-hub.org/plugins/featureforest) -A napari plugin for segmentation using vision transformers' features. -We developed a *napari* plugin to train a *Random Forest* model using extracted embeddings of ViT models for input and just a few scribble labels provided by the user. This approach can do the segmentation of desired objects almost as well as manual segmentations but in a much shorter time with less manual effort. +**A napari plugin for making image annotation using feature space of vision transformers and random forest classifier.** +We developed a *napari* plugin to train a *Random Forest* model using extracted features of vision foundation models and just a few scribble labels provided by the user as input. This approach can do the segmentation of desired objects almost as well as manual segmentations but in a much shorter time with less manual effort. ---------------------------------- @@ -16,8 +16,7 @@ We developed a *napari* plugin to train a *Random Forest* model using extracted The plugin documentation is [here](docs/index.md). ## Installation -It is highly recommended to use a python environment manager like [conda] to create a clean environment for installation. -You can install all the requirements using provided environment config files: +To install this plugin you need to use [conda] or [mamba] to create a environment and install the requirements. Use the commands below to create the environment and install the plugin: ```bash # for GPU conda env create -f ./env_gpu.yml @@ -27,9 +26,11 @@ conda env create -f ./env_gpu.yml conda env create -f ./env_cpu.yml ``` +#### Note: You need to install `sam-2` which can be install easily using conda. To install `sam-2` using `pip` please refer to the official [sam-2](https://github.com/facebookresearch/sam2) repository. + ### Requirements -- `python >= 3.9` -- `numpy` +- `python >= 3.10` +- `numpy==1.24.4` - `opencv-python` - `scikit-learn` - `scikit-image` @@ -39,16 +40,18 @@ conda env create -f ./env_cpu.yml - `qtpy` - `napari` - `h5py` -- `pytorch=2.1.2` -- `torchvision=0.16.2` +- `pytorch=2.3.1` +- `torchvision=0.18.1` - `timm=1.0.9` - `pynrrd` +- `segment-anything` +- `sam-2` If you want to install the plugin manually using GPU, please follow the pytorch installation instruction [here](https://pytorch.org/get-started/locally/). For detailed napari installation see [here](https://napari.org/stable/tutorials/fundamentals/installation). ### Installing The Plugin -If you use the conda `env.yml` file, the plugin will be installed automatically. But in case you already have the environment setup, +If you use the provided conda environment yaml files, the plugin will be installed automatically. But in case you already have the environment setup, you can just install the plugin. First clone the repository: ```bash git clone https://github.com/juglab/featureforest @@ -59,17 +62,6 @@ cd ./featureforest pip install . ``` - - - - - - ## License @@ -96,3 +88,4 @@ If you encounter any problems, please [file an issue] along with a detailed desc [pip]: https://pypi.org/project/pip/ [PyPI]: https://pypi.org/ [conda]: https://conda.io/projects/conda/en/latest/index.html +[mamba]: https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html \ No newline at end of file diff --git a/env_cpu.yml b/env_cpu.yml index 48845e4..08bb62c 100644 --- a/env_cpu.yml +++ b/env_cpu.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: - - python=3.9 + - python=3.10 - pyqt=5.15.10 - qtpy - magicgui @@ -16,16 +16,17 @@ dependencies: - napari-plugin-engine - napari-svg - h5py - - pytorch=2.1.2 - - torchvision=0.16.2 + - pytorch=2.3.1 + - torchvision=0.18.1 + - sam-2 - pooch - pip - pip: - - numpy==1.23.5 # timm drops deprecation warnings with newer versions + - numpy==1.24.4 # timm drops deprecation warnings with newer versions - matplotlib - opencv-python - timm==1.0.9 - pynrrd + - iopath>=0.1.10 - git+https://github.com/facebookresearch/segment-anything.git - - segment-anything-hq - git+https://github.com/juglab/featureforest.git diff --git a/env_gpu.yml b/env_gpu.yml index f61e67e..1071596 100644 --- a/env_gpu.yml +++ b/env_gpu.yml @@ -5,7 +5,7 @@ channels: - conda-forge - defaults dependencies: - - python=3.9 + - python=3.10 - pyqt=5.15.10 - qtpy - magicgui @@ -18,16 +18,17 @@ dependencies: - napari-svg - h5py - pytorch-cuda=11.8 - - pytorch=2.1.2 - - torchvision=0.16.2 + - pytorch=2.3.1 + - torchvision=0.18.1 + - sam-2 - pooch - pip - pip: - - numpy==1.23.5 # timm drops deprecation warnings with newer versions + - numpy==1.24.4 # timm drops deprecation warnings with newer versions - matplotlib - opencv-python - timm==1.0.9 - pynrrd + - iopath>=0.1.10 - git+https://github.com/facebookresearch/segment-anything.git - - segment-anything-hq - git+https://github.com/juglab/featureforest.git diff --git a/pyproject.toml b/pyproject.toml index 97782b2..3684007 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,13 +17,16 @@ allow-direct-references = true only-include = ["src"] sources = ["src"] +[tool.hatch.envs.default.env-vars] +SAM2_BUILD_CUDA="0" + # https://peps.python.org/pep-0621/ [project] name = "featureforest" dynamic = ["version"] -description = "A napari plugin for segmentation using vision transformer models' features" +description = "A napari plugin for segmentation using vision transformer features" readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.10" license = { text = "BSD-3-Clause" } authors = [ { name = "Mehdi Seifi", email = "mehdi.seifi@fht.org" }, @@ -33,8 +36,6 @@ classifiers = [ "Development Status :: 3 - Alpha", "Framework :: napari", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -42,9 +43,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Image Processing", ] dependencies = [ - "numpy==1.23.5", - "torch==2.1.2", - "torchvision==0.16.2", + "numpy==1.24.4", "opencv-python", "scikit-learn", "scikit-image", @@ -56,9 +55,12 @@ dependencies = [ "napari", "h5py", "pooch", + "iopath>=0.1.10", + "torch==2.3.1", + "torchvision==0.18.1", "timm==1.0.9", "segment-anything-py", - "segment-anything-hq", + # "sam-2 @ git+https://github.com/facebookresearch/sam2.git" ] [project.optional-dependencies] # development dependencies and tooling @@ -86,12 +88,12 @@ markers = [ [tool.black] line-length = 90 -target-version = ['py38', 'py39', 'py310'] +target-version = ["py310", "py311", "py312"] [tool.ruff] line-length = 90 -target-version = "py38" +target-version = "py310" src = ["src"] select = [ "E", "F", "W", #flake8 diff --git a/src/featureforest/models/SAM/model.py b/src/featureforest/models/SAM/model.py index bc02edd..b0843b7 100644 --- a/src/featureforest/models/SAM/model.py +++ b/src/featureforest/models/SAM/model.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch from segment_anything.modeling import Sam diff --git a/src/featureforest/models/SAM2/__init__.py b/src/featureforest/models/SAM2/__init__.py new file mode 100644 index 0000000..200635e --- /dev/null +++ b/src/featureforest/models/SAM2/__init__.py @@ -0,0 +1,9 @@ +from .model import get_large_model, get_base_model +from .adapter import SAM2Adapter + + +__all__ = [ + "get_large_model", + "get_base_model", + "SAM2Adapter", +] diff --git a/src/featureforest/models/SAM2/adapter.py b/src/featureforest/models/SAM2/adapter.py new file mode 100644 index 0000000..ed14e00 --- /dev/null +++ b/src/featureforest/models/SAM2/adapter.py @@ -0,0 +1,79 @@ +from typing import Tuple + +import torch +import torch.nn as nn +from torch import Tensor +from torchvision.transforms import v2 as tv_transforms2 + +from featureforest.models.base import BaseModelAdapter +from featureforest.utils.data import ( + get_patch_size, + get_nonoverlapped_patches, +) + + +class SAM2Adapter(BaseModelAdapter): + """SAM2 model adapter + """ + def __init__( + self, + image_encoder: nn.Module, + img_height: float, + img_width: float, + device: torch.device, + name: str = "SAM2_Large" + ) -> None: + super().__init__(image_encoder, img_height, img_width, device) + # for different flavors of SAM2 only the name is different. + self.name = name + # we need sam2 image encoder part + self.encoder = image_encoder + self.encoder_num_channels = 256 + self._set_patch_size() + self.device = device + + # input transform for sam + self.sam_input_dim = 1024 + self.input_transforms = tv_transforms2.Compose([ + tv_transforms2.Resize( + (self.sam_input_dim, self.sam_input_dim), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True + ), + ]) + # to transform feature patches back to the original patch size + self.embedding_transform = tv_transforms2.Compose([ + tv_transforms2.Resize( + (self.patch_size, self.patch_size), + interpolation=tv_transforms2.InterpolationMode.BICUBIC, + antialias=True + ), + ]) + + def _set_patch_size(self) -> None: + self.patch_size = get_patch_size(self.img_height, self.img_width) + self.overlap = self.patch_size // 2 + + def get_features_patches( + self, in_patches: Tensor + ) -> Tuple[Tensor, Tensor]: + # get the image encoder outputs + with torch.no_grad(): + output = self.encoder( + self.input_transforms(in_patches) + ) + # backbone_fpn contains 3 levels of features from hight to low resolution. + # [b, 256, 256, 256] + # [b, 256, 128, 128] + # [b, 256, 64, 64] + features = [ + self.embedding_transform(feat.cpu()) + for feat in output["backbone_fpn"] + ] + features = torch.cat(features, dim=1) + out_feature_patches = get_nonoverlapped_patches(features, self.patch_size, self.overlap) + + return out_feature_patches + + def get_total_output_channels(self) -> int: + return 256 * 3 diff --git a/src/featureforest/models/SAM2/model.py b/src/featureforest/models/SAM2/model.py new file mode 100644 index 0000000..679b6c2 --- /dev/null +++ b/src/featureforest/models/SAM2/model.py @@ -0,0 +1,75 @@ +from pathlib import Path + +import torch + +from sam2.modeling.sam2_base import SAM2Base +from sam2.build_sam import build_sam2 + +from featureforest.utils.downloader import download_model +from featureforest.models.SAM2.adapter import SAM2Adapter + + +def get_large_model( + img_height: float, img_width: float, *args, **kwargs +) -> SAM2Adapter: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"running on {device}") + # download model's weights + model_url = \ + "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt" + model_file = download_model( + model_url=model_url, + model_name="sam2.1_hiera_large.pt" + ) + if model_file is None: + raise ValueError(f"Could not download the model from {model_url}.") + + # init the model + model: SAM2Base = build_sam2( + config_file= "configs/sam2.1/sam2.1_hiera_l.yaml", + ckpt_path=model_file, + device="cpu" + ) + # to save some GPU memory, only put the encoder part on GPU + sam_image_encoder = model.image_encoder.to(device) + sam_image_encoder.eval() + + # create the model adapter + sam2_model_adapter = SAM2Adapter( + sam_image_encoder, img_height, img_width, device, "SAM2_Large" + ) + + return sam2_model_adapter + + +def get_base_model( + img_height: float, img_width: float, *args, **kwargs +) -> SAM2Adapter: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"running on {device}") + # download model's weights + model_url = \ + "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt" + model_file = download_model( + model_url=model_url, + model_name="sam2.1_hiera_base_plus.pt" + ) + if model_file is None: + raise ValueError(f"Could not download the model from {model_url}.") + + # init the model + model: SAM2Base = build_sam2( + config_file= "configs/sam2.1/sam2.1_hiera_b+.yaml", + ckpt_path=model_file, + device="cpu" + ) + # to save some GPU memory, only put the encoder part on GPU + sam_image_encoder = model.image_encoder.to(device) + sam_image_encoder.eval() + + # create the model adapter + sam2_model_adapter = SAM2Adapter( + sam_image_encoder, img_height, img_width, device, "SAM2_Base" + ) + + return sam2_model_adapter diff --git a/src/featureforest/models/__init__.py b/src/featureforest/models/__init__.py index 06f0c4b..50c2541 100644 --- a/src/featureforest/models/__init__.py +++ b/src/featureforest/models/__init__.py @@ -6,9 +6,13 @@ from .MobileSAM import get_model as get_mobile_sam_model from .SAM import get_model as get_sam_model from .DinoV2 import get_model as get_dino_v2_model +from .SAM2 import get_large_model as get_sam2_large_model +from .SAM2 import get_base_model as get_sam2_base_model _MODELS_DICT = { + "SAM2_Large": get_sam2_large_model, + "SAM2_Base": get_sam2_base_model, "MobileSAM": get_mobile_sam_model, "SAM": get_sam_model, "DinoV2": get_dino_v2_model, diff --git a/src/featureforest/postprocess/postprocess_with_sam.py b/src/featureforest/postprocess/postprocess_with_sam.py index c9ba6d7..f290195 100644 --- a/src/featureforest/postprocess/postprocess_with_sam.py +++ b/src/featureforest/postprocess/postprocess_with_sam.py @@ -8,44 +8,37 @@ import torch from cv2.typing import Rect -from segment_anything_hq.modeling import Sam -from segment_anything_hq.build_sam import build_sam_vit_t -from segment_anything_hq import SamPredictor, SamAutomaticMaskGenerator +from sam2.modeling.sam2_base import SAM2Base +from sam2.build_sam import build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor +from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from featureforest.utils.downloader import download_model from featureforest.utils.data import is_image_rgb, image_to_uint8 from .postprocess import postprocess_label_mask -def get_light_hq_sam() -> Sam: - """Load the Light HQ SAM model instance. This model produces better masks. - - Raises: - ValueError: if model's weights could not be downloaded. - - Returns: - Sam: a Light HQ SAM model instance - """ - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"running on {device}") - # download model's weights +def get_sam2() -> SAM2Base: model_url = \ - "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_tiny.pth" + "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt" model_file = download_model( model_url=model_url, - model_name="sam_hq_vit_tiny.pth" + model_name="sam2.1_hiera_base_plus.pt" ) if model_file is None: raise ValueError(f"Could not download the model from {model_url}.") - # init & load the light hq sam model - lhq_sam = build_sam_vit_t().to(device) - lhq_sam.load_state_dict( - torch.load(model_file, map_location=device) + # init the model + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"running on {device}") + sam2_model: SAM2Base = build_sam2( + config_file= "configs/sam2.1/sam2.1_hiera_b+.yaml", + ckpt_path=model_file, + device=device ) - lhq_sam.eval() + sam2_model.eval() - return lhq_sam + return sam2_model def get_watershed_bboxes(image: np.ndarray) -> List[Rect]: @@ -113,12 +106,12 @@ def get_bounding_boxes(image: np.ndarray) -> List[Rect]: def get_sam_mask( - predictor: SamPredictor, image: np.ndarray, bboxes: List[Rect] + predictor: SAM2ImagePredictor, image: np.ndarray, bboxes: List[Rect] ) -> np.ndarray: """Returns a mask aggregated by sam predictor masks for each given bounding box. Args: - predictor (SamPredictor): the sam predictor instance + predictor (SAM2ImagePredictor): the sam predictor instance image (np.ndarray): input binary image bboxes (List[Rect]): bounding boxes @@ -128,33 +121,29 @@ def get_sam_mask( # sam needs an RGB image image = np.repeat(image[:, :, np.newaxis], 3, axis=2) predictor.set_image(image) - # get sam-ready bounding boxes: x,y,w,h - input_boxes = torch.tensor([ + # get sam-ready bounding boxes: x,y,w,h -> x1,y1,x2,y2 + input_boxes = np.array([ (box[0], box[1], box[0] + box[2], box[1] + box[3]) for box in bboxes - ]).to(predictor.device) - transformed_boxes = predictor.transform.apply_boxes_torch( - input_boxes, image.shape[:2] - ) + ]) # get sam predictor masks bs = 16 - num_batches = np.ceil(len(transformed_boxes) / bs).astype(int) + num_batches = np.ceil(len(bboxes) / bs).astype(int) final_mask = np.zeros((image.shape[0], image.shape[1]), dtype=bool) for i in np_progress( range(num_batches), desc="Generating masks using SAM predictor" ): start = i * bs end = start + bs - masks, _, _ = predictor.predict_torch( + masks, _, _ = predictor.predict( point_coords=None, point_labels=None, - boxes=transformed_boxes[start:end], + box=input_boxes[start:end], multimask_output=True, ) - masks_np = masks.squeeze(1).cpu().numpy() final_mask = np.bitwise_or( final_mask, - np.bitwise_or.reduce(masks_np, axis=0) + np.bitwise_or.reduce(masks.astype(bool), axis=(0, 1)) ) return final_mask @@ -180,8 +169,8 @@ def postprocess_with_sam( Returns: np.ndarray: post-processed segmentation image """ - # init a sam predictor using light hq sam - predictor = SamPredictor(get_light_hq_sam()) + # init a sam predictor using SAM2 Base Plus + predictor = SAM2ImagePredictor(get_sam2()) final_mask = np.zeros_like(segmentation_image, dtype=np.uint8) # postprocessing gets done for each label's mask separately. @@ -204,6 +193,10 @@ def postprocess_with_sam( # put the final label mask into final result mask final_mask[sam_label_mask] = label + # clean-up + del predictor + torch.cuda.empty_cache() + return final_mask @@ -225,24 +218,28 @@ def get_sam_auto_masks(input_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray] # normalize the image in [0, 255] as uint8 image = image_to_uint8(input_image.copy()) # init a sam auto-segmentation mask generator - mask_generator = SamAutomaticMaskGenerator( - model=get_light_hq_sam(), - points_per_side=64, + mask_generator = SAM2AutomaticMaskGenerator( + model=get_sam2(), + points_per_side=50, pred_iou_thresh=0.8, - stability_score_thresh=0.85, + stability_score_thresh=0.88, stability_score_offset=0.9, crop_n_layers=1, - crop_n_points_downscale_factor=2, - # crop_nms_thresh=0.7, - min_mask_region_area=20 + crop_n_points_downscale_factor=8, + min_mask_region_area=20, + use_m2m=True ) - # generate SAM masks - print("generating masks using SamAutomaticMaskGenerator...") - with np_progress(range(1), desc="Generating masks using SamAutomaticMaskGenerator"): + # generate SAM2 masks + print("generating masks using SAM2AutomaticMaskGenerator...") + with np_progress(range(1), desc="Generating masks using SAM2AutomaticMaskGenerator"): sam_generated_masks = mask_generator.generate(image) sam_masks = np.array([mask["segmentation"] for mask in sam_generated_masks]) sam_areas = np.array([mask["area"] for mask in sam_generated_masks]) + # clean-up + del mask_generator + torch.cuda.empty_cache() + return sam_masks, sam_areas