Skip to content

Commit

Permalink
UNETR - Make skip connections optional (#9)
Browse files Browse the repository at this point in the history
Update UNETR model to optionally use skip connections or not
  • Loading branch information
anwai98 authored Dec 16, 2023
1 parent 741366f commit 41bab4a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 16 deletions.
13 changes: 10 additions & 3 deletions experiments/vision-transformer/unetr/livecell/train_by_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def prune_prefix(checkpoint_path):
return updated_model_state


def get_custom_unetr_model(device, model_name, sam_initialization, output_channels, checkpoint_path, freeze_encoder):
def get_custom_unetr_model(
device, model_name, sam_initialization, output_channels, checkpoint_path, freeze_encoder, joint_training
):
if checkpoint_path is not None:
if checkpoint_path.endswith("pt"): # for finetuned models
model_state = prune_prefix(checkpoint_path)
Expand All @@ -37,8 +39,10 @@ def get_custom_unetr_model(device, model_name, sam_initialization, output_channe
out_channels=output_channels,
use_sam_stats=sam_initialization,
final_activation="Sigmoid",
encoder_checkpoint=model_state
encoder_checkpoint=model_state,
use_skip_connection=not joint_training # if joint_training, no skip con. else, use skip con. by default
)

model.to(device)

# if expected, let's freeze the image encoder
Expand Down Expand Up @@ -66,7 +70,7 @@ def main(args):
# get the custom model for the training and inference on livecell dataset
model = get_custom_unetr_model(
device, args.model_name, sam_initialization=args.do_sam_ini, output_channels=3,
checkpoint_path=args.checkpoint, freeze_encoder=args.freeze_encoder
checkpoint_path=args.checkpoint, freeze_encoder=args.freeze_encoder, joint_training=args.joint_training
)

# determining where to save the checkpoints and tensorboard logs
Expand Down Expand Up @@ -123,5 +127,8 @@ def main(args):
parser.add_argument(
"--freeze_encoder", action="store_true", help="Experiments to freeze the encoder."
)
parser.add_argument(
"--joint_training", action="store_true", help="Uses VNETR for training"
)
args = parser.parse_args()
main(args)
41 changes: 28 additions & 13 deletions torch_em/model/unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,13 @@ def __init__(
use_mae_stats: bool = False,
encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
final_activation: Optional[Union[str, nn.Module]] = None,
use_skip_connection: bool = True
) -> None:
super().__init__()

self.use_sam_stats = use_sam_stats
self.use_mae_stats = use_mae_stats
self.use_skip_connection = use_skip_connection

print(f"Using {encoder} from {backbone.upper()}")
self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder)
Expand Down Expand Up @@ -107,10 +109,12 @@ def __init__(
self.deconv1 = Deconv2DBlock(self.encoder.embed_dim, features_decoder[0])
self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1])
self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2])
self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3])

self.deconv4 = SingleDeconv2DBlock(features_decoder[-1], features_decoder[-1])
self.deconv_out = SingleDeconv2DBlock(features_decoder[-1], features_decoder[-1])

self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])

self.decoder_head = ConvBlock2d(2*features_decoder[-1], features_decoder[-1])
self.final_activation = self._get_activation(final_activation)

def _get_activation(self, activation):
Expand Down Expand Up @@ -167,26 +171,37 @@ def forward(self, x):
# backbone used for reshaping inputs to the desired "encoder" shape
x = torch.stack([self.preprocess(e) for e in x], dim=0)

z0 = self.z_inputs(x)
use_skip_connection = getattr(self, "use_skip_connection", True)

z12, from_encoder = self.encoder(x)
x = self.base(z12)

from_encoder = from_encoder[::-1]
z9 = self.deconv1(from_encoder[0])
if use_skip_connection:
# TODO: we share the weights in the deconv(s), and should preferably avoid doing that
from_encoder = from_encoder[::-1]
z9 = self.deconv1(from_encoder[0])

z6 = self.deconv1(from_encoder[1])
z6 = self.deconv2(z6)

z6 = self.deconv1(from_encoder[1])
z6 = self.deconv2(z6)
z3 = self.deconv1(from_encoder[2])
z3 = self.deconv2(z3)
z3 = self.deconv3(z3)

z3 = self.deconv1(from_encoder[2])
z3 = self.deconv2(z3)
z3 = self.deconv3(z3)
z0 = self.z_inputs(x)

else:
z9 = self.deconv1(z12)
z6 = self.deconv2(z9)
z3 = self.deconv3(z6)
z0 = self.deconv4(z3)

updated_from_encoder = [z9, z6, z3]

x = self.base(z12)
x = self.decoder(x, encoder_inputs=updated_from_encoder)
x = self.deconv4(x)
x = torch.cat([x, z0], dim=1)
x = self.deconv_out(x)

x = torch.cat([x, z0], dim=1)
x = self.decoder_head(x)

x = self.out_conv(x)
Expand Down

0 comments on commit 41bab4a

Please sign in to comment.