diff --git a/turbo_alignment/cli/train.py b/turbo_alignment/cli/train.py index d04a89c..d50792e 100755 --- a/turbo_alignment/cli/train.py +++ b/turbo_alignment/cli/train.py @@ -153,7 +153,7 @@ def reinforce_training( seed=0, enable_prefix_caching=False, enforce_eager=False, - max_model_len=1024, + max_model_len=experiment_settings.trainer_settings.actor_settings.max_model_len, ) ray.get(policy_models.async_fit_actor_model( diff --git a/turbo_alignment/common/tf/loaders/model/model.py b/turbo_alignment/common/tf/loaders/model/model.py index 9878e7e..e8ab8af 100755 --- a/turbo_alignment/common/tf/loaders/model/model.py +++ b/turbo_alignment/common/tf/loaders/model/model.py @@ -1,4 +1,9 @@ import torch +from liger_kernel.transformers import ( + apply_liger_kernel_to_gemma2, + apply_liger_kernel_to_llama, + apply_liger_kernel_to_qwen2, +) from peft import PeftModel, get_peft_model, prepare_model_for_int8_training from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled @@ -7,7 +12,6 @@ PeftConfigRegistry, TransformersAutoModelRegistry, ) -from turbo_alignment.modeling.liger_kernels import apply_liger_kernel_to_gemma2 from turbo_alignment.settings.model import ( ModelForPeftSettings, ModelType, @@ -46,6 +50,30 @@ def load_model( model_settings: PreTrainedModelSettings, tokenizer: PreTrainedTokenizerBase, ) -> PreTrainedModel: + if model_settings.liger_kernels_settings is not None: + apply_liger_kernel_to_llama( + rope=model_settings.liger_kernels_settings.use_rope, + cross_entropy=model_settings.liger_kernels_settings.use_cross_entropy, + swiglu=model_settings.liger_kernels_settings.use_mlp, + rms_norm=model_settings.liger_kernels_settings.use_rms_norm, + fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_linear_cross_entropy, + ) + + apply_liger_kernel_to_gemma2( + rope=model_settings.liger_kernels_settings.use_rope, + cross_entropy=model_settings.liger_kernels_settings.use_cross_entropy, + geglu=model_settings.liger_kernels_settings.use_mlp, + rms_norm=model_settings.liger_kernels_settings.use_rms_norm, + fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_linear_cross_entropy, + ) + + apply_liger_kernel_to_qwen2( + rope=model_settings.liger_kernels_settings.use_rope, + cross_entropy=model_settings.liger_kernels_settings.use_cross_entropy, + swiglu=model_settings.liger_kernels_settings.use_mlp, + rms_norm=model_settings.liger_kernels_settings.use_rms_norm, + fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_linear_cross_entropy, + ) model = TransformersAutoModelRegistry.by_name(model_settings.model_type).from_pretrained( model_settings.model_path, @@ -53,15 +81,6 @@ def load_model( **model_settings.model_kwargs, torch_dtype=torch.bfloat16, ) - if model_settings.liger_kernels_settings is not None: - apply_liger_kernel_to_gemma2( - rope=model_settings.liger_kernels_settings.use_rope, - cross_entropy=model_settings.liger_kernels_settings.use_cross_entropy, - geglu=model_settings.liger_kernels_settings.use_geglu, - model=model - ) - import ray - print(f'node_id:{ray.get_runtime_context().get_node_id()}, gpu_id: {ray.get_runtime_context().get_accelerator_ids()["GPU"]}', flush=True) if model_settings.transformers_settings.load_in_8bit: model = prepare_model_for_int8_training(model) @@ -99,4 +118,4 @@ def load_model( ) model.base_model.model.score.weight.requires_grad = True - return model + return model \ No newline at end of file diff --git a/turbo_alignment/common/tf/special_tokens_setter.py b/turbo_alignment/common/tf/special_tokens_setter.py index b1aaeb6..5fb4353 100755 --- a/turbo_alignment/common/tf/special_tokens_setter.py +++ b/turbo_alignment/common/tf/special_tokens_setter.py @@ -77,7 +77,7 @@ def setSEP(self, sep_token: str | None) -> None: return None def set_all(self) -> None: - self.setBOS(bos_token=self._special_tokens_settings.bos_token) + # self.setBOS(bos_token=self._special_tokens_settings.bos_token) self.setEOS(eos_token=self._special_tokens_settings.eos_token) self.setPAD(pad_token=self._special_tokens_settings.pad_token) self.setUNK(unk_token=self._special_tokens_settings.unk_token) @@ -94,7 +94,7 @@ def set_custom_tokens(self, tokens: list[str]) -> None: assert added_tokens == len(tokens) def setup_model_config(self, model: PreTrainedModel) -> None: - model.config.bos_token_id = self._tokenizer.bos_token_id + # model.config.bos_token_id = self._tokenizer.bos_token_id model.config.eos_token_id = self._tokenizer.eos_token_id if self._tokenizer.pad_token_id is not None: diff --git a/turbo_alignment/dataset/chat/chat.py b/turbo_alignment/dataset/chat/chat.py index ebdede9..48309c0 100755 --- a/turbo_alignment/dataset/chat/chat.py +++ b/turbo_alignment/dataset/chat/chat.py @@ -247,8 +247,8 @@ def _truncate_and_merge( input_ids = input_ids[cut_index:] labels = labels[-len(input_ids) :] - input_ids = np.concatenate((np.array([self.tokenizer.bos_token_id]), input_ids)) - labels = np.concatenate((np.array([DISABLE_LOSS_LABEL]), labels)) + # input_ids = np.concatenate((np.array([self.tokenizer.bos_token_id]), input_ids)) + # labels = np.concatenate((np.array([DISABLE_LOSS_LABEL]), labels)) return input_ids, labels, conversation.get_prompt_repr(left_bound, right_bound) diff --git a/turbo_alignment/generators/chat.py b/turbo_alignment/generators/chat.py index bb88f67..0b6687e 100755 --- a/turbo_alignment/generators/chat.py +++ b/turbo_alignment/generators/chat.py @@ -42,7 +42,8 @@ def __init__( # if transformers_settings.num_beams > 1: # beam_search_params['use_beam_search'] = True # beam_search_params['best_of'] = transformers_settings.num_beams - + print(f'SET MAX_NEW_TOKENS={transformers_settings.max_new_tokens}', flush=True) + self._sampling_params = SamplingParams( n=transformers_settings.num_return_sequences, repetition_penalty=transformers_settings.repetition_penalty, diff --git a/turbo_alignment/settings/model.py b/turbo_alignment/settings/model.py index 8ae1bb0..e0221d5 100755 --- a/turbo_alignment/settings/model.py +++ b/turbo_alignment/settings/model.py @@ -1,6 +1,8 @@ from enum import Enum from pathlib import Path +from pydantic import model_validator + from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel from turbo_alignment.settings.tf.model import ModelTransformersSettings from turbo_alignment.settings.tf.peft import PEFT_TYPE @@ -15,11 +17,21 @@ class ModelType(str, Enum): class LigerKernelSettings(ExtraFieldsNotAllowedBaseModel): use_rope: bool = True - use_cross_entropy: bool = True - use_geglu: bool = True + use_cross_entropy: bool = False use_fused_linear_cross_entropy: bool = False + use_mlp: bool = True use_rms_norm: bool = False + @model_validator(mode='after') + def check_cross_entopy_kernels(self) -> 'LigerKernelSettings': + if self.use_fused_linear_cross_entropy and self.use_cross_entropy: + raise ValueError( + 'You cannot use both FusedLinearCrossEntropy and CrossEntropy kernels. ' + 'FusedLinearCrossEntropy is preferred if possible.' + ) + + return self + class PreTrainedModelSettings(ExtraFieldsNotAllowedBaseModel): model_path: Path @@ -45,4 +57,4 @@ class ModelForPeftSettings(PreTrainedModelSettings): class PreTrainedMultiModalModel(PreTrainedAdaptersModelSettings): projections_path: Path - n_modality_embeddings: int + n_modality_embeddings: int \ No newline at end of file diff --git a/turbo_alignment/settings/online.py b/turbo_alignment/settings/online.py index 39bc1e4..4916aac 100644 --- a/turbo_alignment/settings/online.py +++ b/turbo_alignment/settings/online.py @@ -11,6 +11,7 @@ class vLLMActorSettings(ExtraFieldsNotAllowedBaseModel): vllm_num_engines: int = 1 vllm_tensor_parallel_size: int = 1 + max_model_len: int = 8192 class HFActorSettings(ExtraFieldsNotAllowedBaseModel): actor_type: Literal[ActorType.LOCAL_TRANSFORMERS] = ActorType.LOCAL_TRANSFORMERS diff --git a/turbo_alignment/settings/pipelines/train/reinforce.py b/turbo_alignment/settings/pipelines/train/reinforce.py index b4e140a..6ac54dc 100644 --- a/turbo_alignment/settings/pipelines/train/reinforce.py +++ b/turbo_alignment/settings/pipelines/train/reinforce.py @@ -16,7 +16,7 @@ from typing import Union class REINFORCETrainerSettings(TrainerSettings): - max_tokens_count: int = 1024 + max_new_tokens: int = 1024 stop_token: str = '' penalty_reward_value: float = 0.1 diff --git a/turbo_alignment/trainers/online/reinforce.py b/turbo_alignment/trainers/online/reinforce.py index 7d01b0a..5e6e2b9 100644 --- a/turbo_alignment/trainers/online/reinforce.py +++ b/turbo_alignment/trainers/online/reinforce.py @@ -47,7 +47,7 @@ # FIXME @dataclass class REINFORCETrainingArguments(TrainingArguments): - max_tokens_count: int = 1024 + max_new_tokens: int = 1024 stop_token: str = '' penalty_reward_value: float = 0.1 @@ -177,7 +177,7 @@ def __init__( self.generator_transformers_settings = GeneratorTransformersSettings( temperature=args.temperature, - max_length=args.max_tokens_count, + max_new_tokens=args.max_new_tokens, num_return_sequences=args.num_generations, num_beams=args.num_generations, do_sample=True, @@ -193,12 +193,10 @@ def __init__( num_generations=args.num_generations, ) start = time.time() - torch.cuda.memory._dump_snapshot("before_reward_stats.pickle") self.print_readable_stats() self.norm_reward_mean, self.norm_reward_std = self.reward_stats( model=self.model, dataloader=self.get_train_dataloader() ) - torch.cuda.memory._dump_snapshot("after_stats.pickle") logging.info(f'statictis in __init__ elapsed time:{time.time() - start}') self.print_readable_stats() @@ -275,21 +273,19 @@ def get_answers_and_rewards( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: import time import logging - for k, v in inputs.items(): - if isinstance(v, torch.Tensor): - print(f'get_answers_and_rewards:inputs: {k}:{v.shape}', flush=True) + + if torch.distributed.get_rank() == 0: + print(f'Input shape: {inputs["input_ids"].shape}', flush=True) + print(f'Input Example at index [0]: {self.tokenizer.batch_decode(inputs["input_ids"][0, :].unsqueeze(0))}') if do_broadcast: # TODO: move to generator - # print('broadcast', type(model), flush=True) - # assert isinstance(model, DeepSpeedEngine) torch.distributed.barrier() start = time.time() self._broadcast_to_vllm(model) torch.distributed.barrier() logging.info(f'broadcast elapsed time:{time.time() - start}') - # model = self.accelerator.unwrap_model(model) start = time.time() generator = self._get_chat_generator() @@ -301,13 +297,13 @@ def get_answers_and_rewards( response_ids = [ans.answer_token_ids for g in generations for ans in g.answers] response_attention_mask = [ans.answer_attention_mask for g in generations for ans in g.answers] - import logging - logging.info(f'Generated sample: {response_ids[0].shape}') - logging.info(f'Last token: {response_ids[0].squeeze(0)[-1]}') - logging.info(f'Decoded: {self.tokenizer.decode([response_ids[0].squeeze(0)[-1]])}') + if torch.distributed.get_rank() == 0: + print(f'Generated completion at index [0] shape: {response_ids[0].shape}', flush=True) + print(f'Completion decoded: {self.tokenizer.batch_decode(response_ids[0, :].unsqueeze(0))}', flush=True) # Padding - max_length = 8192 # max(response_id.size(1) for response_id in response_ids) #45.89 GiB + + max_length = max([response_id.size(1) for response_id in response_ids]) logging.info(f'{max_length=}') def pad_sequences(sequences, max_length, pad_value=0): @@ -332,10 +328,6 @@ def pad_sequences(sequences, max_length, pad_value=0): 'attention_mask': response_attention_mask, 'position_ids': position_ids, } - - # for k, v in rm_inputs.items(): - # if isinstance(v, torch.Tensor): - # print(f'get_answers_and_rewards:rm_inputs: {k}, {v.device=}', flush=True) start = time.time() critic_generator = self._get_rm_generator(self.reward_model) @@ -380,7 +372,6 @@ def get_batch_loss_metrics( del inputs gc.collect() torch.cuda.empty_cache() - torch.cuda.memory._dump_snapshot("before_inference.pickle") logprobs = self.get_logprobs( model=model, @@ -445,14 +436,10 @@ def print_readable_stats(self): print(f"Reserved: {torch.cuda.memory_reserved() / (1024 ** 2):.2f} MB") def compute_loss(self, model, inputs, return_outputs: bool = False, num_items_in_batch=None): - print(f'Before batch_loss metrics\n\n') - self.print_readable_stats() - import logging import time import gc - logging.info(f'{isinstance(model, DeepSpeedEngine)=}') start = time.time() loss, metrics = self.get_batch_loss_metrics(model, inputs, 'train') @@ -461,9 +448,6 @@ def compute_loss(self, model, inputs, return_outputs: bool = False, num_items_in logging.info(f'Compute Loss elapsed time:{time.time() - start}') gc.collect() torch.cuda.empty_cache() - print(f'After batch_loss metrics\n\n') - self.print_readable_stats() - torch.cuda.memory._dump_snapshot("after_compute_loss.pickle") return (loss, metrics) if return_outputs else loss