Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
Boris Shaposhnikov committed Sep 9, 2024
1 parent b987e9c commit 6696370
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 67 deletions.
63 changes: 30 additions & 33 deletions turbo_alignment/dataset/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ def enum_to_string_in_dict(conv: list[dict[str, Any]]):

class ChatDataset(AlignmentDataset[ChatDatasetRecord], ABC):
def __init__(
self,
source: DatasetSourceSettings,
settings: ChatDatasetSettings,
tokenizer: PreTrainedTokenizerBase,
read: bool = True,
self,
source: DatasetSourceSettings,
settings: ChatDatasetSettings,
tokenizer: PreTrainedTokenizerBase,
read: bool = True,
) -> None:
super().__init__(source=source, settings=settings, tokenizer=tokenizer)
self.settings: ChatDatasetSettings = settings
Expand All @@ -129,10 +129,10 @@ def _print(self, tokens: Tokens, prefix: str | None = None):
loguru.logger.info("{}{}", prefix, [self.tokenizer.batch_decode([id])[0] for id in tokens.input_ids])

def _encode(
self,
records: list[ChatDatasetRecord],
inference: bool,
random_cut: bool,
self,
records: list[ChatDatasetRecord],
inference: bool,
random_cut: bool,
) -> list[dict[str, Any] | None]:
output = []
for record in records:
Expand All @@ -151,7 +151,7 @@ def _encode(
padding=False,
truncation=False,
return_tensors="np",
meta=record.meta
meta=record.meta,
)
input_ids = tokenized_conversation["input_ids"][0]

