Skip to content

Commit

Permalink
in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
a.khokhulin committed Nov 11, 2024
1 parent a66986f commit f25e31f
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 14 deletions.
41 changes: 40 additions & 1 deletion turbo_alignment/cherry_picks/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Iterable

from accelerate import Accelerator
Expand Down Expand Up @@ -41,16 +42,47 @@ def _get_dataset_metrics(
tokenizer=tokenizer,
transformers_settings=self._generator_transformers_settings,
custom_generation_settings=self._custom_generation_settings,
accelerator=accelerator,
# accelerator=accelerator,
return_logits=True,
)

batch_size = self._generator_transformers_settings.num_return_sequences

if accelerator is not None:
len_records_batches = len(dataset)
world_size = accelerator.num_processes
rank_device = accelerator.process_index
window_size = math.ceil(len_records_batches / world_size)

print(f"len records_batches", len_records_batches)
print("rank_device", rank_device)
print("world_size", world_size)
print(f"slice [{rank_device * window_size} : {rank_device * window_size + window_size}]")


dataset = dataset[rank_device * window_size : rank_device * window_size + window_size]

generations = generator.generate_from_dataset(dataset)

print(f"len dataset {len(dataset)}")
print(f"first 2 prompts: {[record['prompt'] for record in dataset][:2]}")
print(f"len generations {len(generations)}")
print(f"first 2 generations: {generations[:2]}")

prompts = [record['prompt'] for record in dataset]

# len_records_batches = len(prompts)
# world_size = accelerator.num_processes
# rank_device = accelerator.process_index
# window_size = math.ceil(len_records_batches / world_size)

# prompts =



string_answers = [[answer.content for answer in g.answers] for g in generations]
print(f"prompts: {len(prompts)}, {prompts[:2]}")
print(f"string_answers: {len(string_answers)}, {string_answers[:2]}")
string_labels = [[g.messages[-1].content] * len(g.answers) for g in generations]

flattened_answers = [answer for g in generations for answer in g.answers]
Expand All @@ -66,6 +98,9 @@ def _get_dataset_metrics(
if sft_model is not None:
metrics_kwargs[KLType.SFT_MODEL] = get_logits(input_tokens_ids, answer_tokens_ids, sft_model)

print(f"len prompts {len(prompts)}")
print(f"batch_size {batch_size}")
print(f"prompt element_wise_scores {len([prompt for prompt in prompts for _ in range(batch_size)])}")
metric_outputs = [
MetricResults(
element_wise_scores=[
Expand All @@ -86,6 +121,10 @@ def _get_dataset_metrics(
]
),
]
print(f"prompts: {len(metric_outputs[0].element_wise_scores[0].values)}, {metric_outputs[0].element_wise_scores[0].values}")
print(f"labels: {len(metric_outputs[1].element_wise_scores[0].values)}, {metric_outputs[1].element_wise_scores[0].values}")
# print(f"metric_outputs: {metric_outputs}")


for metric in self._metrics:
metric_results = metric.compute(
Expand Down
45 changes: 32 additions & 13 deletions turbo_alignment/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,38 @@ def generate_from_dataset(self, dataset: BaseDataset) -> list[InferenceOutputT]:
)
)
else:
with self._accelerator.split_between_processes(
list(zip(records_batches, original_records_batches)), apply_padding=True
) as accelerator_records:
generations = [
self._generate_from_batch(
records_batch,
original_records_batch,
dataset.source.name,
)
for records_batch, original_records_batch in accelerator_records
][: len(records_batches)]

return sum(generations, [])
# with self._accelerator.split_between_processes(
# list(zip(records_batches, original_records_batches)), apply_padding=True
# ) as accelerator_records:
len_records_batches = len(records_batches)
world_size = self._accelerator.num_processes
rank_device = self._accelerator.process_index
window_size = math.ceil(len_records_batches / world_size)

print(f"len records_batches", len(records_batches))
print("original_records_batches", len(original_records_batches))
print("rank_device", rank_device)
print("world_size", world_size)
print(f"slice [{rank_device * window_size} : {rank_device * window_size + window_size}]")

slice_records_batches = records_batches[rank_device * window_size : rank_device * window_size + window_size]
slice_original_records_batches = original_records_batches[rank_device * window_size : (rank_device + 1) * window_size]

flattened_records_batches = [r for batch in slice_records_batches for r in batch]
flattened_original_records_batches = [r for batch in slice_original_records_batches for r in batch]

generations = self._generate_from_batch(
flattened_records_batches,
flattened_original_records_batches,
dataset.source.name,
)
# for records_batch, original_records_batch in accelerator_records
# ][: len(records_batches)]
# ]
# print(generations)

# return sum(generations, [])
return generations


class ChatGeneratorBase(BaseGenerator, Generic[DatasetRecordT, InferenceOutputT]):
Expand Down

0 comments on commit f25e31f

Please sign in to comment.