Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: 🐌 Add FSDP support #51

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions turbo_alignment/cherry_picks/chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Iterable
import torch

from accelerate import Accelerator
from transformers import PreTrainedModel, PreTrainedTokenizerBase
Expand Down Expand Up @@ -36,18 +37,23 @@ def _get_dataset_metrics(
ref_model: dict = kwargs.get('ref_model', None)
sft_model: dict = kwargs.get('sft_model', None)

generator = ChatGenerator(
model=model,
tokenizer=tokenizer,
transformers_settings=self._generator_transformers_settings,
custom_generation_settings=self._custom_generation_settings,
accelerator=accelerator,
return_logits=True,
)
with torch.distributed.fsdp.FullyShardedDataParallel.summon_full_params(model):
# print('👀'*5, 'model_gathered')
generator = ChatGenerator(
model=model.module,
tokenizer=tokenizer,
transformers_settings=self._generator_transformers_settings,
custom_generation_settings=self._custom_generation_settings,
accelerator=accelerator,
return_logits=True,
)

batch_size = self._generator_transformers_settings.num_return_sequences
batch_size = self._generator_transformers_settings.num_return_sequences

generations = generator.generate_from_dataset(dataset)
# print('🐷'*5, 'before generation started')
generations = generator.generate_from_dataset(dataset)
# print('🐙'*5, 'after generation ended')
# torch.distributed.barrier()

prompts = [record['prompt'] for record in dataset]
string_answers = [[answer.content for answer in g.answers] for g in generations]
Expand Down
1 change: 0 additions & 1 deletion turbo_alignment/common/tf/callbacks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,4 @@ def on_evaluate(
}

self.call_event('on_log', args, state, control, logs=table)

return control
8 changes: 4 additions & 4 deletions turbo_alignment/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def __init__(
accelerator: Accelerator | None = None,
batch: int = 1,
) -> None:
if accelerator is not None:
self._model = accelerator.prepare_model(model, device_placement=True, evaluation_mode=True)
else:
self._model = model
# if accelerator is not None:
# self._model = accelerator.prepare_model(model, device_placement=True, evaluation_mode=True)
# else:
self._model = model

self._tokenizer = tokenizer
self._accelerator = accelerator
Expand Down
3 changes: 3 additions & 0 deletions turbo_alignment/generators/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,16 @@ def _generate_from_single_record(
input_ids = torch.unsqueeze(record['input_ids'], 0).to(self.device)
attention_mask = torch.unsqueeze(record['attention_mask'], 0).to(self.device)

# print('🐌'*5, 'generating!!!')
output_indices = self._model.generate(
inputs=input_ids,
attention_mask=attention_mask,
generation_config=self._transformers_generator_parameters,
tokenizer=self._tokenizer,
pad_token_id=self._tokenizer.pad_token_id,
synced_gpus=True,
)
# print('🚀'*5, 'generating ended!!!')

postprocessed_output_indices = self._postprocess(
input_indices=input_ids,
Expand Down
Loading