From 6696370d70c8f262fd5edcfeda65c3069afd583d Mon Sep 17 00:00:00 2001 From: Boris Shaposhnikov Date: Mon, 9 Sep 2024 12:35:08 +0300 Subject: [PATCH] reformat --- turbo_alignment/dataset/chat/chat.py | 63 ++++++++++----------- turbo_alignment/pipelines/train/base.py | 74 +++++++++++++------------ 2 files changed, 70 insertions(+), 67 deletions(-) diff --git a/turbo_alignment/dataset/chat/chat.py b/turbo_alignment/dataset/chat/chat.py index e3f282f..be22995 100755 --- a/turbo_alignment/dataset/chat/chat.py +++ b/turbo_alignment/dataset/chat/chat.py @@ -102,11 +102,11 @@ def enum_to_string_in_dict(conv: list[dict[str, Any]]): class ChatDataset(AlignmentDataset[ChatDatasetRecord], ABC): def __init__( - self, - source: DatasetSourceSettings, - settings: ChatDatasetSettings, - tokenizer: PreTrainedTokenizerBase, - read: bool = True, + self, + source: DatasetSourceSettings, + settings: ChatDatasetSettings, + tokenizer: PreTrainedTokenizerBase, + read: bool = True, ) -> None: super().__init__(source=source, settings=settings, tokenizer=tokenizer) self.settings: ChatDatasetSettings = settings @@ -129,10 +129,10 @@ def _print(self, tokens: Tokens, prefix: str | None = None): loguru.logger.info("{}{}", prefix, [self.tokenizer.batch_decode([id])[0] for id in tokens.input_ids]) def _encode( - self, - records: list[ChatDatasetRecord], - inference: bool, - random_cut: bool, + self, + records: list[ChatDatasetRecord], + inference: bool, + random_cut: bool, ) -> list[dict[str, Any] | None]: output = [] for record in records: @@ -151,7 +151,7 @@ def _encode( padding=False, truncation=False, return_tensors="np", - meta=record.meta + meta=record.meta, ) input_ids = tokenized_conversation["input_ids"][0] @@ -161,13 +161,13 @@ def _encode( binary_seq_str = ''.join('1' if i else '0' for i in assistant_masks) last_replica_start_ind = binary_seq_str.rfind('01') - result_str = binary_seq_str[:last_replica_start_ind + 1].replace('1', '0') + binary_seq_str[ - last_replica_start_ind + 1:] + result_str = ( + binary_seq_str[: last_replica_start_ind + 1].replace('1', '0') + + binary_seq_str[last_replica_start_ind + 1 :] + ) assistant_masks = [i == '1' for i in result_str] - tokens = Tokens( - input_ids, np.where(assistant_masks, input_ids, DISABLE_LOSS_LABEL) - ) + tokens = Tokens(input_ids, np.where(assistant_masks, input_ids, DISABLE_LOSS_LABEL)) replica_start_indices = get_target_value_indices(input_ids, self.replica_start_ids) replica_end_indices = get_target_value_indices(input_ids, self.replica_end_ids) @@ -196,9 +196,8 @@ def _encode( max_length = max_length - (len(prefix) + len(suffix)) while ( - replica_start_indices[right_replica_ind_exclusive] - - replica_start_indices[left_replica_ind_inclusive] - > max_length + replica_start_indices[right_replica_ind_exclusive] - replica_start_indices[left_replica_ind_inclusive] + > max_length ): if self.settings.keep_end: left_replica_ind_inclusive += 1 @@ -209,11 +208,11 @@ def _encode( right_replica_ind_exclusive = random.randint(left_replica_ind_inclusive, right_replica_ind_exclusive) while right_replica_ind_exclusive > 0 and ( - (inference and conversation.messages[right_replica_ind_exclusive - 1].role == ChatMessageRole.BOT) - or ( - not inference - and conversation.messages[right_replica_ind_exclusive - 1].role == ChatMessageRole.USER - ) + (inference and conversation.messages[right_replica_ind_exclusive - 1].role == ChatMessageRole.BOT) + or ( + not inference + and conversation.messages[right_replica_ind_exclusive - 1].role == ChatMessageRole.USER + ) ): right_replica_ind_exclusive -= 1 @@ -244,13 +243,11 @@ def _encode( @staticmethod @overload - def _read_records(records: Path) -> list[ChatDatasetRecord]: - ... + def _read_records(records: Path) -> list[ChatDatasetRecord]: ... @staticmethod @overload - def _read_records(records: list[dict]) -> list[ChatDatasetRecord]: - ... + def _read_records(records: list[dict]) -> list[ChatDatasetRecord]: ... @staticmethod def _read_records(records) -> list[ChatDatasetRecord]: @@ -270,12 +267,12 @@ def convert_records(self, records: list[ChatDatasetRecord]) -> list[dict[str, An @ChatDatasetTypeRegistry.register(DatasetStrategy.INFERENCE) class InferenceChatDataset(ChatDataset): def __init__( - self, - source: DatasetSourceSettings, - settings: ChatDatasetSettings, - tokenizer: PreTrainedTokenizerBase, - read: bool = True, - random_cut: bool = False, + self, + source: DatasetSourceSettings, + settings: ChatDatasetSettings, + tokenizer: PreTrainedTokenizerBase, + read: bool = True, + random_cut: bool = False, ) -> None: self._random_cut = random_cut diff --git a/turbo_alignment/pipelines/train/base.py b/turbo_alignment/pipelines/train/base.py index 13008e1..4bd69c9 100755 --- a/turbo_alignment/pipelines/train/base.py +++ b/turbo_alignment/pipelines/train/base.py @@ -31,7 +31,7 @@ logger = get_project_logger() -ExperimentSettingsT = TypeVar('ExperimentSettingsT', bound=BaseTrainExperimentSettings) +ExperimentSettingsT = TypeVar("ExperimentSettingsT", bound=BaseTrainExperimentSettings) class BaseTrainStrategy(S3Mixin, LoggingMixin, BaseStrategy, Generic[ExperimentSettingsT]): @@ -46,18 +46,19 @@ def _get_wandb_callback(wandb_run: Run | RunDisabled) -> BaseWandbCallback: @staticmethod @abstractmethod def _get_cherry_pick_callback( - experiment_settings: ExperimentSettingsT, - tokenizer: PreTrainedTokenizerBase, - **kwargs, - ) -> CherryPickCallbackBase | None: - ... + experiment_settings: ExperimentSettingsT, + tokenizer: PreTrainedTokenizerBase, + **kwargs, + ) -> CherryPickCallbackBase | None: ... @staticmethod def _save_experiment_config( - experiment_settings: ExperimentSettingsT, model: PreTrainedModel, output_path: Path + experiment_settings: ExperimentSettingsT, + model: PreTrainedModel, + output_path: Path, ) -> None: model_config = model.config.to_dict() - model_config['experiment_config'] = experiment_settings + model_config["experiment_config"] = experiment_settings write_json(model_config, output_path) @staticmethod @@ -67,45 +68,47 @@ def _save_experiment_metadata(experiment_metadata: ExperimentMetadata, output_pa @staticmethod @abstractmethod def _get_data_collator( - experiment_settings: ExperimentSettingsT, tokenizer: PreTrainedTokenizerBase, **kwargs - ) -> Callable: - ... + experiment_settings: ExperimentSettingsT, + tokenizer: PreTrainedTokenizerBase, + **kwargs, + ) -> Callable: ... @staticmethod @abstractmethod def _get_trainer( - training_args: TrainingArguments, - experiment_settings: ExperimentSettingsT, - model: PreTrainedModel, - tokenizer: PreTrainedTokenizerBase, - train_dataset: Dataset, - val_dataset: Dataset, - data_collator: Callable, - ) -> Trainer: - ... + training_args: TrainingArguments, + experiment_settings: ExperimentSettingsT, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, + train_dataset: Dataset, + val_dataset: Dataset, + data_collator: Callable, + ) -> Trainer: ... @staticmethod def _load_model( - experiment_settings: ExperimentSettingsT, tokenizer: PreTrainedTokenizerBase + experiment_settings: ExperimentSettingsT, tokenizer: PreTrainedTokenizerBase ) -> torch.nn.Module | PreTrainedModel: return load_model(experiment_settings.model_settings, tokenizer) @staticmethod @abstractmethod - def _get_training_args(experiment_settings: ExperimentSettingsT) -> TrainingArguments: - ... + def _get_training_args( + experiment_settings: ExperimentSettingsT, + ) -> TrainingArguments: ... @staticmethod - def _load_tokenizer(experiment_settings: ExperimentSettingsT) -> PreTrainedTokenizerBase: + def _load_tokenizer( + experiment_settings: ExperimentSettingsT, + ) -> PreTrainedTokenizerBase: return load_tokenizer(experiment_settings.tokenizer_settings, experiment_settings.model_settings) @abstractmethod - def _dataset_and_collator_sanity_check(self, dataset: Dataset, collator: Callable) -> None: - ... + def _dataset_and_collator_sanity_check(self, dataset: Dataset, collator: Callable) -> None: ... @staticmethod def _get_additional_special_tokens( - experiment_settings: BaseTrainExperimentSettings, + experiment_settings: BaseTrainExperimentSettings, ) -> list[str]: embeddings_initialization_strategy = experiment_settings.model_settings.embeddings_initialization_strategy return list(embeddings_initialization_strategy.keys()) if embeddings_initialization_strategy else [] @@ -123,7 +126,7 @@ def _add_trainer_callbacks(self, experiment_settings: ExperimentSettingsT, **kwa if experiment_settings.checkpoint_uploader_callback_parameters is not None: if self.trainer.is_deepspeed_enabled and self.trainer.args.deepspeed_plugin.zero_stage == 3: raise NotImplementedError( - 'You should not use checkpoint uploader callback when DeepSpeed ZeRO stage 3 enabled' + "You should not use checkpoint uploader callback when DeepSpeed ZeRO stage 3 enabled" ) s3_handler = self._get_s3_handler(S3HandlerParameters()) @@ -141,21 +144,21 @@ def run(self, experiment_settings: ExperimentSettingsT) -> None: self.tokenizer = self._load_tokenizer(experiment_settings) - logger.info('Tokenizer is loaded!') + logger.info("Tokenizer is loaded!") additional_special_tokens = self._get_additional_special_tokens(experiment_settings) - logger.info(f'Special tokens: {additional_special_tokens}') + logger.info(f"Special tokens: {additional_special_tokens}") special_tokens_setter = SpecialTokensSetter(self.tokenizer) special_tokens_setter.set_all() special_tokens_setter.set_custom_tokens(additional_special_tokens) - logger.info('Special tokens added!') + logger.info("Special tokens added!") self.model = self._load_model(experiment_settings, self.tokenizer) special_tokens_setter.setup_model_config(self.model) - logger.info('Model is loaded!') + logger.info("Model is loaded!") train_dataset: ConcatDataset = ConcatDataset( datasets=DatasetLoader().load_datasets( @@ -192,12 +195,15 @@ def run(self, experiment_settings: ExperimentSettingsT) -> None: os.makedirs(self.trainer.args.output_dir, exist_ok=True) self._save_experiment_config( - experiment_settings, self.trainer.model, Path(self.trainer.args.output_dir) / 'experiment.config' + experiment_settings, + self.trainer.model, + Path(self.trainer.args.output_dir) / "experiment.config", ) experiment_metadata = ExperimentMetadata() self._save_experiment_metadata( - experiment_metadata, Path(self.trainer.args.output_dir) / 'experiment_metadata.json' + experiment_metadata, + Path(self.trainer.args.output_dir) / "experiment_metadata.json", ) self.trainer.train()