diff --git a/turbo_alignment/cherry_picks/chat.py b/turbo_alignment/cherry_picks/chat.py index eec477e..0c40434 100755 --- a/turbo_alignment/cherry_picks/chat.py +++ b/turbo_alignment/cherry_picks/chat.py @@ -1,4 +1,5 @@ from typing import Iterable +import torch from accelerate import Accelerator from transformers import PreTrainedModel, PreTrainedTokenizerBase @@ -36,18 +37,23 @@ def _get_dataset_metrics( ref_model: dict = kwargs.get('ref_model', None) sft_model: dict = kwargs.get('sft_model', None) - generator = ChatGenerator( - model=model, - tokenizer=tokenizer, - transformers_settings=self._generator_transformers_settings, - custom_generation_settings=self._custom_generation_settings, - accelerator=accelerator, - return_logits=True, - ) + with torch.distributed.fsdp.FullyShardedDataParallel.summon_full_params(model): + # print('👀'*5, 'model_gathered') + generator = ChatGenerator( + model=model.module, + tokenizer=tokenizer, + transformers_settings=self._generator_transformers_settings, + custom_generation_settings=self._custom_generation_settings, + accelerator=accelerator, + return_logits=True, + ) - batch_size = self._generator_transformers_settings.num_return_sequences + batch_size = self._generator_transformers_settings.num_return_sequences - generations = generator.generate_from_dataset(dataset) + # print('🐷'*5, 'before generation started') + generations = generator.generate_from_dataset(dataset) + # print('🐙'*5, 'after generation ended') + # torch.distributed.barrier() prompts = [record['prompt'] for record in dataset] string_answers = [[answer.content for answer in g.answers] for g in generations] diff --git a/turbo_alignment/common/tf/callbacks/common.py b/turbo_alignment/common/tf/callbacks/common.py index 108e251..b249a27 100755 --- a/turbo_alignment/common/tf/callbacks/common.py +++ b/turbo_alignment/common/tf/callbacks/common.py @@ -89,5 +89,4 @@ def on_evaluate( } self.call_event('on_log', args, state, control, logs=table) - return control diff --git a/turbo_alignment/generators/base.py b/turbo_alignment/generators/base.py index 5415daf..deb7b86 100755 --- a/turbo_alignment/generators/base.py +++ b/turbo_alignment/generators/base.py @@ -24,10 +24,10 @@ def __init__( accelerator: Accelerator | None = None, batch: int = 1, ) -> None: - if accelerator is not None: - self._model = accelerator.prepare_model(model, device_placement=True, evaluation_mode=True) - else: - self._model = model + # if accelerator is not None: + # self._model = accelerator.prepare_model(model, device_placement=True, evaluation_mode=True) + # else: + self._model = model self._tokenizer = tokenizer self._accelerator = accelerator diff --git a/turbo_alignment/generators/chat.py b/turbo_alignment/generators/chat.py index acc4337..f995e35 100755 --- a/turbo_alignment/generators/chat.py +++ b/turbo_alignment/generators/chat.py @@ -80,13 +80,16 @@ def _generate_from_single_record( input_ids = torch.unsqueeze(record['input_ids'], 0).to(self.device) attention_mask = torch.unsqueeze(record['attention_mask'], 0).to(self.device) + # print('🐌'*5, 'generating!!!') output_indices = self._model.generate( inputs=input_ids, attention_mask=attention_mask, generation_config=self._transformers_generator_parameters, tokenizer=self._tokenizer, pad_token_id=self._tokenizer.pad_token_id, + synced_gpus=True, ) + # print('🚀'*5, 'generating ended!!!') postprocessed_output_indices = self._postprocess( input_indices=input_ids,