Expand All @@ -161,13 +161,13 @@ def _encode(
binary_seq_str = ''.join('1' if i else '0' for i in assistant_masks)
last_replica_start_ind = binary_seq_str.rfind('01')

result_str = binary_seq_str[:last_replica_start_ind + 1].replace('1', '0') + binary_seq_str[
last_replica_start_ind + 1:]
result_str = (
binary_seq_str[: last_replica_start_ind + 1].replace('1', '0')
+ binary_seq_str[last_replica_start_ind + 1 :]
)
assistant_masks = [i == '1' for i in result_str]

tokens = Tokens(
input_ids, np.where(assistant_masks, input_ids, DISABLE_LOSS_LABEL)
)
tokens = Tokens(input_ids, np.where(assistant_masks, input_ids, DISABLE_LOSS_LABEL))
replica_start_indices = get_target_value_indices(input_ids, self.replica_start_ids)
replica_end_indices = get_target_value_indices(input_ids, self.replica_end_ids)

Expand Down Expand Up @@ -196,9 +196,8 @@ def _encode(
max_length = max_length - (len(prefix) + len(suffix))

while (
replica_start_indices[right_replica_ind_exclusive]
- replica_start_indices[left_replica_ind_inclusive]
> max_length
replica_start_indices[right_replica_ind_exclusive] - replica_start_indices[left_replica_ind_inclusive]
> max_length
):
if self.settings.keep_end:
left_replica_ind_inclusive += 1
Expand All @@ -209,11 +208,11 @@ def _encode(
right_replica_ind_exclusive = random.randint(left_replica_ind_inclusive, right_replica_ind_exclusive)

while right_replica_ind_exclusive > 0 and (
(inference and conversation.messages[right_replica_ind_exclusive - 1].role == ChatMessageRole.BOT)
or (
not inference
and conversation.messages[right_replica_ind_exclusive - 1].role == ChatMessageRole.USER
)
(inference and conversation.messages[right_replica_ind_exclusive - 1].role == ChatMessageRole.BOT)
or (
not inference
and conversation.messages[right_replica_ind_exclusive - 1].role == ChatMessageRole.USER
)
):
right_replica_ind_exclusive -= 1

Expand Down Expand Up @@ -244,13 +243,11 @@ def _encode(

@staticmethod
@overload
def _read_records(records: Path) -> list[ChatDatasetRecord]:
...
def _read_records(records: Path) -> list[ChatDatasetRecord]: ...

@staticmethod
@overload
def _read_records(records: list[dict]) -> list[ChatDatasetRecord]:
...
def _read_records(records: list[dict]) -> list[ChatDatasetRecord]: ...

@staticmethod
def _read_records(records) -> list[ChatDatasetRecord]:
Expand All @@ -270,12 +267,12 @@ def convert_records(self, records: list[ChatDatasetRecord]) -> list[dict[str, An
@ChatDatasetTypeRegistry.register(DatasetStrategy.INFERENCE)
class InferenceChatDataset(ChatDataset):
def __init__(
self,
source: DatasetSourceSettings,
settings: ChatDatasetSettings,
tokenizer: PreTrainedTokenizerBase,
read: bool = True,
random_cut: bool = False,
self,
source: DatasetSourceSettings,
settings: ChatDatasetSettings,
tokenizer: PreTrainedTokenizerBase,
read: bool = True,
random_cut: bool = False,
) -> None:
self._random_cut = random_cut

Expand Down
74 changes: 40 additions & 34 deletions turbo_alignment/pipelines/train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

logger = get_project_logger()

ExperimentSettingsT = TypeVar('ExperimentSettingsT', bound=BaseTrainExperimentSettings)
ExperimentSettingsT = TypeVar("ExperimentSettingsT", bound=BaseTrainExperimentSettings)


class BaseTrainStrategy(S3Mixin, LoggingMixin, BaseStrategy, Generic[ExperimentSettingsT]):
Expand All @@ -46,18 +46,19 @@ def _get_wandb_callback(wandb_run: Run | RunDisabled) -> BaseWandbCallback:
@staticmethod
@abstractmethod
def _get_cherry_pick_callback(
experiment_settings: ExperimentSettingsT,
tokenizer: PreTrainedTokenizerBase,
**kwargs,
) -> CherryPickCallbackBase | None:
...
experiment_settings: ExperimentSettingsT,
tokenizer: PreTrainedTokenizerBase,
**kwargs,
) -> CherryPickCallbackBase | None: ...

@staticmethod
def _save_experiment_config(
experiment_settings: ExperimentSettingsT, model: PreTrainedModel, output_path: Path
experiment_settings: ExperimentSettingsT,
model: PreTrainedModel,
output_path: Path,
) -> None:
model_config = model.config.to_dict()
model_config['experiment_config'] = experiment_settings
model_config["experiment_config"] = experiment_settings
write_json(model_config, output_path)

@staticmethod
Expand All @@ -67,45 +68,47 @@ def _save_experiment_metadata(experiment_metadata: ExperimentMetadata, output_pa
@staticmethod
@abstractmethod
def _get_data_collator(
experiment_settings: ExperimentSettingsT, tokenizer: PreTrainedTokenizerBase, **kwargs
) -> Callable:
...
experiment_settings: ExperimentSettingsT,
tokenizer: PreTrainedTokenizerBase,
**kwargs,
) -> Callable: ...

@staticmethod
@abstractmethod
def _get_trainer(
training_args: TrainingArguments,
experiment_settings: ExperimentSettingsT,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
train_dataset: Dataset,
val_dataset: Dataset,
data_collator: Callable,
) -> Trainer:
...
training_args: TrainingArguments,
experiment_settings: ExperimentSettingsT,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
train_dataset: Dataset,
val_dataset: Dataset,
data_collator: Callable,
) -> Trainer: ...

@staticmethod
def _load_model(
experiment_settings: ExperimentSettingsT, tokenizer: PreTrainedTokenizerBase
experiment_settings: ExperimentSettingsT, tokenizer: PreTrainedTokenizerBase
) -> torch.nn.Module | PreTrainedModel:
return load_model(experiment_settings.model_settings, tokenizer)

@staticmethod
@abstractmethod
def _get_training_args(experiment_settings: ExperimentSettingsT) -> TrainingArguments:
...
def _get_training_args(
experiment_settings: ExperimentSettingsT,
) -> TrainingArguments: ...

@staticmethod
def _load_tokenizer(experiment_settings: ExperimentSettingsT) -> PreTrainedTokenizerBase:
def _load_tokenizer(
experiment_settings: ExperimentSettingsT,
) -> PreTrainedTokenizerBase:
return load_tokenizer(experiment_settings.tokenizer_settings, experiment_settings.model_settings)

@abstractmethod
def _dataset_and_collator_sanity_check(self, dataset: Dataset, collator: Callable) -> None:
...
def _dataset_and_collator_sanity_check(self, dataset: Dataset, collator: Callable) -> None: ...

@staticmethod
def _get_additional_special_tokens(
experiment_settings: BaseTrainExperimentSettings,
experiment_settings: BaseTrainExperimentSettings,
) -> list[str]:
embeddings_initialization_strategy = experiment_settings.model_settings.embeddings_initialization_strategy
return list(embeddings_initialization_strategy.keys()) if embeddings_initialization_strategy else []
Expand All @@ -123,7 +126,7 @@ def _add_trainer_callbacks(self, experiment_settings: ExperimentSettingsT, **kwa
if experiment_settings.checkpoint_uploader_callback_parameters is not None:
if self.trainer.is_deepspeed_enabled and self.trainer.args.deepspeed_plugin.zero_stage == 3:
raise NotImplementedError(
'You should not use checkpoint uploader callback when DeepSpeed ZeRO stage 3 enabled'
"You should not use checkpoint uploader callback when DeepSpeed ZeRO stage 3 enabled"
)
s3_handler = self._get_s3_handler(S3HandlerParameters())

Expand All @@ -141,21 +144,21 @@ def run(self, experiment_settings: ExperimentSettingsT) -> None:

self.tokenizer = self._load_tokenizer(experiment_settings)

logger.info('Tokenizer is loaded!')
logger.info("Tokenizer is loaded!")

additional_special_tokens = self._get_additional_special_tokens(experiment_settings)
logger.info(f'Special tokens: {additional_special_tokens}')
logger.info(f"Special tokens: {additional_special_tokens}")
special_tokens_setter = SpecialTokensSetter(self.tokenizer)
special_tokens_setter.set_all()
special_tokens_setter.set_custom_tokens(additional_special_tokens)

logger.info('Special tokens added!')
logger.info("Special tokens added!")

self.model = self._load_model(experiment_settings, self.tokenizer)

special_tokens_setter.setup_model_config(self.model)

logger.info('Model is loaded!')
logger.info("Model is loaded!")

train_dataset: ConcatDataset = ConcatDataset(
datasets=DatasetLoader().load_datasets(
Expand Down Expand Up @@ -192,12 +195,15 @@ def run(self, experiment_settings: ExperimentSettingsT) -> None:

os.makedirs(self.trainer.args.output_dir, exist_ok=True)
self._save_experiment_config(
experiment_settings, self.trainer.model, Path(self.trainer.args.output_dir) / 'experiment.config'
experiment_settings,
self.trainer.model,
Path(self.trainer.args.output_dir) / "experiment.config",
)

experiment_metadata = ExperimentMetadata()
self._save_experiment_metadata(
experiment_metadata, Path(self.trainer.args.output_dir) / 'experiment_metadata.json'
experiment_metadata,
Path(self.trainer.args.output_dir) / "experiment_metadata.json",
)

self.trainer.train()
Expand Down

0 comments on commit 6696370

Please sign in to comment.