Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
a.khokhulin committed Nov 25, 2024
1 parent 4d84b45 commit d8d811c
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 30 deletions.
7 changes: 2 additions & 5 deletions turbo_alignment/cherry_picks/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import abstractmethod
import math
from typing import Generic, Iterable, TypeVar

from accelerate import Accelerator
Expand Down Expand Up @@ -76,8 +75,6 @@ def on_evaluate(
return dataset_metrics

@staticmethod
@abstractmethod
def _get_sharded_dataset(dataset: InferenceDatasetT, accelerator: Accelerator) -> InferenceDatasetT:
rank_device = accelerator.process_index
slice_size = math.ceil(len(dataset) / accelerator.num_processes)

return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size)
...
8 changes: 8 additions & 0 deletions turbo_alignment/cherry_picks/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Iterable

from accelerate import Accelerator
Expand Down Expand Up @@ -109,3 +110,10 @@ def _get_dataset_metrics(

metric_outputs.extend(metric_results)
return metric_outputs

@staticmethod
def _get_sharded_dataset(dataset: InferenceChatDataset, accelerator: Accelerator) -> InferenceChatDataset:
rank_device = accelerator.process_index
slice_size = math.ceil(len(dataset) / accelerator.num_processes)

return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size)
20 changes: 16 additions & 4 deletions turbo_alignment/cherry_picks/classification.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
import math
from typing import Iterable

from accelerate import Accelerator
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from turbo_alignment.cherry_picks.base import CherryPickCallbackBase
from turbo_alignment.dataset.chat.conversation import Conversation
from turbo_alignment.dataset.classification.classification import ClassificationDataset
from turbo_alignment.dataset.classification.classification import (
InferenceClassificationDataset,
)
from turbo_alignment.generators.classification import ClassificationGenerator
from turbo_alignment.metrics.metric import Metric
from turbo_alignment.settings.cherry_pick import ClassificationCherryPickSettings
from turbo_alignment.settings.metric import ElementWiseScores, MetricResults


class ClassificationCherryPickCallback(CherryPickCallbackBase[ClassificationDataset]):
class ClassificationCherryPickCallback(CherryPickCallbackBase[InferenceClassificationDataset]):
def __init__(
self,
cherry_pick_settings: ClassificationCherryPickSettings,
datasets: Iterable[ClassificationDataset],
datasets: Iterable[InferenceClassificationDataset],
metrics: list[Metric],
) -> None:
super().__init__(cherry_pick_settings=cherry_pick_settings, datasets=datasets, metrics=metrics)

def _get_dataset_metrics(
self,
dataset: ClassificationDataset,
dataset: InferenceClassificationDataset,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
**kwargs,
Expand Down Expand Up @@ -69,3 +72,12 @@ def _get_dataset_metrics(
]

return metric_outputs

@staticmethod
def _get_sharded_dataset(
dataset: InferenceClassificationDataset, accelerator: Accelerator
) -> InferenceClassificationDataset:
rank_device = accelerator.process_index
slice_size = math.ceil(len(dataset) / accelerator.num_processes)

return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size)
2 changes: 2 additions & 0 deletions turbo_alignment/cherry_picks/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@


class MultimodalCherryPickCallback(CherryPickCallbackBase[InferenceMultimodalDataset]):
# pylint: disable=abstract-method
# TODO: add _get_sharded_dataset method
def __init__(
self,
cherry_pick_settings: MultimodalCherryPickSettings,
Expand Down
9 changes: 9 additions & 0 deletions turbo_alignment/cherry_picks/rag.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

from accelerate import Accelerator
from transformers import PreTrainedModel, PreTrainedTokenizerBase

Expand Down Expand Up @@ -62,3 +64,10 @@ def _get_dataset_metrics(
metric_outputs.extend(metric_results)

return metric_outputs

@staticmethod
def _get_sharded_dataset(dataset, accelerator: Accelerator) -> InferenceChatDataset:
rank_device = accelerator.process_index
slice_size = math.ceil(len(dataset) / accelerator.num_processes)

return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size)
8 changes: 8 additions & 0 deletions turbo_alignment/cherry_picks/rm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Iterable

from accelerate import Accelerator
Expand Down Expand Up @@ -75,3 +76,10 @@ def _get_dataset_metrics(
]

return metric_outputs

@staticmethod
def _get_sharded_dataset(dataset: PairPreferenceDataset, accelerator: Accelerator) -> PairPreferenceDataset:
rank_device = accelerator.process_index
slice_size = math.ceil(len(dataset) / accelerator.num_processes)

return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size)
5 changes: 0 additions & 5 deletions turbo_alignment/dataset/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizerBase
from typing_extensions import Self

