Skip to content

Commit

Permalink
in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
lmeribal committed Sep 10, 2024
1 parent e13ddb5 commit 8c0f6cc
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
13 changes: 6 additions & 7 deletions turbo_alignment/common/data/multimodal/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path

import h5py
import torch
from safetensors import safe_open

from turbo_alignment.common.data.multimodal.image.base import BaseImageReader
from turbo_alignment.common.data.multimodal.registry import (
Expand All @@ -18,11 +18,10 @@ def __init__(self, **_kwargs):
self.processed_tensors = None

@staticmethod
def _get_safetensors_file(path: Path) -> Path:
return list(path.glob('*.safetensors'))[0] # FIXME: What if there is more than one safetensors file?
def _get_h5_file(path: Path) -> Path:
return list(path.glob('*.h5'))[0] # FIXME: What if there is more than one h5 file?

def read(self, path: str) -> torch.Tensor:
safetensors_file = self._get_safetensors_file(Path(path).parent)
if self.processed_tensors is None:
self.processed_tensors = safe_open(safetensors_file, framework='pt', device='cpu')
return self.processed_tensors.get_tensor(Path(path).name)
h5_file = self._get_h5_file(Path(path).parent)
with h5py.File(h5_file, 'r') as f:
return torch.tensor(f[Path(path).name])
5 changes: 2 additions & 3 deletions turbo_alignment/dataset/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
import torch
from allenai_common import Params
from safetensors._safetensors_rust import SafetensorError

from turbo_alignment.common.data.io import read_jsonl
from turbo_alignment.common.data.multimodal import BaseModalityReader
Expand Down Expand Up @@ -176,7 +175,7 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s

try:
encoded_modalities = self._read_modalities(record, modality_messages_after_truncation)
except (OSError, RuntimeError, SafetensorError):
except (OSError, RuntimeError, KeyError):
outputs.append(None)
continue

Expand Down Expand Up @@ -235,7 +234,7 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s

try:
encoded_modalities = self._read_modalities(record, modality_messages_after_truncation)
except (OSError, RuntimeError, SafetensorError):
except (OSError, RuntimeError, KeyError):
outputs.append(None)
continue

Expand Down
3 changes: 2 additions & 1 deletion turbo_alignment/pipelines/preprocessing/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def _async_process_files(self, reader, encoder, files_paths, experiment_settings
batch_output = self._process_function(reader, encoder, batch, experiment_settings, i, accelerator)
with h5py.File(output_file_name, 'a') as f:
for path, encoded_output in batch_output.items():
f.create_dataset(path, data=encoded_output.numpy())
print(path)
# f.create_dataset(path, data=encoded_output.numpy())
except Exception as exc:
logger.error(f'Error reading file: {exc}')

Expand Down

0 comments on commit 8c0f6cc

Please sign in to comment.