Skip to content

Commit

Permalink
no gradient calc for clip inference
Browse files Browse the repository at this point in the history
  • Loading branch information
lmeribal committed Sep 10, 2024
1 parent aee3af5 commit 2544a00
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions turbo_alignment/modeling/multimodal/encoders/image/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ def __init__(self, encoder_path: Path, model_clip: Optional[CLIPModel] = None, i

@staticmethod
def _get_clip_hidden_states(model_clip: CLIPModel, inputs: torch.Tensor, is_pickle: bool = False) -> torch.Tensor:
if is_pickle:
return inputs
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L213
# -2 is default value of vision_feature_layer in llava config
# [1:] is everything after vit [cls] token
return model_clip.vision_model(inputs.squeeze(1), output_hidden_states=True).hidden_states[-2][
:, 1:
] # FIXME: squeeze dimension?
with torch.no_grad():
if is_pickle:
return inputs
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L213
# -2 is default value of vision_feature_layer in llava config
# [1:] is everything after vit [cls] token
return model_clip.vision_model(inputs.squeeze(1), output_hidden_states=True).hidden_states[-2][
:, 1:
] # FIXME: squeeze dimension?

def encode(self, inputs: torch.Tensor) -> torch.Tensor:
return self._get_clip_hidden_states(self.model_clip, inputs, self.is_pickle)
Expand Down

0 comments on commit 2544a00

Please sign in to comment.