Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Elisei Rykov committed Sep 29, 2024
1 parent d275596 commit 7846fbd
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
8 changes: 5 additions & 3 deletions turbo_alignment/dataset/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,6 @@ def __get_modality_message(self, msg) -> str:
def _convert_to_chat(self, record: MultimodalDatasetRecord) -> ChatDatasetRecord:
converted_messages: list[ChatMessage] = []
for msg in record.messages:
print(msg)
print(hasattr(msg, 'modality_object_path'))

if msg.modality_object_path: # if there is a path to modality object
converted_messages.append(
ChatMessage(
Expand Down Expand Up @@ -192,11 +189,13 @@ def _read_modalities(self, record):
modality_encodings.append((Modality.IMAGE, reader.read(msg.modality_object_path)))
# modality_encodings.append((msg.type, reader.read(msg.content)))
except (OSError, RuntimeError, KeyError) as E:
# print("😆"*10, E)
return None

# record['modality_inputs'] = modality_encodings

if len(modality_encodings) != modality_messages_after_truncation:
# print("😆"*10, "len(modality_encodings) != modality_messages_after_truncation")
return None

return modality_encodings
Expand All @@ -213,8 +212,11 @@ def __iter__(self):
end = min(start + per_worker, end)
for i, sample in enumerate(self.records[start:end]):
output = self._read_modalities(sample)
# print("😇"*10, output)
if output:
yield sample | {'modality_inputs': output}
# else:
# yield sample


@MultimodalDatasetTypeRegistry.register(DatasetStrategy.INFERENCE)
Expand Down
5 changes: 3 additions & 2 deletions turbo_alignment/modeling/multimodal/lm/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ def __init__(self, *args, **kwargs) -> None:

self.config = self.language_model.config

@staticmethod
def get_replica_spans(input_ids, start_token_id=32000, end_token_id=32001):
def get_replica_spans(self, input_ids, start_token='<|start_header_id|>', end_token='<|eot_id|>'):
start_token_id = self.tokenizer.encode(start_token, add_special_tokens=False)[0]
end_token_id = self.tokenizer.encode(end_token, add_special_tokens=False)[0]
spans = []
inside_replica = False
start_idx = 0
Expand Down
3 changes: 0 additions & 3 deletions turbo_alignment/pipelines/train/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,6 @@ def _load_model(
) -> torch.nn.Module | PreTrainedModel:
language_model = load_model(experiment_settings.model_settings, tokenizer)

print(tokenizer)
exit()

modality_encoders = TrainMultimodalStrategy._load_modality_encoders(
experiment_settings.modality_encoder_settings_mapping,
device=language_model.device,
Expand Down

0 comments on commit 7846fbd

Please sign in to comment.