diff --git a/turbo_alignment/cherry_picks/base.py b/turbo_alignment/cherry_picks/base.py index 995dd62..93b0ffd 100755 --- a/turbo_alignment/cherry_picks/base.py +++ b/turbo_alignment/cherry_picks/base.py @@ -1,5 +1,4 @@ from abc import abstractmethod -import math from typing import Generic, Iterable, TypeVar from accelerate import Accelerator @@ -76,8 +75,6 @@ def on_evaluate( return dataset_metrics @staticmethod + @abstractmethod def _get_sharded_dataset(dataset: InferenceDatasetT, accelerator: Accelerator) -> InferenceDatasetT: - rank_device = accelerator.process_index - slice_size = math.ceil(len(dataset) / accelerator.num_processes) - - return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size) + ... diff --git a/turbo_alignment/cherry_picks/chat.py b/turbo_alignment/cherry_picks/chat.py index c6f4631..3229940 100755 --- a/turbo_alignment/cherry_picks/chat.py +++ b/turbo_alignment/cherry_picks/chat.py @@ -1,3 +1,4 @@ +import math from typing import Iterable from accelerate import Accelerator @@ -109,3 +110,10 @@ def _get_dataset_metrics( metric_outputs.extend(metric_results) return metric_outputs + + @staticmethod + def _get_sharded_dataset(dataset: InferenceChatDataset, accelerator: Accelerator) -> InferenceChatDataset: + rank_device = accelerator.process_index + slice_size = math.ceil(len(dataset) / accelerator.num_processes) + + return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size) diff --git a/turbo_alignment/cherry_picks/classification.py b/turbo_alignment/cherry_picks/classification.py index cb2e0b1..e811b0e 100755 --- a/turbo_alignment/cherry_picks/classification.py +++ b/turbo_alignment/cherry_picks/classification.py @@ -1,3 +1,4 @@ +import math from typing import Iterable from accelerate import Accelerator @@ -5,25 +6,27 @@ from turbo_alignment.cherry_picks.base import CherryPickCallbackBase from turbo_alignment.dataset.chat.conversation import Conversation -from turbo_alignment.dataset.classification.classification import ClassificationDataset +from turbo_alignment.dataset.classification.classification import ( + InferenceClassificationDataset, +) from turbo_alignment.generators.classification import ClassificationGenerator from turbo_alignment.metrics.metric import Metric from turbo_alignment.settings.cherry_pick import ClassificationCherryPickSettings from turbo_alignment.settings.metric import ElementWiseScores, MetricResults -class ClassificationCherryPickCallback(CherryPickCallbackBase[ClassificationDataset]): +class ClassificationCherryPickCallback(CherryPickCallbackBase[InferenceClassificationDataset]): def __init__( self, cherry_pick_settings: ClassificationCherryPickSettings, - datasets: Iterable[ClassificationDataset], + datasets: Iterable[InferenceClassificationDataset], metrics: list[Metric], ) -> None: super().__init__(cherry_pick_settings=cherry_pick_settings, datasets=datasets, metrics=metrics) def _get_dataset_metrics( self, - dataset: ClassificationDataset, + dataset: InferenceClassificationDataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase, **kwargs, @@ -69,3 +72,12 @@ def _get_dataset_metrics( ] return metric_outputs + + @staticmethod + def _get_sharded_dataset( + dataset: InferenceClassificationDataset, accelerator: Accelerator + ) -> InferenceClassificationDataset: + rank_device = accelerator.process_index + slice_size = math.ceil(len(dataset) / accelerator.num_processes) + + return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size) diff --git a/turbo_alignment/cherry_picks/multimodal.py b/turbo_alignment/cherry_picks/multimodal.py index 80c6022..a4f94bd 100755 --- a/turbo_alignment/cherry_picks/multimodal.py +++ b/turbo_alignment/cherry_picks/multimodal.py @@ -14,6 +14,8 @@ class MultimodalCherryPickCallback(CherryPickCallbackBase[InferenceMultimodalDataset]): + # pylint: disable=abstract-method + # TODO: add _get_sharded_dataset method def __init__( self, cherry_pick_settings: MultimodalCherryPickSettings, diff --git a/turbo_alignment/cherry_picks/rag.py b/turbo_alignment/cherry_picks/rag.py index 53d332e..2c0cd27 100755 --- a/turbo_alignment/cherry_picks/rag.py +++ b/turbo_alignment/cherry_picks/rag.py @@ -1,3 +1,5 @@ +import math + from accelerate import Accelerator from transformers import PreTrainedModel, PreTrainedTokenizerBase @@ -62,3 +64,10 @@ def _get_dataset_metrics( metric_outputs.extend(metric_results) return metric_outputs + + @staticmethod + def _get_sharded_dataset(dataset, accelerator: Accelerator) -> InferenceChatDataset: + rank_device = accelerator.process_index + slice_size = math.ceil(len(dataset) / accelerator.num_processes) + + return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size) diff --git a/turbo_alignment/cherry_picks/rm.py b/turbo_alignment/cherry_picks/rm.py index 829f221..11e79f2 100755 --- a/turbo_alignment/cherry_picks/rm.py +++ b/turbo_alignment/cherry_picks/rm.py @@ -1,3 +1,4 @@ +import math from typing import Iterable from accelerate import Accelerator @@ -75,3 +76,10 @@ def _get_dataset_metrics( ] return metric_outputs + + @staticmethod + def _get_sharded_dataset(dataset: PairPreferenceDataset, accelerator: Accelerator) -> PairPreferenceDataset: + rank_device = accelerator.process_index + slice_size = math.ceil(len(dataset) / accelerator.num_processes) + + return dataset.get_slice(rank_device * slice_size, rank_device * slice_size + slice_size) diff --git a/turbo_alignment/dataset/base/base.py b/turbo_alignment/dataset/base/base.py index aac204f..71e353e 100755 --- a/turbo_alignment/dataset/base/base.py +++ b/turbo_alignment/dataset/base/base.py @@ -6,7 +6,6 @@ import torch from torch.utils.data import Dataset from transformers import PreTrainedTokenizerBase -from typing_extensions import Self from turbo_alignment.common.logging import get_project_logger from turbo_alignment.dataset.base.models import DatasetRecord @@ -103,10 +102,6 @@ def _read_records(records: list[dict]) -> list[RecordT]: def _read_records(records): ... - @abstractmethod - def get_slice(self, start: int, end: int) -> Self: - ... - class AlignmentDataset(BaseDataset, ABC, Generic[RecordT]): def __init__( diff --git a/turbo_alignment/dataset/classification/classification.py b/turbo_alignment/dataset/classification/classification.py index 030bd69..77fac16 100755 --- a/turbo_alignment/dataset/classification/classification.py +++ b/turbo_alignment/dataset/classification/classification.py @@ -82,6 +82,18 @@ def _read_records(records) -> list[ClassificationDatasetRecord]: return [ClassificationDatasetRecord(**record) for record in records] raise NotImplementedError + +@ClassificationDatasetTypeRegistry.register(DatasetStrategy.TRAIN) +class TrainClassificationDataset(ClassificationDataset): + def convert_records(self, records: list[ClassificationDatasetRecord]) -> list[dict[str, Any] | None]: + return self._encode(records, inference=False) + + +@ClassificationDatasetTypeRegistry.register(DatasetStrategy.INFERENCE) +class InferenceClassificationDataset(ClassificationDataset): + def convert_records(self, records: list[ClassificationDatasetRecord]) -> list[dict[str, Any] | None]: + return self._encode(records, inference=True) + def get_slice(self, start: int, end: int) -> Self: new_instance = self.__class__( source=self.source, @@ -97,15 +109,3 @@ def get_slice(self, start: int, end: int) -> Self: } return new_instance - - -@ClassificationDatasetTypeRegistry.register(DatasetStrategy.TRAIN) -class TrainClassificationDataset(ClassificationDataset): - def convert_records(self, records: list[ClassificationDatasetRecord]) -> list[dict[str, Any] | None]: - return self._encode(records, inference=False) - - -@ClassificationDatasetTypeRegistry.register(DatasetStrategy.INFERENCE) -class InferenceClassificationDataset(ClassificationDataset): - def convert_records(self, records: list[ClassificationDatasetRecord]) -> list[dict[str, Any] | None]: - return self._encode(records, inference=True) diff --git a/turbo_alignment/pipelines/train/classification.py b/turbo_alignment/pipelines/train/classification.py index 1bc09fe..5b07f92 100755 --- a/turbo_alignment/pipelines/train/classification.py +++ b/turbo_alignment/pipelines/train/classification.py @@ -7,7 +7,9 @@ from turbo_alignment.cherry_picks.classification import ClassificationCherryPickCallback from turbo_alignment.common.logging import get_project_logger from turbo_alignment.constants import TRAINER_LOGS_FOLDER -from turbo_alignment.dataset.classification.classification import ClassificationDataset +from turbo_alignment.dataset.classification.classification import ( + InferenceClassificationDataset, +) from turbo_alignment.dataset.loader import DatasetLoader from turbo_alignment.metrics.metric import Metric from turbo_alignment.metrics.registry import MetricSettingsRegistry @@ -47,9 +49,9 @@ def _get_cherry_pick_callback( ) -> ClassificationCherryPickCallback: cherry_pick_settings = experiment_settings.cherry_pick_settings - cherry_pick_datasets = DatasetLoader[ClassificationDataset](ClassificationDataset).load_datasets( - cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE - ) + cherry_pick_datasets = DatasetLoader[InferenceClassificationDataset]( + InferenceClassificationDataset + ).load_datasets(cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE) metrics = [ Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters))