From af8d300fed1863854237df2f21af1ddf70c503f2 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Fri, 13 Sep 2024 16:37:21 -0400 Subject: [PATCH] Fix forward method to return hidden states --- mmlearn/modules/encoders/vision.py | 49 +++++++++++++++++------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/mmlearn/modules/encoders/vision.py b/mmlearn/modules/encoders/vision.py index add723c..6b06a60 100644 --- a/mmlearn/modules/encoders/vision.py +++ b/mmlearn/modules/encoders/vision.py @@ -56,24 +56,41 @@ def __init__( if model_kwargs is None: model_kwargs = {} - model: nn.Module = timm.create_model( + self.model: VisionTransformer = timm.create_model( model_name, pretrained=pretrained, num_classes=projection_dim, **model_kwargs, ) - assert isinstance(model, VisionTransformer), ( + assert isinstance(self.model, VisionTransformer), ( f"Model {model_name} is not a Vision Transformer. " "Please provide a model name that corresponds to a Vision Transformer." ) + self._freeze_layers(freeze_layers, freeze_layer_norm) + + if peft_config is not None: + self.model = hf_utils._wrap_peft_model(self.model, peft_config) + + def _freeze_layers( + self, freeze_layers: Union[int, float, List[int], bool], freeze_layer_norm: bool + ) -> None: + """Freeze the layers of the model. + + Parameters + ---------- + freeze_layers : Union[int, float, List[int], bool] + Whether to freeze the layers. + freeze_layer_norm : bool + Whether to freeze the layer norm. + """ if isinstance(freeze_layers, bool) and freeze_layers: - for name, param in model.named_parameters(): + for name, param in self.model.named_parameters(): param.requires_grad = ( (not freeze_layer_norm) if "norm" in name else False ) - modules = [model.patch_embed, *model.blocks, model.norm] + modules = [self.model.patch_embed, *self.model.blocks, self.model.norm] if isinstance(freeze_layers, float): freeze_layers = int(freeze_layers * len(modules)) if isinstance(freeze_layers, int): @@ -87,11 +104,6 @@ def __init__( (not freeze_layer_norm) if "norm" in name else False ) - if peft_config is not None: - model = hf_utils._wrap_peft_model(model, peft_config) - - self.model = model - def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput: """Run the forward pass. @@ -106,28 +118,23 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput: The output of the model. """ x = inputs[Modalities.RGB] - x = self.model.forward_features(x) - - # Separate the class token and patch embeddings - cls_token = x[:, 0] - patch_embeddings = x[:, 1:] + _, intermediates = self.model.forward_intermediates(x) return BaseModelOutput( - last_hidden_state=patch_embeddings, - pooler_output=cls_token, - hidden_states=None, + last_hidden_state=intermediates[-1], + hidden_states=intermediates, attentions=None, ) def get_intermediate_layers( - self, x: torch.Tensor, n: int = 1 + self, inputs: Dict[Union[str, Modality], Any], n: int = 1 ) -> List[torch.Tensor]: """Get the output of the intermediate layers. Parameters ---------- - x : torch.Tensor - The input tensor. + inputs : Dict[Union[str, Modality], Any] + The input data. The `image` will be expected under the `Modalities.RGB` key. n : int, default=1 The number of intermediate layers to return. @@ -136,7 +143,7 @@ def get_intermediate_layers( List[torch.Tensor] The outputs of the last n intermediate layers. """ - return self.model.get_intermediate_layers(x, n) # type: ignore + return self.model.get_intermediate_layers(inputs[Modalities.RGB], n) # type: ignore def get_patch_info(self) -> Tuple[int, int]: """Get patch size and number of patches.