from turbo_alignment.common.logging import get_project_logger
from turbo_alignment.dataset.base.models import DatasetRecord
Expand Down Expand Up @@ -103,10 +102,6 @@ def _read_records(records: list[dict]) -> list[RecordT]:
def _read_records(records):
...

@abstractmethod
def get_slice(self, start: int, end: int) -> Self:
...


class AlignmentDataset(BaseDataset, ABC, Generic[RecordT]):
def __init__(
Expand Down
24 changes: 12 additions & 12 deletions turbo_alignment/dataset/classification/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ def _read_records(records) -> list[ClassificationDatasetRecord]:
return [ClassificationDatasetRecord(**record) for record in records]
raise NotImplementedError


@ClassificationDatasetTypeRegistry.register(DatasetStrategy.TRAIN)
class TrainClassificationDataset(ClassificationDataset):
def convert_records(self, records: list[ClassificationDatasetRecord]) -> list[dict[str, Any] | None]:
return self._encode(records, inference=False)


@ClassificationDatasetTypeRegistry.register(DatasetStrategy.INFERENCE)
class InferenceClassificationDataset(ClassificationDataset):
def convert_records(self, records: list[ClassificationDatasetRecord]) -> list[dict[str, Any] | None]:
return self._encode(records, inference=True)

def get_slice(self, start: int, end: int) -> Self:
new_instance = self.__class__(
source=self.source,
Expand All @@ -97,15 +109,3 @@ def get_slice(self, start: int, end: int) -> Self:
}

return new_instance


@ClassificationDatasetTypeRegistry.register(DatasetStrategy.TRAIN)
class TrainClassificationDataset(ClassificationDataset):
def convert_records(self, records: list[ClassificationDatasetRecord]) -> list[dict[str, Any] | None]:
return self._encode(records, inference=False)


@ClassificationDatasetTypeRegistry.register(DatasetStrategy.INFERENCE)
class InferenceClassificationDataset(ClassificationDataset):
def convert_records(self, records: list[ClassificationDatasetRecord]) -> list[dict[str, Any] | None]:
return self._encode(records, inference=True)
10 changes: 6 additions & 4 deletions turbo_alignment/pipelines/train/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from turbo_alignment.cherry_picks.classification import ClassificationCherryPickCallback
from turbo_alignment.common.logging import get_project_logger
from turbo_alignment.constants import TRAINER_LOGS_FOLDER
from turbo_alignment.dataset.classification.classification import ClassificationDataset
from turbo_alignment.dataset.classification.classification import (
InferenceClassificationDataset,
)
from turbo_alignment.dataset.loader import DatasetLoader
from turbo_alignment.metrics.metric import Metric
from turbo_alignment.metrics.registry import MetricSettingsRegistry
Expand Down Expand Up @@ -47,9 +49,9 @@ def _get_cherry_pick_callback(
) -> ClassificationCherryPickCallback:
cherry_pick_settings = experiment_settings.cherry_pick_settings

cherry_pick_datasets = DatasetLoader[ClassificationDataset](ClassificationDataset).load_datasets(
cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE
)
cherry_pick_datasets = DatasetLoader[InferenceClassificationDataset](
InferenceClassificationDataset
).load_datasets(cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE)

metrics = [
Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters))
Expand Down

0 comments on commit d8d811c

Please sign in to comment.