Skip to content

Commit

Permalink
prettifier
Browse files Browse the repository at this point in the history
  • Loading branch information
lmeribal committed Sep 12, 2024
1 parent b63c973 commit 90b736f
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 198 deletions.
3 changes: 2 additions & 1 deletion turbo_alignment/dataset/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def _read_records(records: list[dict]) -> list[RecordT]:
@abstractmethod
def _read_records(records):
...



class BaseIterableDataset(IterableDataset, ABC, Generic[RecordT]):
def __init__(
self,
Expand Down
8 changes: 4 additions & 4 deletions turbo_alignment/dataset/multimodal/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ def torch_call(self, features):
modality_inputs = [feature['modality_inputs'] for feature in features]
else:
modality_inputs = [None for feature in features]

if 'messages' in features[0].keys():
message_inputs = [feature['messages'] for feature in features]
else:
message_inputs = [None] * len(features)

if 'messages' in features[0].keys():
modality_input_names = (label_name, 'modality_inputs', 'modality_tokens_mask', 'messages')
else:
Expand All @@ -40,8 +40,8 @@ def torch_call(self, features):
assert padding_side == 'right'

batch['modality_inputs'] = modality_inputs
if 'messages' in features[0].keys():

if 'messages' in features[0].keys():
batch['messages'] = message_inputs

batch['modality_tokens_mask'] = torch.stack(
Expand Down
88 changes: 34 additions & 54 deletions turbo_alignment/dataset/multimodal/multimodal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
import math
from abc import ABC
from pathlib import Path
from typing import Any, overload

Expand Down Expand Up @@ -93,6 +93,29 @@ def __get_modality_message(self, modality: Modality) -> str:
modality_message_span = ''.join(modality_token for _ in range(self._n_modality_embeddings))
return f'{self._start_modality_token}{modality_message_span}{self._end_modality_token}'

def _read_modalities(self, record):
modality_messages_after_truncation = int(
(self.records[0]['input_ids'] == self._get_token_id(self._start_modality_token)).sum()
)

modality_messages: list[MultimodalFileMessage] = [
m for m in record['messages'] if isinstance(m, MultimodalFileMessage)
]

messages_to_delete = len(modality_messages) - modality_messages_after_truncation

if self._truncate_top:
modality_messages = modality_messages[messages_to_delete:]
else:
modality_messages = modality_messages[:modality_messages_after_truncation]

modality_encodings: list[tuple[Modality, torch.Tensor]] = []
for msg in modality_messages:
reader = self._modality_readers[msg.type]
modality_encodings.append((msg.type, reader.read(msg.content)))
record['modality_inputs'] = modality_encodings
return record

def _convert_to_chat(self, record: MultimodalDatasetRecord) -> ChatDatasetRecord:
"""
Обычные текстовые сообщения оставляем без изменений, а мультимодальные сообщения
Expand Down Expand Up @@ -150,16 +173,6 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s
outputs.append(None)
continue

# try:
# encoded_modalities = self._read_modalities(record, modality_messages_after_truncation)
# except (OSError, RuntimeError, KeyError):
# outputs.append(None)
# continue

# if len(encoded_modalities) != modality_messages_after_truncation:
# outputs.append(None)
# continue

modality_tokens_mask = torch.isin(
tokenized_record['input_ids'],
torch.tensor([self._get_token_id(token) for token in self._modality_token_mapping.values()]),
Expand All @@ -178,38 +191,26 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s

return outputs

def _read_modalities(self, record):
modality_messages_after_truncation = int((self.records[0]['input_ids'] == self._get_token_id(self._start_modality_token)).sum())

modality_messages: list[MultimodalFileMessage] = [
m for m in record['messages'] if isinstance(m, MultimodalFileMessage)
]

messages_to_delete = len(modality_messages) - modality_messages_after_truncation

if self._truncate_top:
modality_messages = modality_messages[messages_to_delete:]
else:
modality_messages = modality_messages[:modality_messages_after_truncation]
def __iter__(self):
# try:
# encoded_modalities = self._read_modalities(record, modality_messages_after_truncation)
# except (OSError, RuntimeError, KeyError):
# outputs.append(None)
# continue

modality_encodings: list[tuple[Modality, torch.Tensor]] = []
for msg in modality_messages:
reader = self._modality_readers[msg.type]
modality_encodings.append((msg.type, reader.read(msg.content)))
record['modality_inputs'] = modality_encodings
return record
# if len(encoded_modalities) != modality_messages_after_truncation:
# outputs.append(None)
# continue

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()

start = 0
end = len(self.records) - 1
if worker_info:
per_worker = int(math.ceil(len(self.records) / float(worker_info.num_workers)))
worker_id = worker_info.id
start = start + worker_id * per_worker
end = min(start + per_worker, end)
print("🩻"*10, f"{float(worker_info.num_workers)=}, {per_worker=}, {worker_info.id=}, {start=}, {end=}")

return map(self._read_modalities, iter(self.records[start:end]))

Expand All @@ -229,27 +230,6 @@ def __init__(
tokenizer=kwargs['tokenizer'], source=kwargs['source'], settings=settings, read=False
)
self._read()

def _read_modalities(self, record):
modality_messages_after_truncation = int((self.records[0]['input_ids'] == self._get_token_id(self._start_modality_token)).sum())

modality_messages: list[MultimodalFileMessage] = [
m for m in record['messages'] if isinstance(m, MultimodalFileMessage)
]

messages_to_delete = len(modality_messages) - modality_messages_after_truncation

if self._truncate_top:
modality_messages = modality_messages[messages_to_delete:]
else:
modality_messages = modality_messages[:modality_messages_after_truncation]

modality_encodings: list[tuple[Modality, torch.Tensor]] = []
for msg in modality_messages:
reader = self._modality_readers[msg.type]
modality_encodings.append((msg.type, reader.read(msg.content)))
record['modality_inputs'] = modality_encodings
return record

def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[str, Any] | None]:
chat_records = [self._convert_to_chat(r) for r in records]
Expand Down
Loading

0 comments on commit 90b736f

Please sign in to comment.