diff --git a/turbo_alignment/common/data/multimodal/common.py b/turbo_alignment/common/data/multimodal/common.py index 01e124e..8f16577 100644 --- a/turbo_alignment/common/data/multimodal/common.py +++ b/turbo_alignment/common/data/multimodal/common.py @@ -1,6 +1,5 @@ from pathlib import Path -import h5py import torch from turbo_alignment.common.data.multimodal.image.base import BaseImageReader @@ -14,11 +13,8 @@ @AudioModalityReaderRegistry.register(ModalityReader.PICKLE) @ImageModalityReaderRegistry.register(ModalityReader.PICKLE) class FileReader(BaseImageReader): - @staticmethod - def _get_h5_file(path: Path) -> Path: - return list(path.glob('*.h5'))[0] # FIXME: What if there is more than one h5 file? + def __init__(self, **_kwargs): + ... def read(self, path: str) -> torch.Tensor: - h5_file = self._get_h5_file(Path(path).parent) - with h5py.File(h5_file, 'r') as f: - return torch.tensor(f[Path(path).name]) + return torch.load(path + '.image.clip.pt') diff --git a/turbo_alignment/modeling/multimodal/encoders/image/clip.py b/turbo_alignment/modeling/multimodal/encoders/image/clip.py index 4537689..32ffe32 100644 --- a/turbo_alignment/modeling/multimodal/encoders/image/clip.py +++ b/turbo_alignment/modeling/multimodal/encoders/image/clip.py @@ -23,12 +23,12 @@ def __init__(self, encoder_path: Path, model_clip: Optional[CLIPModel] = None, i @staticmethod def _get_clip_hidden_states(model_clip: CLIPModel, inputs: torch.Tensor, is_pickle: bool = False) -> torch.Tensor: + if is_pickle: + return inputs + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L213 + # -2 is default value of vision_feature_layer in llava config + # [1:] is everything after vit [cls] token with torch.no_grad(): - if is_pickle: - return inputs - # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L213 - # -2 is default value of vision_feature_layer in llava config - # [1:] is everything after vit [cls] token return model_clip.vision_model(inputs.squeeze(1), output_hidden_states=True).hidden_states[-2][ :, 1: ] # FIXME: squeeze dimension? diff --git a/turbo_alignment/pipelines/preprocessing/multimodal.py b/turbo_alignment/pipelines/preprocessing/multimodal.py index 7959fb2..249a1ef 100644 --- a/turbo_alignment/pipelines/preprocessing/multimodal.py +++ b/turbo_alignment/pipelines/preprocessing/multimodal.py @@ -68,13 +68,16 @@ def _async_process_files(self, reader, encoder, files_paths, experiment_settings for filename, encoded_output in batch_output.items(): torch.save( encoded_output, - experiment_settings.output_file_path / (filename - + '.' - + experiment_settings.modality.value - + '.' - + experiment_settings.encoder_settings.modality_encoder_type - + '.pt' - )) + experiment_settings.output_file_path + / ( + filename + + '.' + + experiment_settings.modality.value + + '.' + + experiment_settings.encoder_settings.modality_encoder_type + + '.pt' + ), + ) except Exception as exc: logger.error(f'Error reading file: {exc}')