Skip to content

Commit

Permalink
Fix forward method to return hidden states
Browse files Browse the repository at this point in the history
  • Loading branch information
actions-user committed Sep 13, 2024
1 parent 636c001 commit af8d300
Showing 1 changed file with 28 additions and 21 deletions.
49 changes: 28 additions & 21 deletions mmlearn/modules/encoders/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,41 @@ def __init__(
if model_kwargs is None:
model_kwargs = {}

model: nn.Module = timm.create_model(
self.model: VisionTransformer = timm.create_model(
model_name,
pretrained=pretrained,
num_classes=projection_dim,
**model_kwargs,
)
assert isinstance(model, VisionTransformer), (
assert isinstance(self.model, VisionTransformer), (
f"Model {model_name} is not a Vision Transformer. "
"Please provide a model name that corresponds to a Vision Transformer."
)

self._freeze_layers(freeze_layers, freeze_layer_norm)

if peft_config is not None:
self.model = hf_utils._wrap_peft_model(self.model, peft_config)

def _freeze_layers(
self, freeze_layers: Union[int, float, List[int], bool], freeze_layer_norm: bool
) -> None:
"""Freeze the layers of the model.
Parameters
----------
freeze_layers : Union[int, float, List[int], bool]
Whether to freeze the layers.
freeze_layer_norm : bool
Whether to freeze the layer norm.
"""
if isinstance(freeze_layers, bool) and freeze_layers:
for name, param in model.named_parameters():
for name, param in self.model.named_parameters():
param.requires_grad = (
(not freeze_layer_norm) if "norm" in name else False
)

modules = [model.patch_embed, *model.blocks, model.norm]
modules = [self.model.patch_embed, *self.model.blocks, self.model.norm]
if isinstance(freeze_layers, float):
freeze_layers = int(freeze_layers * len(modules))
if isinstance(freeze_layers, int):
Expand All @@ -87,11 +104,6 @@ def __init__(
(not freeze_layer_norm) if "norm" in name else False
)

if peft_config is not None:
model = hf_utils._wrap_peft_model(model, peft_config)

self.model = model

def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
"""Run the forward pass.
Expand All @@ -106,28 +118,23 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
The output of the model.
"""
x = inputs[Modalities.RGB]
x = self.model.forward_features(x)

# Separate the class token and patch embeddings
cls_token = x[:, 0]
patch_embeddings = x[:, 1:]
_, intermediates = self.model.forward_intermediates(x)

return BaseModelOutput(
last_hidden_state=patch_embeddings,
pooler_output=cls_token,
hidden_states=None,
last_hidden_state=intermediates[-1],
hidden_states=intermediates,
attentions=None,
)

def get_intermediate_layers(
self, x: torch.Tensor, n: int = 1
self, inputs: Dict[Union[str, Modality], Any], n: int = 1
) -> List[torch.Tensor]:
"""Get the output of the intermediate layers.
Parameters
----------
x : torch.Tensor
The input tensor.
inputs : Dict[Union[str, Modality], Any]
The input data. The `image` will be expected under the `Modalities.RGB` key.
n : int, default=1
The number of intermediate layers to return.
Expand All @@ -136,7 +143,7 @@ def get_intermediate_layers(
List[torch.Tensor]
The outputs of the last n intermediate layers.
"""
return self.model.get_intermediate_layers(x, n) # type: ignore
return self.model.get_intermediate_layers(inputs[Modalities.RGB], n) # type: ignore

def get_patch_info(self) -> Tuple[int, int]:
"""Get patch size and number of patches.
Expand Down

0 comments on commit af8d300

Please sign in to comment.