Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FaCT Finetuning for SAM #682

Merged
merged 18 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions finetuning/evaluation/evaluate_amg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from micro_sam.evaluation.evaluation import run_evaluation
from micro_sam.evaluation.inference import run_amg

from util import get_paths # comment this and create a custom function with the same name to run amg on your data
from util import get_pred_paths, get_default_arguments, VANILLA_MODELS
from util import (
get_paths, # comment this line out and create a custom function with the same name to run amg on your data
get_pred_paths, get_default_arguments
)


def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, lora_rank):
def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, peft_kwargs):
val_image_paths, val_gt_paths = get_paths(dataset_name, split="val")
test_image_paths, _ = get_paths(dataset_name, split="test")
prediction_folder = run_amg(
Expand All @@ -17,7 +19,7 @@ def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, l
val_image_paths=val_image_paths,
val_gt_paths=val_gt_paths,
test_image_paths=test_image_paths,
lora_rank=lora_rank,
peft_kwargs=peft_kwargs,
)
return prediction_folder

Expand All @@ -33,12 +35,10 @@ def eval_amg(dataset_name, prediction_folder, experiment_folder):

def main():
args = get_default_arguments()
if args.checkpoint is None:
ckpt = VANILLA_MODELS[args.model]
else:
ckpt = args.checkpoint

prediction_folder = run_amg_inference(args.dataset, args.model, ckpt, args.experiment_folder, args.lora_rank)
peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module}
prediction_folder = run_amg_inference(
args.dataset, args.model, args.checkpoint, args.experiment_folder, peft_kwargs
)
eval_amg(args.dataset, prediction_folder, args.experiment_folder)


Expand Down
14 changes: 8 additions & 6 deletions finetuning/evaluation/evaluate_instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from micro_sam.evaluation.evaluation import run_evaluation
from micro_sam.evaluation.inference import run_instance_segmentation_with_decoder

from util import get_paths # comment this and create a custom function with the same name to run ais on your data
from util import get_pred_paths, get_default_arguments
from util import (
get_paths, # comment this line out and create a custom function with the same name to run ais on your data
get_pred_paths, get_default_arguments
)


def run_instance_segmentation_with_decoder_inference(
dataset_name, model_type, checkpoint, experiment_folder, lora_rank
dataset_name, model_type, checkpoint, experiment_folder, peft_kwargs,
):
val_image_paths, val_gt_paths = get_paths(dataset_name, split="val")
test_image_paths, _ = get_paths(dataset_name, split="test")
Expand All @@ -19,7 +21,7 @@ def run_instance_segmentation_with_decoder_inference(
val_image_paths=val_image_paths,
val_gt_paths=val_gt_paths,
test_image_paths=test_image_paths,
lora_rank=lora_rank,
peft_kwargs=peft_kwargs,
)
return prediction_folder

Expand All @@ -35,9 +37,9 @@ def eval_instance_segmentation_with_decoder(dataset_name, prediction_folder, exp

def main():
args = get_default_arguments()

peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module}
prediction_folder = run_instance_segmentation_with_decoder_inference(
args.dataset, args.model, args.checkpoint, args.experiment_folder, args.lora_rank,
args.dataset, args.model, args.checkpoint, args.experiment_folder, peft_kwargs,
)
eval_instance_segmentation_with_decoder(args.dataset, prediction_folder, args.experiment_folder)

Expand Down
12 changes: 9 additions & 3 deletions finetuning/evaluation/iterative_prompting.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os

from micro_sam.util import get_sam_model
from micro_sam.evaluation import inference
from micro_sam.evaluation.evaluation import run_evaluation_for_iterative_prompting

from util import get_paths # comment this and create a custom function with the same name to run int. seg. on your data
from util import get_model, get_default_arguments
from util import (
get_paths, # comment this line out and create a custom function with the same name to run int. seg. on your data
get_default_arguments
)


def _run_iterative_prompting(dataset_name, exp_folder, predictor, start_with_box_prompt, use_masks):
Expand Down Expand Up @@ -42,7 +45,10 @@ def main():
start_with_box_prompt = args.box # overwrite to start first iters' prompt with box instead of single point

# get the predictor to perform inference
predictor = get_model(model_type=args.model, ckpt=args.checkpoint, lora_rank=args.lora_rank)
peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module}
predictor = get_sam_model(
model_type=args.model, checkpoint_path=args.checkpoint, peft_kwargs=peft_kwargs,
)

