Skip to content

Commit

Permalink
Add encoder flexibility in UNETR (#10)
Browse files Browse the repository at this point in the history
* Update distance transform (#11 - Getting instance segmentations from the distance transform)

* Making the encoder argument flexible (to either pytorch modules or model name as str)
  • Loading branch information
anwai98 authored Dec 20, 2023
1 parent 41bab4a commit 7f06c97
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 12 deletions.
35 changes: 35 additions & 0 deletions scripts/vision_transformer/load_sam_encoder_in_unetr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch

from torch_em.model import UNETR

from micro_sam.util import get_sam_model


def main():
checkpoint = "/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth"
model_type = "vit_b"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

predictor = get_sam_model(
model_type=model_type,
checkpoint_path=checkpoint
)

model = UNETR(
backbone="sam",
encoder=predictor.model.image_encoder,
out_channels=3,
use_sam_stats=True,
final_activation="Sigmoid",
use_skip_connection=False
)
model.to(device)

x = torch.ones((1, 1, 512, 512)).to(device)
y = model(x)

print("UNETR Model successfully created and encoder initialized from", checkpoint)


if __name__ == "__main__":
main()
52 changes: 40 additions & 12 deletions torch_em/model/unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn.functional as F

from .unet import Decoder, ConvBlock2d, Upsampler2d
from .vit import get_vision_transformer
from .vit import get_vision_transformer, ViT_MAE, ViT_Sam

try:
from micro_sam.util import get_sam_model
Expand All @@ -24,7 +24,7 @@ class UNETR(nn.Module):
def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):

if isinstance(checkpoint, str):
if backbone == "sam":
if backbone == "sam" and isinstance(encoder, str):
# If we have a SAM encoder, then we first try to load the full SAM Model
# (using micro_sam) and otherwise fall back on directly loading the encoder state
# from the checkpoint
Expand Down Expand Up @@ -63,25 +63,47 @@ def __init__(
self,
img_size: int = 1024,
backbone: str = "sam",
encoder: str = "vit_b",
encoder: Optional[Union[nn.Module, str]] = "vit_b",
decoder: Optional[nn.Module] = None,
out_channels: int = 1,
use_sam_stats: bool = False,
use_mae_stats: bool = False,
encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
final_activation: Optional[Union[str, nn.Module]] = None,
use_skip_connection: bool = True
use_skip_connection: bool = True,
embed_dim: Optional[int] = None
) -> None:
super().__init__()

self.use_sam_stats = use_sam_stats
self.use_mae_stats = use_mae_stats
self.use_skip_connection = use_skip_connection

print(f"Using {encoder} from {backbone.upper()}")
self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder)
if encoder_checkpoint is not None:
self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)
if isinstance(encoder, str): # "vit_b" / "vit_l" / "vit_h"
print(f"Using {encoder} from {backbone.upper()}")
self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder)
if encoder_checkpoint is not None:
self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)

in_chans = self.encoder.in_chans
if embed_dim is None:
embed_dim = self.encoder.embed_dim

else: # `nn.Module` ViT backbone
self.encoder = encoder

have_neck = False
for name, _ in self.encoder.named_parameters():
if name.startswith("neck"):
have_neck = True

if embed_dim is None:
if have_neck:
embed_dim = self.encoder.neck[2].out_channels # the value is 256
else:
embed_dim = self.encoder.patch_embed.proj.out_channels

in_chans = self.encoder.patch_embed.proj.in_channels

# parameters for the decoder network
depth = 3
Expand All @@ -101,12 +123,13 @@ def __init__(
else:
self.decoder = decoder

self.z_inputs = ConvBlock2d(self.encoder.in_chans, features_decoder[-1])
self.z_inputs = ConvBlock2d(in_chans, features_decoder[-1])

self.base = ConvBlock2d(embed_dim, features_decoder[0])

self.base = ConvBlock2d(self.encoder.embed_dim, features_decoder[0])
self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)

self.deconv1 = Deconv2DBlock(self.encoder.embed_dim, features_decoder[0])
self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0])
self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1])
self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2])
self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3])
Expand Down Expand Up @@ -173,7 +196,12 @@ def forward(self, x):

use_skip_connection = getattr(self, "use_skip_connection", True)

z12, from_encoder = self.encoder(x)
encoder_outputs = self.encoder(x)

if isinstance(self.encoder, ViT_Sam) or isinstance(self.encoder, ViT_MAE):
z12, from_encoder = encoder_outputs
else:
z12 = encoder_outputs

if use_skip_connection:
# TODO: we share the weights in the deconv(s), and should preferably avoid doing that
Expand Down
5 changes: 5 additions & 0 deletions torch_em/transform/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def __init__(
boundary_distances=True,
directed_distances=False,
foreground=True,
instances=False,
apply_label=True,
correct_centers=True,
min_size=0,
Expand All @@ -313,6 +314,7 @@ def __init__(
self.boundary_distances = boundary_distances
self.directed_distances = directed_distances
self.foreground = foreground
self.instances = instances

self.apply_label = apply_label
self.correct_centers = correct_centers
Expand Down Expand Up @@ -441,4 +443,7 @@ def __call__(self, labels):
binary_labels = (labels > 0).astype("float32")
distances = np.concatenate([binary_labels[None], distances], axis=0)

if self.instances:
distances = np.concatenate([labels[None], distances], axis=0)

return distances

0 comments on commit 7f06c97

Please sign in to comment.