From 41bab4a6d6787fc7f28c6db38cbc392857bf10de Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Sat, 16 Dec 2023 21:04:54 +0100 Subject: [PATCH] UNETR - Make skip connections optional (#9) Update UNETR model to optionally use skip connections or not --- .../unetr/livecell/train_by_parts.py | 13 ++++-- torch_em/model/unetr.py | 41 +++++++++++++------ 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/experiments/vision-transformer/unetr/livecell/train_by_parts.py b/experiments/vision-transformer/unetr/livecell/train_by_parts.py index 280fa730..abbd09f4 100644 --- a/experiments/vision-transformer/unetr/livecell/train_by_parts.py +++ b/experiments/vision-transformer/unetr/livecell/train_by_parts.py @@ -22,7 +22,9 @@ def prune_prefix(checkpoint_path): return updated_model_state -def get_custom_unetr_model(device, model_name, sam_initialization, output_channels, checkpoint_path, freeze_encoder): +def get_custom_unetr_model( + device, model_name, sam_initialization, output_channels, checkpoint_path, freeze_encoder, joint_training +): if checkpoint_path is not None: if checkpoint_path.endswith("pt"): # for finetuned models model_state = prune_prefix(checkpoint_path) @@ -37,8 +39,10 @@ def get_custom_unetr_model(device, model_name, sam_initialization, output_channe out_channels=output_channels, use_sam_stats=sam_initialization, final_activation="Sigmoid", - encoder_checkpoint=model_state + encoder_checkpoint=model_state, + use_skip_connection=not joint_training # if joint_training, no skip con. else, use skip con. by default ) + model.to(device) # if expected, let's freeze the image encoder @@ -66,7 +70,7 @@ def main(args): # get the custom model for the training and inference on livecell dataset model = get_custom_unetr_model( device, args.model_name, sam_initialization=args.do_sam_ini, output_channels=3, - checkpoint_path=args.checkpoint, freeze_encoder=args.freeze_encoder + checkpoint_path=args.checkpoint, freeze_encoder=args.freeze_encoder, joint_training=args.joint_training ) # determining where to save the checkpoints and tensorboard logs @@ -123,5 +127,8 @@ def main(args): parser.add_argument( "--freeze_encoder", action="store_true", help="Experiments to freeze the encoder." ) + parser.add_argument( + "--joint_training", action="store_true", help="Uses VNETR for training" + ) args = parser.parse_args() main(args) diff --git a/torch_em/model/unetr.py b/torch_em/model/unetr.py index 33c35070..39aa586c 100644 --- a/torch_em/model/unetr.py +++ b/torch_em/model/unetr.py @@ -70,11 +70,13 @@ def __init__( use_mae_stats: bool = False, encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, final_activation: Optional[Union[str, nn.Module]] = None, + use_skip_connection: bool = True ) -> None: super().__init__() self.use_sam_stats = use_sam_stats self.use_mae_stats = use_mae_stats + self.use_skip_connection = use_skip_connection print(f"Using {encoder} from {backbone.upper()}") self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder) @@ -107,10 +109,12 @@ def __init__( self.deconv1 = Deconv2DBlock(self.encoder.embed_dim, features_decoder[0]) self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1]) self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2]) + self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3]) - self.deconv4 = SingleDeconv2DBlock(features_decoder[-1], features_decoder[-1]) + self.deconv_out = SingleDeconv2DBlock(features_decoder[-1], features_decoder[-1]) + + self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) - self.decoder_head = ConvBlock2d(2*features_decoder[-1], features_decoder[-1]) self.final_activation = self._get_activation(final_activation) def _get_activation(self, activation): @@ -167,26 +171,37 @@ def forward(self, x): # backbone used for reshaping inputs to the desired "encoder" shape x = torch.stack([self.preprocess(e) for e in x], dim=0) - z0 = self.z_inputs(x) + use_skip_connection = getattr(self, "use_skip_connection", True) z12, from_encoder = self.encoder(x) - x = self.base(z12) - from_encoder = from_encoder[::-1] - z9 = self.deconv1(from_encoder[0]) + if use_skip_connection: + # TODO: we share the weights in the deconv(s), and should preferably avoid doing that + from_encoder = from_encoder[::-1] + z9 = self.deconv1(from_encoder[0]) + + z6 = self.deconv1(from_encoder[1]) + z6 = self.deconv2(z6) - z6 = self.deconv1(from_encoder[1]) - z6 = self.deconv2(z6) + z3 = self.deconv1(from_encoder[2]) + z3 = self.deconv2(z3) + z3 = self.deconv3(z3) - z3 = self.deconv1(from_encoder[2]) - z3 = self.deconv2(z3) - z3 = self.deconv3(z3) + z0 = self.z_inputs(x) + + else: + z9 = self.deconv1(z12) + z6 = self.deconv2(z9) + z3 = self.deconv3(z6) + z0 = self.deconv4(z3) updated_from_encoder = [z9, z6, z3] + + x = self.base(z12) x = self.decoder(x, encoder_inputs=updated_from_encoder) - x = self.deconv4(x) - x = torch.cat([x, z0], dim=1) + x = self.deconv_out(x) + x = torch.cat([x, z0], dim=1) x = self.decoder_head(x) x = self.out_conv(x)