From f25e31f53314c3684cd772c86fa1d9b02fffed73 Mon Sep 17 00:00:00 2001 From: "a.khokhulin" Date: Mon, 11 Nov 2024 15:59:00 +0300 Subject: [PATCH] in progress --- turbo_alignment/cherry_picks/chat.py | 41 ++++++++++++++++++++++++- turbo_alignment/generators/base.py | 45 ++++++++++++++++++++-------- 2 files changed, 72 insertions(+), 14 deletions(-) diff --git a/turbo_alignment/cherry_picks/chat.py b/turbo_alignment/cherry_picks/chat.py index eec477e..ff5c470 100755 --- a/turbo_alignment/cherry_picks/chat.py +++ b/turbo_alignment/cherry_picks/chat.py @@ -1,3 +1,4 @@ +import math from typing import Iterable from accelerate import Accelerator @@ -41,16 +42,47 @@ def _get_dataset_metrics( tokenizer=tokenizer, transformers_settings=self._generator_transformers_settings, custom_generation_settings=self._custom_generation_settings, - accelerator=accelerator, + # accelerator=accelerator, return_logits=True, ) batch_size = self._generator_transformers_settings.num_return_sequences + if accelerator is not None: + len_records_batches = len(dataset) + world_size = accelerator.num_processes + rank_device = accelerator.process_index + window_size = math.ceil(len_records_batches / world_size) + + print(f"len records_batches", len_records_batches) + print("rank_device", rank_device) + print("world_size", world_size) + print(f"slice [{rank_device * window_size} : {rank_device * window_size + window_size}]") + + + dataset = dataset[rank_device * window_size : rank_device * window_size + window_size] + generations = generator.generate_from_dataset(dataset) + print(f"len dataset {len(dataset)}") + print(f"first 2 prompts: {[record['prompt'] for record in dataset][:2]}") + print(f"len generations {len(generations)}") + print(f"first 2 generations: {generations[:2]}") + prompts = [record['prompt'] for record in dataset] + + # len_records_batches = len(prompts) + # world_size = accelerator.num_processes + # rank_device = accelerator.process_index + # window_size = math.ceil(len_records_batches / world_size) + + # prompts = + + + string_answers = [[answer.content for answer in g.answers] for g in generations] + print(f"prompts: {len(prompts)}, {prompts[:2]}") + print(f"string_answers: {len(string_answers)}, {string_answers[:2]}") string_labels = [[g.messages[-1].content] * len(g.answers) for g in generations] flattened_answers = [answer for g in generations for answer in g.answers] @@ -66,6 +98,9 @@ def _get_dataset_metrics( if sft_model is not None: metrics_kwargs[KLType.SFT_MODEL] = get_logits(input_tokens_ids, answer_tokens_ids, sft_model) + print(f"len prompts {len(prompts)}") + print(f"batch_size {batch_size}") + print(f"prompt element_wise_scores {len([prompt for prompt in prompts for _ in range(batch_size)])}") metric_outputs = [ MetricResults( element_wise_scores=[ @@ -86,6 +121,10 @@ def _get_dataset_metrics( ] ), ] + print(f"prompts: {len(metric_outputs[0].element_wise_scores[0].values)}, {metric_outputs[0].element_wise_scores[0].values}") + print(f"labels: {len(metric_outputs[1].element_wise_scores[0].values)}, {metric_outputs[1].element_wise_scores[0].values}") + # print(f"metric_outputs: {metric_outputs}") + for metric in self._metrics: metric_results = metric.compute( diff --git a/turbo_alignment/generators/base.py b/turbo_alignment/generators/base.py index 5415daf..e4d94ad 100755 --- a/turbo_alignment/generators/base.py +++ b/turbo_alignment/generators/base.py @@ -71,19 +71,38 @@ def generate_from_dataset(self, dataset: BaseDataset) -> list[InferenceOutputT]: ) ) else: - with self._accelerator.split_between_processes( - list(zip(records_batches, original_records_batches)), apply_padding=True - ) as accelerator_records: - generations = [ - self._generate_from_batch( - records_batch, - original_records_batch, - dataset.source.name, - ) - for records_batch, original_records_batch in accelerator_records - ][: len(records_batches)] - - return sum(generations, []) + # with self._accelerator.split_between_processes( + # list(zip(records_batches, original_records_batches)), apply_padding=True + # ) as accelerator_records: + len_records_batches = len(records_batches) + world_size = self._accelerator.num_processes + rank_device = self._accelerator.process_index + window_size = math.ceil(len_records_batches / world_size) + + print(f"len records_batches", len(records_batches)) + print("original_records_batches", len(original_records_batches)) + print("rank_device", rank_device) + print("world_size", world_size) + print(f"slice [{rank_device * window_size} : {rank_device * window_size + window_size}]") + + slice_records_batches = records_batches[rank_device * window_size : rank_device * window_size + window_size] + slice_original_records_batches = original_records_batches[rank_device * window_size : (rank_device + 1) * window_size] + + flattened_records_batches = [r for batch in slice_records_batches for r in batch] + flattened_original_records_batches = [r for batch in slice_original_records_batches for r in batch] + + generations = self._generate_from_batch( + flattened_records_batches, + flattened_original_records_batches, + dataset.source.name, + ) + # for records_batch, original_records_batch in accelerator_records + # ][: len(records_batches)] + # ] + # print(generations) + + # return sum(generations, []) + return generations class ChatGeneratorBase(BaseGenerator, Generic[DatasetRecordT, InferenceOutputT]):