Skip to content

Commit

Permalink
going back to the pt files
Browse files Browse the repository at this point in the history
  • Loading branch information
lmeribal committed Sep 11, 2024
1 parent 43d6ba6 commit 0e51337
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 19 deletions.
10 changes: 3 additions & 7 deletions turbo_alignment/common/data/multimodal/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from pathlib import Path

import h5py
import torch

from turbo_alignment.common.data.multimodal.image.base import BaseImageReader
Expand All @@ -14,11 +13,8 @@
@AudioModalityReaderRegistry.register(ModalityReader.PICKLE)
@ImageModalityReaderRegistry.register(ModalityReader.PICKLE)
class FileReader(BaseImageReader):
@staticmethod
def _get_h5_file(path: Path) -> Path:
return list(path.glob('*.h5'))[0] # FIXME: What if there is more than one h5 file?
def __init__(self, **_kwargs):
...

def read(self, path: str) -> torch.Tensor:
h5_file = self._get_h5_file(Path(path).parent)
with h5py.File(h5_file, 'r') as f:
return torch.tensor(f[Path(path).name])
return torch.load(path + '.image.clip.pt')
10 changes: 5 additions & 5 deletions turbo_alignment/modeling/multimodal/encoders/image/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ def __init__(self, encoder_path: Path, model_clip: Optional[CLIPModel] = None, i

@staticmethod
def _get_clip_hidden_states(model_clip: CLIPModel, inputs: torch.Tensor, is_pickle: bool = False) -> torch.Tensor:
if is_pickle:
return inputs
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L213
# -2 is default value of vision_feature_layer in llava config
# [1:] is everything after vit [cls] token
with torch.no_grad():
if is_pickle:
return inputs
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L213
# -2 is default value of vision_feature_layer in llava config
# [1:] is everything after vit [cls] token
return model_clip.vision_model(inputs.squeeze(1), output_hidden_states=True).hidden_states[-2][
:, 1:
] # FIXME: squeeze dimension?
Expand Down
17 changes: 10 additions & 7 deletions turbo_alignment/pipelines/preprocessing/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,16 @@ def _async_process_files(self, reader, encoder, files_paths, experiment_settings
for filename, encoded_output in batch_output.items():
torch.save(
encoded_output,
experiment_settings.output_file_path / (filename
+ '.'
+ experiment_settings.modality.value
+ '.'
+ experiment_settings.encoder_settings.modality_encoder_type
+ '.pt'
))
experiment_settings.output_file_path
/ (
filename
+ '.'
+ experiment_settings.modality.value
+ '.'
+ experiment_settings.encoder_settings.modality_encoder_type
+ '.pt'
),
)
except Exception as exc:
logger.error(f'Error reading file: {exc}')

Expand Down

0 comments on commit 0e51337

Please sign in to comment.