Skip to content

Commit

Permalink
in p
Browse files Browse the repository at this point in the history
  • Loading branch information
Elisei Rykov committed Sep 26, 2024
1 parent 5f2a1ae commit cc3a444
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion turbo_alignment/dataset/multimodal/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def torch_call(self, features):
if 'modality_inputs' in features[0].keys():
modality_inputs = [feature['modality_inputs'] for feature in features]
else:
modality_inputs = [None for feature in features]
modality_inputs = [None for _ in features]

if 'messages' in features[0].keys():
message_inputs = [feature['messages'] for feature in features]
Expand Down
13 changes: 8 additions & 5 deletions turbo_alignment/dataset/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import torch
from allenai_common import Params
import gc

from turbo_alignment.common.data.io import read_jsonl
from turbo_alignment.common.data.multimodal import BaseModalityReader
Expand Down Expand Up @@ -196,12 +197,13 @@ def _read_modalities(self, record):
except (OSError, RuntimeError, KeyError):
return None

record['modality_inputs'] = modality_encodings
# record['modality_inputs'] = modality_encodings

if len(modality_encodings) != modality_messages_after_truncation:
return None

return record
# return record
return modality_encodings

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
Expand All @@ -213,9 +215,10 @@ def __iter__(self):
worker_id = worker_info.id
start = start + worker_id * per_worker
end = min(start + per_worker, end)
for sample in map(self._read_modalities, iter(self.records[start:end])):
if sample:
yield sample
for i, sample in enumerate(self.records[start:end]):
output = self._read_modalities(sample)
if output:
yield sample | {'modality_inputs': output}


@MultimodalDatasetTypeRegistry.register(DatasetStrategy.INFERENCE)
Expand Down

0 comments on commit cc3a444

Please sign in to comment.