prediction_root = _run_iterative_prompting(
args.dataset, args.experiment_folder, predictor, start_with_box_prompt, args.use_masks
Expand Down
12 changes: 9 additions & 3 deletions finetuning/evaluation/precompute_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import os

from micro_sam.util import get_sam_model
from micro_sam.evaluation import precompute_all_embeddings

from util import get_paths # comment this and create a custom function with the same name to execute on your data
from util import get_model, get_default_arguments
from util import (
get_paths, # comment this and create a custom function with the same name to execute on your data
get_default_arguments
)


def main():
args = get_default_arguments()

predictor = get_model(model_type=args.model, ckpt=args.checkpoint, lora_rank=args.lora_rank)
peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module}
predictor = get_sam_model(
model_type=args.model, checkpoint_path=args.checkpoint, peft_kwargs=peft_kwargs,
)
embedding_dir = os.path.join(args.experiment_folder, "embeddings")
os.makedirs(embedding_dir, exist_ok=True)

Expand Down
2 changes: 1 addition & 1 deletion finetuning/evaluation/submit_all_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,4 +221,4 @@ def main(args):
parser.add_argument("-s", "--specific_experiment", type=str, default=None)

args = parser.parse_args()
main(args)
main(args)
16 changes: 3 additions & 13 deletions finetuning/evaluation/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from torch_em.data import datasets

from micro_sam.util import get_sam_model
from micro_sam.evaluation.livecell import _get_livecell_paths


Expand Down Expand Up @@ -80,16 +79,6 @@ def get_dataset_paths(dataset_name, split_choice):
return raw_dir, labels_dir


def get_model(model_type, ckpt, lora_rank):
if ckpt is None:
ckpt = VANILLA_MODELS[model_type]

predictor = get_sam_model(
model_type=model_type, checkpoint_path=ckpt, lora_rank=lora_rank,
)
return predictor


def get_paths(dataset_name, split):
assert dataset_name in DATASETS, dataset_name

Expand Down Expand Up @@ -222,14 +211,15 @@ def get_default_arguments():
parser.add_argument(
"-m", "--model", type=str, required=True, help="Provide the model type to initialize the predictor"
)
parser.add_argument("-c", "--checkpoint", type=none_or_str, required=True, default=None)
parser.add_argument("-c", "--checkpoint", type=none_or_str, default=None)
parser.add_argument("-e", "--experiment_folder", type=str, required=True)
parser.add_argument("-d", "--dataset", type=str, default=None)
parser.add_argument("--box", action="store_true", help="If passed, starts with first prompt as box")
parser.add_argument(
"--use_masks", action="store_true", help="To use logits masks for iterative prompting."
)
parser.add_argument("--lora_rank", default=None, type=int, help="The rank for low rank adaptation method.")
parser.add_argument("--peft_rank", default=None, type=int, help="The rank for peft method.")
parser.add_argument("--peft_module", default=None, type=str, help="The module for peft method. (e.g. LoRA or FacT)")
args = parser.parse_args()
return args

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,4 @@ def main():
if __name__ == "__main__":
import warnings
warnings.filterwarnings("ignore")
main()
main()
8 changes: 4 additions & 4 deletions micro_sam/evaluation/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,12 +547,12 @@ def run_amg(
test_image_paths: List[Union[str, os.PathLike]],
iou_thresh_values: Optional[List[float]] = None,
stability_score_values: Optional[List[float]] = None,
lora_rank: Optional[int] = None,
peft_kwargs: Optional[Dict] = None,
) -> str:
embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved
os.makedirs(embedding_folder, exist_ok=True)

predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, lora_rank=lora_rank)
predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs)
amg = AutomaticMaskGenerator(predictor)
amg_prefix = "amg"

Expand Down Expand Up @@ -589,13 +589,13 @@ def run_instance_segmentation_with_decoder(
val_image_paths: List[Union[str, os.PathLike]],
val_gt_paths: List[Union[str, os.PathLike]],
test_image_paths: List[Union[str, os.PathLike]],
lora_rank: Optional[int] = None,
peft_kwargs: Optional[Dict] = None,
) -> str:
embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved
os.makedirs(embedding_folder, exist_ok=True)

predictor, decoder = get_predictor_and_decoder(
model_type=model_type, checkpoint_path=checkpoint, lora_rank=lora_rank,
model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs,
)
segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
seg_prefix = "instance_segmentation_with_decoder"
Expand Down
4 changes: 2 additions & 2 deletions micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ def get_predictor_and_decoder(
model_type: str,
checkpoint_path: Union[str, os.PathLike],
device: Optional[Union[str, torch.device]] = None,
lora_rank: Optional[int] = None,
peft_kwargs: Optional[Dict] = None,
) -> Tuple[SamPredictor, DecoderAdapter]:
"""Load the SAM model (predictor) and instance segmentation decoder.

Expand All @@ -823,7 +823,7 @@ def get_predictor_and_decoder(
checkpoint_path=checkpoint_path,
device=device,
return_state=True,
lora_rank=lora_rank,
peft_kwargs=peft_kwargs,
)
if "decoder_state" not in state:
raise ValueError(f"The checkpoint at {checkpoint_path} does not contain a decoder state")
Expand Down
59 changes: 56 additions & 3 deletions micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import List, Union
from typing import List, Union, Optional

import torch.nn as nn

Expand Down Expand Up @@ -56,11 +56,64 @@ def forward(self, x):
return qkv


class FacTSurgery(nn.Module):
"""Operates on the attention layers for performing factorized attention.

(Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py)

Args:
rank: The rank of the decomposition matrices for updating weights in each attention layer.
block: The chosen attention blocks for implementing fact.
"""

def __init__(
self,
rank: int,
block: nn.Module,
dropout: Optional[float] = None,
):
super().__init__()
self.qkv_proj = block.attn.qkv
self.dim = self.qkv_proj.in_features

self.q_FacTs = nn.Linear(rank, rank, bias=False)
self.v_FacTs = nn.Linear(rank, rank, bias=False)

self.dropout = dropout
if self.dropout is not None:
# NOTE : Dropout is not included in the original implementation
self.dp_q = nn.Dropout(self.dropout)
self.dp_v = nn.Dropout(self.dropout)

self.FacTu = nn.Linear(self.dim, rank, bias=False)
self.FacTv = nn.Linear(rank, self.dim, bias=False)

block.attn.qkv = self

def forward(self, x):
qkv = self.qkv_proj(x) # B, N, N, 3 * org_C

new_q = self.q_FacTs(self.FacTu(x))
new_v = self.v_FacTs(self.FacTu(x))

if self.dropout is not None:
new_q = self.dp_q(new_q)
new_v = self.dp_v(new_v)

new_q = self.FacTv(new_q)
new_v = self.FacTv(new_v)

# NOTE : Scaling Factor was set to 1 as it can be tuned via the learning rate
# Does it make sense to include it, in order to have similar learning rate as the original model?
qkv[:, :, :, : self.dim] += new_q
qkv[:, :, :, -self.dim:] += new_v

return qkv


class PEFT_Sam(nn.Module):
"""Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.

Inspired by https://github.com/JamesQFreeman/Sam_LoRA/

Args:
model: The Segment Anything model.
rank: The rank for low-rank adaptation.
Expand Down
10 changes: 7 additions & 3 deletions micro_sam/models/sam_3d_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ def get_sam_3d_model(
model_type="vit_b",
checkpoint_path=None,
):
# Make sure not to freeze the encoder when using LoRA.
freeze_encoder_ = freeze_encoder if lora_rank is None else False
peft_kwargs = {}
if lora_rank is not None:
peft_kwargs["rank"] = lora_rank

_, sam = get_sam_model(
model_type=model_type,
device=device,
Expand All @@ -28,9 +30,11 @@ def get_sam_3d_model(
flexible_load_checkpoint=True,
num_multimask_outputs=n_classes,
image_size=image_size,
lora_rank=lora_rank,
peft_kwargs=peft_kwargs,
)

# Make sure not to freeze the encoder when using LoRA.
freeze_encoder_ = freeze_encoder if lora_rank is None else False
sam_3d = Sam3DWrapper(sam, freeze_encoder=freeze_encoder_)
sam_3d.to(device)
return sam_3d
Expand Down
6 changes: 5 additions & 1 deletion micro_sam/models/simple_sam_3d_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@ def get_simple_sam_3d_model(
model_type="vit_b",
checkpoint_path=None,
):
peft_kwargs = {}
if lora_rank is not None:
peft_kwargs["rank"] = lora_rank

_, sam = get_sam_model(
model_type=model_type,
device=device,
checkpoint_path=checkpoint_path,
return_sam=True,
image_size=image_size,
flexible_load_checkpoint=True,
lora_rank=lora_rank,
peft_kwargs=peft_kwargs,
)

# Make sure not to freeze the encoder when using LoRA.
Expand Down
4 changes: 2 additions & 2 deletions micro_sam/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def train_sam(
save_every_kth_epoch: Optional[int] = None,
pbar_signals: Optional[QObject] = None,
optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
peft_kwargs: Optional[Dict] = None,
**model_kwargs,
) -> None:
"""Run training for a SAM model.
Expand Down Expand Up @@ -196,17 +197,16 @@ def train_sam(
_check_loader(val_loader, with_segmentation_decoder)

device = get_device(device)

# Get the trainable segment anything model.
model, state = get_trainable_sam_model(
model_type=model_type,
device=device,
freeze=freeze,
checkpoint_path=checkpoint_path,
return_state=True,
peft_kwargs=peft_kwargs,
**model_kwargs
)

# This class creates all the training data for a batch (inputs, prompts and labels).
convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)

Expand Down
Loading
Loading