Skip to content

Commit

Permalink
Merge pull request #61 from turbo-llm/fix_clearml_cherrypicks
Browse files Browse the repository at this point in the history
🍒 Fix cherrypicks
  • Loading branch information
white-r0se authored Dec 13, 2024
2 parents 009574b + 0ba6c6a commit 82c1b03
Show file tree
Hide file tree
Showing 10 changed files with 134 additions and 11 deletions.
6 changes: 6 additions & 0 deletions turbo_alignment/cherry_picks/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import abstractmethod
from typing import Generic, Iterable, TypeVar

from accelerate import Accelerator
from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase,
Expand Down Expand Up @@ -72,3 +73,8 @@ def on_evaluate(
model.train()

return dataset_metrics

@staticmethod
@abstractmethod
def _get_sharded_dataset(dataset: InferenceDatasetT, accelerator: Accelerator) -> InferenceDatasetT:
...
15 changes: 14 additions & 1 deletion turbo_alignment/cherry_picks/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Iterable

from accelerate import Accelerator
Expand Down Expand Up @@ -41,12 +42,17 @@ def _get_dataset_metrics(
tokenizer=tokenizer,
transformers_settings=self._generator_transformers_settings,
custom_generation_settings=self._custom_generation_settings,
accelerator=accelerator,
return_logits=True,
)

batch_size = self._generator_transformers_settings.num_return_sequences

if accelerator is not None:
dataset = self._get_sharded_dataset(
dataset=dataset,
accelerator=accelerator,
)

generations = generator.generate_from_dataset(dataset)

prompts = [record['prompt'] for record in dataset]
Expand Down Expand Up @@ -104,3 +110,10 @@ def _get_dataset_metrics(

metric_outputs.extend(metric_results)
return metric_outputs

@staticmethod
def _get_sharded_dataset(dataset: InferenceChatDataset, accelerator: Accelerator) -> InferenceChatDataset:
rank_device = accelerator.process_index
slice_size = math.ceil(len(dataset) / accelerator.num_processes)

return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size)
27 changes: 22 additions & 5 deletions turbo_alignment/cherry_picks/classification.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
import math
from typing import Iterable

from accelerate import Accelerator
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from turbo_alignment.cherry_picks.base import CherryPickCallbackBase
from turbo_alignment.dataset.chat.conversation import Conversation
from turbo_alignment.dataset.classification.classification import ClassificationDataset
from turbo_alignment.dataset.classification.classification import (
InferenceClassificationDataset,
)
from turbo_alignment.generators.classification import ClassificationGenerator
from turbo_alignment.metrics.metric import Metric
from turbo_alignment.settings.cherry_pick import ClassificationCherryPickSettings
from turbo_alignment.settings.metric import ElementWiseScores, MetricResults


class ClassificationCherryPickCallback(CherryPickCallbackBase[ClassificationDataset]):
class ClassificationCherryPickCallback(CherryPickCallbackBase[InferenceClassificationDataset]):
def __init__(
self,
cherry_pick_settings: ClassificationCherryPickSettings,
datasets: Iterable[ClassificationDataset],
datasets: Iterable[InferenceClassificationDataset],
metrics: list[Metric],
) -> None:
super().__init__(cherry_pick_settings=cherry_pick_settings, datasets=datasets, metrics=metrics)

def _get_dataset_metrics(
self,
dataset: ClassificationDataset,
dataset: InferenceClassificationDataset,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
**kwargs,
Expand All @@ -33,9 +36,14 @@ def _get_dataset_metrics(
generator = ClassificationGenerator(
model=model,
tokenizer=tokenizer,
accelerator=accelerator,
)

if accelerator is not None:
dataset = self._get_sharded_dataset(
dataset=dataset,
accelerator=accelerator,
)

generations = generator.generate_from_dataset(dataset)
predictions = [record.predicted_label for record in generations]
labels = [record['labels'] for record in dataset]
Expand Down Expand Up @@ -64,3 +72,12 @@ def _get_dataset_metrics(
]

return metric_outputs

@staticmethod
def _get_sharded_dataset(
dataset: InferenceClassificationDataset, accelerator: Accelerator
) -> InferenceClassificationDataset:
rank_device = accelerator.process_index
slice_size = math.ceil(len(dataset) / accelerator.num_processes)

return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size)
2 changes: 2 additions & 0 deletions turbo_alignment/cherry_picks/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@


class MultimodalCherryPickCallback(CherryPickCallbackBase[InferenceMultimodalDataset]):
# pylint: disable=abstract-method
# TODO: add _get_sharded_dataset method
def __init__(
self,
cherry_pick_settings: MultimodalCherryPickSettings,
Expand Down
16 changes: 15 additions & 1 deletion turbo_alignment/cherry_picks/rag.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

from accelerate import Accelerator
from transformers import PreTrainedModel, PreTrainedTokenizerBase

Expand All @@ -22,9 +24,14 @@ def _get_dataset_metrics(
tokenizer=tokenizer,
transformers_settings=self._generator_transformers_settings,
custom_generation_settings=self._custom_generation_settings,
accelerator=accelerator,
)

if accelerator is not None:
dataset = self._get_sharded_dataset(
dataset=dataset,
accelerator=accelerator,
)

generations = generator.generate_from_dataset(dataset)

prompts = [dataset[i]['prompt'] for i in range(len(dataset))]
Expand Down Expand Up @@ -57,3 +64,10 @@ def _get_dataset_metrics(
metric_outputs.extend(metric_results)

return metric_outputs

@staticmethod
def _get_sharded_dataset(dataset, accelerator: Accelerator) -> InferenceChatDataset:
rank_device = accelerator.process_index
slice_size = math.ceil(len(dataset) / accelerator.num_processes)

return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size)
14 changes: 14 additions & 0 deletions turbo_alignment/cherry_picks/rm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Iterable

from accelerate import Accelerator
Expand Down Expand Up @@ -35,6 +36,12 @@ def _get_dataset_metrics(
accelerator=accelerator,
)

if accelerator is not None:
dataset = self._get_sharded_dataset(
dataset=dataset,
accelerator=accelerator,
)

generations = generator.generate_from_dataset(dataset)
generations_w = [gen.reward_w for gen in generations]
generations_l = [gen.reward_l for gen in generations]
Expand Down Expand Up @@ -69,3 +76,10 @@ def _get_dataset_metrics(
]

return metric_outputs

@staticmethod
def _get_sharded_dataset(dataset: PairPreferenceDataset, accelerator: Accelerator) -> PairPreferenceDataset:
rank_device = accelerator.process_index
slice_size = math.ceil(len(dataset) / accelerator.num_processes)

return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size)
19 changes: 19 additions & 0 deletions turbo_alignment/dataset/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy.typing as npt
import torch
from transformers import PreTrainedTokenizerBase
from typing_extensions import Self

from turbo_alignment.common.data.io import read_jsonl
from turbo_alignment.common.logging import get_project_logger
Expand Down Expand Up @@ -361,3 +362,21 @@ def __init__(

def convert_records(self, records: list[ChatDatasetRecord]) -> list[dict[str, Any] | None]:
return self._encode(records, inference=True, random_cut=self._random_cut)

def get_slice(self, start: int, end: int) -> Self:
new_instance = self.__class__(
source=self.source,
settings=self.settings,
tokenizer=self.tokenizer,
read=False,
random_cut=self._random_cut,
)

dataset_records = [self[idx] for idx in range(len(self))]

new_instance.records = self.records[start:end]
new_instance.original_records_map = {
record['id']: self.get_original_record_by_id(record['id']) for record in dataset_records
}

return new_instance
17 changes: 17 additions & 0 deletions turbo_alignment/dataset/classification/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, overload

from transformers import PreTrainedTokenizerBase
from typing_extensions import Self

from turbo_alignment.common.data.io import read_jsonl
from turbo_alignment.common.logging import get_project_logger
Expand Down Expand Up @@ -92,3 +93,19 @@ def convert_records(self, records: list[ClassificationDatasetRecord]) -> list[di
class InferenceClassificationDataset(ClassificationDataset):
def convert_records(self, records: list[ClassificationDatasetRecord]) -> list[dict[str, Any] | None]:
return self._encode(records, inference=True)

def get_slice(self, start: int, end: int) -> Self:
new_instance = self.__class__(
source=self.source,
settings=self.settings,
tokenizer=self.tokenizer,
)

dataset_records = [self[idx] for idx in range(len(self))]

new_instance.records = self.records[start:end]
new_instance.original_records_map = {
record['id']: self.get_original_record_by_id(record['id']) for record in dataset_records
}

return new_instance
19 changes: 19 additions & 0 deletions turbo_alignment/dataset/pair_preferences/pair_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, overload

from transformers import PreTrainedTokenizerBase
from typing_extensions import Self

from turbo_alignment.common.data.io import read_jsonl
from turbo_alignment.common.logging import get_project_logger
Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(
read=False,
)
super().__init__(source=source, settings=settings, tokenizer=tokenizer)
self.settings: PairPreferenceDatasetSettings = settings

if read:
self._read()
Expand Down Expand Up @@ -107,3 +109,20 @@ def _read_records(records) -> list[PairPreferenceRecord]:
if isinstance(records, list):
return [PairPreferenceRecord(**record) for record in records]
raise NotImplementedError

def get_slice(self, start: int, end: int) -> Self:
new_instance = self.__class__(
source=self.source,
settings=self.settings,
tokenizer=self.tokenizer,
read=False,
)

dataset_records = [self[idx] for idx in range(len(self))]

new_instance.records = self.records[start:end]
new_instance.original_records_map = {
record['id']: self.get_original_record_by_id(record['id']) for record in dataset_records
}

return new_instance
10 changes: 6 additions & 4 deletions turbo_alignment/pipelines/train/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from turbo_alignment.cherry_picks.classification import ClassificationCherryPickCallback
from turbo_alignment.common.logging import get_project_logger
from turbo_alignment.constants import TRAINER_LOGS_FOLDER
from turbo_alignment.dataset.classification.classification import ClassificationDataset
from turbo_alignment.dataset.classification.classification import (
InferenceClassificationDataset,
)
from turbo_alignment.dataset.loader import DatasetLoader
from turbo_alignment.metrics.metric import Metric
from turbo_alignment.metrics.registry import MetricSettingsRegistry
Expand Down Expand Up @@ -47,9 +49,9 @@ def _get_cherry_pick_callback(
) -> ClassificationCherryPickCallback:
cherry_pick_settings = experiment_settings.cherry_pick_settings

cherry_pick_datasets = DatasetLoader[ClassificationDataset](ClassificationDataset).load_datasets(
cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE
)
cherry_pick_datasets = DatasetLoader[InferenceClassificationDataset](
InferenceClassificationDataset
).load_datasets(cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE)

metrics = [
Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters))
Expand Down

0 comments on commit 82c1b03

Please sign in to comment.