diff --git a/turbo_alignment/common/data/multimodal/common.py b/turbo_alignment/common/data/multimodal/common.py index fc4d9b4..cb17769 100644 --- a/turbo_alignment/common/data/multimodal/common.py +++ b/turbo_alignment/common/data/multimodal/common.py @@ -1,7 +1,7 @@ from pathlib import Path +import h5py import torch -from safetensors import safe_open from turbo_alignment.common.data.multimodal.image.base import BaseImageReader from turbo_alignment.common.data.multimodal.registry import ( @@ -18,11 +18,10 @@ def __init__(self, **_kwargs): self.processed_tensors = None @staticmethod - def _get_safetensors_file(path: Path) -> Path: - return list(path.glob('*.safetensors'))[0] # FIXME: What if there is more than one safetensors file? + def _get_h5_file(path: Path) -> Path: + return list(path.glob('*.h5'))[0] # FIXME: What if there is more than one h5 file? def read(self, path: str) -> torch.Tensor: - safetensors_file = self._get_safetensors_file(Path(path).parent) - if self.processed_tensors is None: - self.processed_tensors = safe_open(safetensors_file, framework='pt', device='cpu') - return self.processed_tensors.get_tensor(Path(path).name) + h5_file = self._get_h5_file(Path(path).parent) + with h5py.File(h5_file, 'r') as f: + return torch.tensor(f[Path(path).name]) diff --git a/turbo_alignment/dataset/multimodal/multimodal.py b/turbo_alignment/dataset/multimodal/multimodal.py index 54fcde2..401604d 100644 --- a/turbo_alignment/dataset/multimodal/multimodal.py +++ b/turbo_alignment/dataset/multimodal/multimodal.py @@ -5,7 +5,6 @@ import numpy as np import torch from allenai_common import Params -from safetensors._safetensors_rust import SafetensorError from turbo_alignment.common.data.io import read_jsonl from turbo_alignment.common.data.multimodal import BaseModalityReader @@ -176,7 +175,7 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s try: encoded_modalities = self._read_modalities(record, modality_messages_after_truncation) - except (OSError, RuntimeError, SafetensorError): + except (OSError, RuntimeError, KeyError): outputs.append(None) continue @@ -235,7 +234,7 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s try: encoded_modalities = self._read_modalities(record, modality_messages_after_truncation) - except (OSError, RuntimeError, SafetensorError): + except (OSError, RuntimeError, KeyError): outputs.append(None) continue diff --git a/turbo_alignment/pipelines/preprocessing/multimodal.py b/turbo_alignment/pipelines/preprocessing/multimodal.py index c9af204..99d9422 100644 --- a/turbo_alignment/pipelines/preprocessing/multimodal.py +++ b/turbo_alignment/pipelines/preprocessing/multimodal.py @@ -70,7 +70,8 @@ def _async_process_files(self, reader, encoder, files_paths, experiment_settings batch_output = self._process_function(reader, encoder, batch, experiment_settings, i, accelerator) with h5py.File(output_file_name, 'a') as f: for path, encoded_output in batch_output.items(): - f.create_dataset(path, data=encoded_output.numpy()) + print(path) + # f.create_dataset(path, data=encoded_output.numpy()) except Exception as exc: logger.error(f'Error reading file: {exc}')