This repo contains a simplified implementation of the awesome 'Segment Anything' models from facebookresearch/segment-anything-2 (and SAMV1), with the intention of removing the magic from the original code base to make it easier to understand. Most of the changes come from separating/simplifying the different components of the model structure.
While the focus of this implementation is on interactivity and readability of the model code, it includes support for arbitrary input resolutions, which can improve performance in some cases. For example, at reduced resolutions, SAMv2 gets a ~4x speed up on video segmentation.
Note
This repo is a (messy) work-in-progress! The end goal is to have something resembling MuggledDPT.
This repo includes two demo scripts, run_image.py and run_video.py (along with a number of simple examples and experiments). To use these scripts, you'll first need to have Python (v3.10+) installed, then set up a virtual environment and install some additional requirements.
First create and activate a virtual environment (do this inside the repo folder after cloning/downloading it):
# For linux or mac:
python3 -m venv .env
source .env/bin/activate
# For windows (cmd):
python -m venv .env
.env\Scripts\activate.bat
Then install the requirements (or you could install them manually from the requirements.txt file):
pip install -r requirements.txt
Additional info for GPU usage
If you're using Windows and want to use an Nvidia GPU or if you're on Linux and don't have a GPU, you'll need to use a slightly different install command to make use of your hardware setup. You can use the Pytorch installer guide to figure out the command to use. For example, for GPU use on Windows it may look something like:
pip3 uninstall torch # <-- Do this first if you already installed from the requirements.txt file
pip3 install torch --index-url https://download.pytorch.org/whl/cu121
Note: With the Windows install as-is, you may get an error about a missing c10.dll
dependency. Downloading and installing this mysterious .exe file seems to fix the problem.
Before you can run a model, you'll need to download it's weights. There are 3 officially supported SAMv1 models (vit-base, vit-large and vit-huge) and four v2 models (tiny, small, base-plus and large). This repo uses the exact same weights as the original implementations (or any fine-tuned variant of the original models), which can be downloaded from the Download Checkpoints section of SAMv2 repo and the Model Checkpoints section of the SAMv1 repo.
After downloading a model file, you can place it in the model_weights
folder of this repo or otherwise just keep note of the file path, since you'll need to provide this when running the demo scripts. If you do place the file in the model_weights folder, then it will auto-load when running the scripts.
Direct download links
The tables below include direct download links to all of the supported models. Note: These are all links to the original repos, none of these files belong to MuggledSAM!
SAMv2 Models | Size (MB) |
---|---|
sam2_hiera_tiny | 160 |
sam2_hiera_small | 185 |
sam2_hiera_base_plus | 325 |
sam2_hiera_large | 900 |
SAMv1 Models | Size (MB) |
---|---|
sam-vit-base | 375 |
sam-vit-large | 1250 |
sam-vit-huge | 2560 |
Here is an example of using the model to generate masks from an image:
import cv2
from lib.make_sam import make_sam_from_state_dict
# Define prompts using 0-to-1 xy coordinates
box_tlbrs = [] # Example [((0.25, 0.25), (0.75, 0.75))]
fg_xys = [(0.5, 0.5)]
bg_xys = []
# Load image & model
image_bgr = cv2.imread("/path/to/image.jpg")
_, model = make_sam_from_state_dict("/path/to/model.pth")
# Process data
encoded_img, _, _ = model.encode_image(image_bgr)
encoded_prompts = model.encode_prompts(box_tlbrs, fg_xys, bg_xys)
mask_preds, iou_preds = model.generate_masks(encoded_img, encoded_prompts)
The run_image.py
script will run the segment-anything model on a single image with an interactive UI running locally. To use the script, make sure you've activated the virtual environment (from the installation step) and then, from the repo folder use:
python run_image.py
You can also add --help
to the end of this command to see a list of additional flags you can set when running this script. For example, two interesting options are the --crop
flag to interactively crop an image prior to processing and the -b
flag, which can change the processing resolution of the model.
If you don't provide an image path (using the -i
flag), then you will be asked to provide one when you run the script, likewise for a path to the model weights. Afterwards, a window will pop-up, with options for how to 'prompt' the model (e.g. bounding boxes or clicking to add points) along the top and various sliders to alter the segmentation results at the bottom. Results can be saved by pressing the s
key.
The run_video.py
script allows for segmentation of videos using an interactive UI running locally. However, it only works with SAMv2 models!
To use the script, make sure you've activated the virtual environment (from the installation step) and then, from the repo folder use:
python run_video.py
As with the image script, you can add --help
to the end of this command to see a list of additional flags. For example, you can add the flag --use_webcam
to run segmentation on a live webcam feed. Using -b 512
to reduce the processing resolution can provide a significant speed up if needed (box prompting works better at reduced resolutions btw!).
This script is a messy work-in-progress for now, more features & stability updates to come! If you'd like a more hackable solution, check out the (much easier to follow) video segmentation example.
The code in this repo is entirely based off the original segment-anything github repos:
facebookresearch/segment-anything
@article{kirillov2023segany,
title={Segment Anything},
author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
journal={arXiv:2304.02643},
year={2023}
}
facebookresearch/segment-anything-2
@article{ravi2024sam2,
title={SAM 2: Segment Anything in Images and Videos},
author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph},
journal={arXiv preprint},
year={2024}
}
- Clean up code base (especially the image encoder, which is unfinished)
- Add interactive script replicating the original 'automatic mask geneartor'
- Add model structure documentation
- Add various experiment scripts (onnx export, mask prompts, attention vis etc.)
- Inevitable bugfixes