diff --git a/tests/fixtures/configs/train/reinforce/deepspeed_cfg.json b/tests/fixtures/configs/train/reinforce/deepspeed_cfg.json index b1b609c..173c612 100644 --- a/tests/fixtures/configs/train/reinforce/deepspeed_cfg.json +++ b/tests/fixtures/configs/train/reinforce/deepspeed_cfg.json @@ -1,29 +1,29 @@ { "fp16": { - "enabled": false - }, - "bf16": { + "enabled": false + }, + "bf16": { "enabled": true - }, - "optimizer": { - "type": "AdamW", - "params": { - "lr": "auto", - "betas": "auto", - "eps": "auto", - "weight_decay": "auto" - } - }, - "scheduler": { - "type": "WarmupDecayLR", - "params": { - "warmup_min_lr": "auto", - "warmup_max_lr": "auto", - "warmup_num_steps": "auto", - "total_num_steps": "auto" - } - }, - + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "total_num_steps": "auto" + } + }, + "zero_optimization": { "stage": 3, "offload_param": { @@ -34,16 +34,22 @@ "pin_memory": true }, "allgather_partitions": true, - "allgather_bucket_size": 2e8, + "allgather_bucket_size": 1e5, "overlap_comm": false, "reduce_scatter": true, - "reduce_bucket_size": 2e8, - "contiguous_gradients": true + "reduce_bucket_size": "auto", + "contiguous_gradients": true, + "sub_group_size": 1e5, + "stage3_param_persistence_threshold": "auto", + "stage3_prefetch_bucket_size": 0, + "stage3_max_live_parameters": 0, + "stage3_max_reuse_distance": 0, + "stage3_gather_16bit_weights_on_model_save": true }, - "gradient_accumulation_steps": "auto", - "gradient_clipping": "auto", - "steps_per_print": 2, - "train_batch_size": "auto", - "train_micro_batch_size_per_gpu": "auto", - "wall_clock_breakdown": false +"gradient_accumulation_steps": "auto", +"gradient_clipping": "auto", +"steps_per_print": 2, +"train_batch_size": "auto", +"train_micro_batch_size_per_gpu": "auto", +"wall_clock_breakdown": false } \ No newline at end of file diff --git a/tests/fixtures/configs/train/reinforce/gemma.json b/tests/fixtures/configs/train/reinforce/gemma.json index 252ac3e..e6409b9 100644 --- a/tests/fixtures/configs/train/reinforce/gemma.json +++ b/tests/fixtures/configs/train/reinforce/gemma.json @@ -1,195 +1,188 @@ { - "train_dataset_settings": { - "sources": [ - { - "name": "train_chat", - "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl", - "sample_rate": 1.0 - } - ], - "prompt_template": { - "role_tag_mapping": { - "bot": "assistant", - "user": "user", - "system": "system" - }, - "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n", - "suffix_template": "<|eot_id|>" + "train_dataset_settings": { + "sources": [ + { + "name": "train_chat", + "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl", + "sample_rate": 1.0 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "assistant", + "user": "user", + "system": "system" }, - "dataset_type": "chat", - "max_tokens_count": 2000, - "only_answer_loss": true - }, - "val_dataset_settings": { + "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n", + "suffix_template": "<|eot_id|>" + }, + "dataset_type": "chat", + "max_tokens_count": 2000, + "only_answer_loss": true + }, + "val_dataset_settings": { + "sources": [ + { + "name": "val_chat", + "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl", + "sample_rate": 1.0 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "assistant", + "user": "user", + "system": "system" + }, + "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n", + "suffix_template": "<|eot_id|>" + }, + "dataset_type": "chat", + "max_tokens_count": 2000, + "only_answer_loss": true + }, + "cherry_pick_settings": { + "generator_transformers_settings": { + "num_beams": 1, + "do_sample": false, + "stop_strings": "", + "max_new_tokens": 8 + }, + "custom_generation_settings": { + "skip_special_tokens": false + }, + "dataset_settings": { "sources": [ - { - "name": "val_chat", + { + "name": "chat_test", "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl", - "sample_rate": 1.0 - } + "num_samples": 2 + } ], "prompt_template": { - "role_tag_mapping": { + "role_tag_mapping": { "bot": "assistant", "user": "user", "system": "system" - }, - "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n", - "suffix_template": "<|eot_id|>" + }, + "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n", + "suffix_template": "<|eot_id|>" }, "dataset_type": "chat", - "max_tokens_count": 2000, + "max_tokens_count": 150, "only_answer_loss": true - }, - "cherry_pick_settings": { - "generator_transformers_settings": { - "num_beams": 1, - "do_sample": false, - "stop_strings": "", - "max_new_tokens": 8 - }, - "custom_generation_settings": { - "skip_special_tokens": false - }, - "dataset_settings": { - "sources": [ - { - "name": "chat_test", - "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl", - "num_samples": 2 - } - ], - "prompt_template": { - "role_tag_mapping": { - "bot": "assistant", - "user": "user", - "system": "system" - }, - "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n", - "suffix_template": "<|eot_id|>" - }, - "dataset_type": "chat", - "max_tokens_count": 150, - "only_answer_loss": true - }, - "metric_settings": [ - { - "type": "length", - "parameters": { - "need_average": [ - true - ] - } - }, - { - "type": "kl", - "parameters": { - "need_average": [ - true - ], - "ref_logits_type": "sft" - } - }, - { - "type": "kl", - "parameters": { - "need_average": [ - true - ], - "ref_logits_type": "reference" - } - } - ] - }, - "reward_model_settings": { - "model_path": "/from_s3/llama-3.2-1B", - "model_type": "seq_cls", - "model_kwargs": { - "num_labels": 1, - "attn_implementation": "flash_attention_2" - }, - "transformers_settings": {} - }, - "model_settings": { - "model_path": "/from_s3/llama-3.2-1B", - "model_type": "causal", - "liger_kernels_settings": { - "use_rope": true, - "use_cross_entropy": false, - "use_geglu": true, - "use_fused_linear_cross_entropy": false, - "use_rms_norm": false - }, - "model_kwargs": { - "attn_implementation": "flash_attention_2" - }, - "transformers_settings": {} - }, - "tokenizer_settings": { - "tokenizer_kwargs": { - "padding_side": "left" + }, + "metric_settings": [ + { + "type": "length", + "parameters": { + "need_average": [ + true + ] } - }, - "trainer_settings": { - "actor_settings": { - "actor_type": "distributed_vllm", - "vllm_num_engines": 6, - "vllm_tensor_parallel_size": 1 }, - "deepspeed": "tests/fixtures/configs/train/reinforce/deepspeed_cfg.json", - "gradient_checkpointing": true, - "gradient_checkpointing_kwargs": { - "use_reentrant": false + { + "type": "kl", + "parameters": { + "need_average": [ + true + ], + "ref_logits_type": "sft" + } }, - "critic_type": "ray_transformers", - "reward_processor_type": "rloo", - "evaluation_strategy": "steps", - "num_generations": 4, - "per_device_train_batch_size": 8, - "per_device_eval_batch_size": 8, - "gradient_accumulation_steps": 10, - "adam_beta1": 0.9, - "adam_beta2": 0.95, - "adam_epsilon": 0.00001, - "eval_steps": 50, - "save_steps": 100, - "save_strategy": "steps", - "load_best_model_at_end": false, - "logging_steps": 1, - "learning_rate": 0.00001, - "max_steps": 1001, - "lr_scheduler_type": "constant_with_warmup", - "warmup_ratio": 0.03, - "fp16": false, - "bf16": true, - "optim": "adamw_torch", - "weight_decay": 0.0, - "max_grad_norm": 2, - "save_total_limit": 11, - "dataloader_num_workers": 12, - "stop_token": "", - "temperature": 0.7, - "non_eos_penalty": true, - "penalty_reward_value": -1, - "clip_rewards_min": -1e+8, - "clip_rewards_max": 1e+8, - "whiten_rewards": false, - "kl_coef": 0.05, - "mean_baseline_coef": 0.95, - "num_samples_for_reward_stats": 1000, - "no_cuda": false - }, - "special_tokens_settings": { - "bos_token": "", - "eos_token": "", - "pad_token": "" - }, - "logging_settings": { - "project_name": "alignment", - "run_name": "sft", - "entity": "turbo-alignment" - }, - "seed": 0, - "log_path": "train_output" + { + "type": "kl", + "parameters": { + "need_average": [ + true + ], + "ref_logits_type": "reference" + } + } + ] + }, + "reward_model_settings": { + "model_path": "/from_s3/llama-3.2-1B", + "model_type": "seq_cls", + "model_kwargs": { + "num_labels": 1, + "attn_implementation": "flash_attention_2" + }, + "transformers_settings": {} + }, + "model_settings": { + "model_path": "/from_s3/llama-3.2-1B", + "model_type": "causal", + "model_kwargs": { + "attn_implementation": "flash_attention_2" + }, + "transformers_settings": {}, + "liger_kernels_settings": {} + }, + "tokenizer_settings": { + "tokenizer_kwargs": { + "padding_side": "left" } - \ No newline at end of file + }, + "trainer_settings": { + "actor_settings": { + "actor_type": "distributed_vllm", + "vllm_num_engines": 2, + "vllm_tensor_parallel_size": 1 + }, + "deepspeed": "tests/fixtures/configs/train/reinforce/deepspeed_cfg.json", + "gradient_checkpointing": true, + "gradient_checkpointing_kwargs": { + "use_reentrant": false + }, + "critic_type": "ray_transformers", + "reward_processor_type": "rloo", + "evaluation_strategy": "steps", + "num_generations": 4, + "per_device_train_batch_size": 1, + "per_device_eval_batch_size": 1, + "gradient_accumulation_steps": 1, + "adam_beta1": 0.9, + "adam_beta2": 0.95, + "adam_epsilon": 0.00001, + "eval_steps": 10, + "save_steps": 100, + "save_strategy": "steps", + "load_best_model_at_end": false, + "logging_steps": 1, + "learning_rate": 0.00001, + "max_steps": 1001, + "lr_scheduler_type": "constant_with_warmup", + "warmup_ratio": 0.03, + "fp16": false, + "bf16": true, + "optim": "adamw_torch", + "weight_decay": 0.0, + "max_grad_norm": 2, + "save_total_limit": 11, + "dataloader_num_workers": 12, + "stop_token": "", + "temperature": 0.7, + "non_eos_penalty": true, + "penalty_reward_value": -1, + "clip_rewards_min": -1e+8, + "clip_rewards_max": 1e+8, + "whiten_rewards": false, + "kl_coef": 0.05, + "mean_baseline_coef": 0.95, + "num_samples_for_reward_stats": 1000, + "no_cuda": false + }, + "special_tokens_settings": { + "bos_token": "", + "eos_token": "", + "pad_token": "" + }, + "logging_settings": { + "project_name": "alignment", + "run_name": "sft", + "entity": "turbo-alignment" + }, + "seed": 0, + "log_path": "train_output" +} \ No newline at end of file diff --git a/tests/fixtures/configs/train/reinforce/stage3.json b/tests/fixtures/configs/train/reinforce/stage3.json deleted file mode 100644 index d005463..0000000 --- a/tests/fixtures/configs/train/reinforce/stage3.json +++ /dev/null @@ -1,55 +0,0 @@ -{ - "fp16": { - "enabled": false - }, - "bf16": { - "enabled": true - }, - "optimizer": { - "type": "AdamW", - "params": { - "lr": "auto", - "betas": "auto", - "eps": "auto", - "weight_decay": "auto" - } - }, - "scheduler": { - "type": "WarmupDecayLR", - "params": { - "warmup_min_lr": "auto", - "warmup_max_lr": "auto", - "warmup_num_steps": "auto", - "total_num_steps": "auto" - } - }, - - "zero_optimization": { - "stage": 3, - "offload_param": { - "device": "none" - }, - "offload_optimizer": { - "device": "none", - "pin_memory": true - }, - "allgather_partitions": true, - "allgather_bucket_size": 1e5, - "overlap_comm": false, - "reduce_scatter": true, - "reduce_bucket_size": "auto", - "contiguous_gradients": true, - "sub_group_size": 1e5, - "stage3_param_persistence_threshold": "auto", - "stage3_prefetch_bucket_size": 0, - "stage3_max_live_parameters": 0, - "stage3_max_reuse_distance": 0, - "stage3_gather_16bit_weights_on_model_save": true - }, - "gradient_accumulation_steps": "auto", - "gradient_clipping": "auto", - "steps_per_print": 2, - "train_batch_size": "auto", - "train_micro_batch_size_per_gpu": "auto", - "wall_clock_breakdown": false -} \ No newline at end of file diff --git a/turbo_alignment/common/tf/loaders/model/model.py b/turbo_alignment/common/tf/loaders/model/model.py index 7be3103..9878e7e 100755 --- a/turbo_alignment/common/tf/loaders/model/model.py +++ b/turbo_alignment/common/tf/loaders/model/model.py @@ -46,12 +46,6 @@ def load_model( model_settings: PreTrainedModelSettings, tokenizer: PreTrainedTokenizerBase, ) -> PreTrainedModel: - if model_settings.liger_kernels_settings is not None: - apply_liger_kernel_to_gemma2( - rope=model_settings.liger_kernels_settings.use_rope, - cross_entropy=model_settings.liger_kernels_settings.use_cross_entropy, - geglu=model_settings.liger_kernels_settings.use_geglu, - ) model = TransformersAutoModelRegistry.by_name(model_settings.model_type).from_pretrained( model_settings.model_path, @@ -59,6 +53,13 @@ def load_model( **model_settings.model_kwargs, torch_dtype=torch.bfloat16, ) + if model_settings.liger_kernels_settings is not None: + apply_liger_kernel_to_gemma2( + rope=model_settings.liger_kernels_settings.use_rope, + cross_entropy=model_settings.liger_kernels_settings.use_cross_entropy, + geglu=model_settings.liger_kernels_settings.use_geglu, + model=model + ) import ray print(f'node_id:{ray.get_runtime_context().get_node_id()}, gpu_id: {ray.get_runtime_context().get_accelerator_ids()["GPU"]}', flush=True) diff --git a/turbo_alignment/generators/base.py b/turbo_alignment/generators/base.py index 86ddcfd..0f067ef 100755 --- a/turbo_alignment/generators/base.py +++ b/turbo_alignment/generators/base.py @@ -28,6 +28,7 @@ def __init__( tokenizer: PreTrainedTokenizerBase, accelerator: Accelerator | None = None, batch: int = 1, + return_logits: bool = False ) -> None: if accelerator is not None: self._model = accelerator.prepare_model(model, device_placement=True, evaluation_mode=True) diff --git a/turbo_alignment/generators/chat.py b/turbo_alignment/generators/chat.py index da0a678..bb88f67 100755 --- a/turbo_alignment/generators/chat.py +++ b/turbo_alignment/generators/chat.py @@ -81,6 +81,7 @@ def generate_from_batch_records( answers = [] for a in request_output.outputs: answer_token_ids=torch.tensor(a.token_ids).unsqueeze(0) + ans_msg = AnswerMessage( id=str(a.index), content=a.text, @@ -120,6 +121,7 @@ def generate_from_batch_records( dataset_name: str, records_batch: dict[str, torch.Tensor] | BatchEncoding, original_records: list[ChatDatasetRecord] | None = None, + return_logits: bool = None ) -> list[ChatInferenceOutput]: batched_input_ids = records_batch['input_ids'].to(self.device) batched_attention_mask = records_batch['attention_mask'].to(self.device) @@ -135,7 +137,7 @@ def generate_from_batch_records( logits: torch.Tensor | None = None - if self._custom_generation_settings.return_logits: + if return_logits: with torch.no_grad(): logits = self._model(output_indices).logits diff --git a/turbo_alignment/generators/rm.py b/turbo_alignment/generators/rm.py index 48681ca..7ff069d 100755 --- a/turbo_alignment/generators/rm.py +++ b/turbo_alignment/generators/rm.py @@ -18,7 +18,7 @@ def generate_from_batch_records(self, records: dict[str, torch.Tensor] | BatchEn with torch.no_grad(): # TODO assuming that only one reward model records = {k: v.cuda() for k, v in records.items()} - rewards = ray.get(self._model.async_forward(records))[0].logits + rewards = ray.get(self._model.async_forward(records))[0] return rewards # .squeeze() def generate_from_batch( diff --git a/turbo_alignment/pipelines/train/reinforce.py b/turbo_alignment/pipelines/train/reinforce.py index b6e97bc..f0a7e63 100644 --- a/turbo_alignment/pipelines/train/reinforce.py +++ b/turbo_alignment/pipelines/train/reinforce.py @@ -193,8 +193,14 @@ def run(self, experiment_settings: ExperimentSettingsT, vllm_engines, reference_ logger.info('Special tokens added!') + import time + + start = time.time() + self.model = self._load_model(experiment_settings, self.tokenizer) + print(f"Elapsed model load time: {time.time() - start} seconds") + special_tokens_setter.setup_model_config(self.model) logger.info('Model is loaded!') @@ -216,7 +222,9 @@ def run(self, experiment_settings: ExperimentSettingsT, vllm_engines, reference_ ) data_collator = self._get_data_collator(experiment_settings, self.tokenizer) - + + start = time.time() + self.trainer = self._get_trainer( vllm_engines, training_args, @@ -229,6 +237,9 @@ def run(self, experiment_settings: ExperimentSettingsT, vllm_engines, reference_ val_dataset, data_collator, ) + print(f"Elapsed get_trainer time: {time.time() - start} seconds") + + start = time.time() if self.trainer.accelerator.is_main_process: self._dataset_and_collator_sanity_check(train_dataset, data_collator) @@ -244,7 +255,7 @@ def run(self, experiment_settings: ExperimentSettingsT, vllm_engines, reference_ self._save_experiment_metadata( experiment_metadata, Path(self.trainer.args.output_dir) / 'experiment_metadata.json' ) - + print(f"Elapsed before trainer.train() time: {time.time() - start} seconds") self.trainer.train() self.trainer.save_model() diff --git a/turbo_alignment/trainers/online/ray/rayactor_group.py b/turbo_alignment/trainers/online/ray/rayactor_group.py index 6a2ac41..ad02a38 100644 --- a/turbo_alignment/trainers/online/ray/rayactor_group.py +++ b/turbo_alignment/trainers/online/ray/rayactor_group.py @@ -100,6 +100,10 @@ def _initiate_actors(self, pg, num_gpus_per_actor): ).remote(world_size, rank, local_rank, master_addr, master_port) self._actor_handlers.append(worker_actor) + def reference_forward(self, records: dict[str, torch.Tensor], temperature, loss_mask): + # TODO assuming one reference model + return self._actor_handlers[0].reference_forward.remote(records, temperature, loss_mask) + def async_forward(self, records: dict[str, torch.Tensor]): return [actor.forward.remote(records) for actor in self._actor_handlers] diff --git a/turbo_alignment/trainers/online/reference_actor.py b/turbo_alignment/trainers/online/reference_actor.py index 4b0769c..7aeda65 100644 --- a/turbo_alignment/trainers/online/reference_actor.py +++ b/turbo_alignment/trainers/online/reference_actor.py @@ -2,6 +2,7 @@ from turbo_alignment.trainers.online.ray.distributed_torch_ray_actor import DistributedTorchRayActor from transformers import AutoModelForCausalLM, AutoTokenizer import torch +import torch.nn.functional as F @ray.remote(num_gpus=1) class ReferenceModel(DistributedTorchRayActor): @@ -31,4 +32,21 @@ def eval(self): def forward(self, x): self.model.eval() x = {k: v.cuda() for k, v in x.items()} - return self.model(**x) \ No newline at end of file + return self.model(**x) + + @torch.no_grad + def reference_forward(self, x, temperature, loss_mask): + torch.cuda.empty_cache() + self.model.eval() + x = {k: v.cuda() for k, v in x.items()} + + logits = self.model(**x).logits[:, :-1] + logits /= temperature + all_logprob = F.log_softmax(logits, dim=-1) + + input_ids = x['input_ids'] + + logprob = torch.gather(all_logprob, 2, input_ids[:, 1:].unsqueeze(-1)).squeeze(-1) + logprob[~loss_mask[:, 1:].to(torch.bool)] = 0 + + return logprob.sum(-1) \ No newline at end of file diff --git a/turbo_alignment/trainers/online/reinforce.py b/turbo_alignment/trainers/online/reinforce.py index 611f7c8..e306b6b 100644 --- a/turbo_alignment/trainers/online/reinforce.py +++ b/turbo_alignment/trainers/online/reinforce.py @@ -82,6 +82,11 @@ def __init__( callbacks: list[TrainerCallback] | None = None, **kwargs, ) -> None: + import time + import logging + + start = time.time() + super().__init__( model=policy, args=args, @@ -91,9 +96,10 @@ def __init__( callbacks=callbacks, **kwargs, ) - + logging.info(f'super().__init__ elapsed time:{time.time() - start}') + self.vllm_engines = vllm_engines - + start = time.time() if self.vllm_engines is not None and torch.distributed.get_rank() == 0: master_address = ray._private.services.get_node_ip_address() with socket.socket() as sock: @@ -139,10 +145,11 @@ def __init__( ray.get(refs) torch.distributed.barrier() + logging.info(f'distributed vllm engine __init__ elapsed time:{time.time() - start}') self.ref_model = ref_model # TODO: delete later - ray.get(self.ref_model.async_eval()) + # ray.get(self.ref_model.async_eval()) # TODO: TODO_RLOO watch later # if self.ref_model is not None: @@ -150,7 +157,7 @@ def __init__( self.reward_model = reward_model # TODO: delete later - ray.get(self.ref_model.async_eval()) + # ray.get(self.ref_model.async_eval()) self._stored_metrics: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list)) self.kl_coef = args.kl_coef @@ -183,10 +190,11 @@ def __init__( mean_baseline_coef=args.mean_baseline_coef, num_generations=args.num_generations, ) - + start = time.time() self.norm_reward_mean, self.norm_reward_std = self.reward_stats( model=self.model, dataloader=self.get_train_dataloader() ) + logging.info(f'statictis in __init__ elapsed time:{time.time() - start}') def _broadcast_to_vllm(self, model: DeepSpeedEngine): @@ -259,7 +267,8 @@ def get_answers_and_rewards( inputs: dict[str, torch.Tensor], do_broadcast=True ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - + import time + import logging # for k, v in inputs.items(): # if isinstance(v, torch.Tensor): # print(f'get_answers_and_rewards:inputs: {k}, {v.device=}', flush=True) @@ -269,21 +278,31 @@ def get_answers_and_rewards( # print('broadcast', type(model), flush=True) # assert isinstance(model, DeepSpeedEngine) torch.distributed.barrier() + start = time.time() self._broadcast_to_vllm(model) torch.distributed.barrier() + logging.info(f'broadcast elapsed time:{time.time() - start}') # model = self.accelerator.unwrap_model(model) + start = time.time() generator = self._get_chat_generator() - + generations: list[ChatInferenceOutput] = generator.generate_from_batch_records( dataset_name='online', records=inputs ) - + logging.info(f'generations elapsed time:{time.time() - start}') + response_ids = [ans.answer_token_ids for g in generations for ans in g.answers] response_attention_mask = [ans.answer_attention_mask for g in generations for ans in g.answers] + import logging + logging.info(f'Generated sample: {response_ids[0].shape}') + logging.info(f'Last token: {response_ids[0].squeeze(0)[-1]}') + logging.info(f'Decoded: {self.tokenizer.decode([response_ids[0].squeeze(0)[-1]])}') + # Padding - max_length = max(response_id.size(1) for response_id in response_ids) + max_length = 2048 # max(response_id.size(1) for response_id in response_ids) + def pad_sequences(sequences, max_length, pad_value=0): padded_sequences = [] for seq in sequences: @@ -312,10 +331,11 @@ def pad_sequences(sequences, max_length, pad_value=0): # print(f'get_answers_and_rewards:rm_inputs: {k}, {v.device=}', flush=True) # accelerator.num_process_id - + start = time.time() critic_generator = self._get_rm_generator(self.reward_model) rewards = critic_generator.generate_from_batch_records(rm_inputs) # print(f'rewards: {rewards.device}', flush=True) + logging.info(f'rewards elapsed time:{time.time() - start}') return response_ids, response_attention_mask, response_tokens_mask, position_ids, rewards def get_batch_loss_metrics( @@ -324,6 +344,9 @@ def get_batch_loss_metrics( inputs: dict[str, torch.Tensor], train_eval: Literal['train', 'eval'] = 'train', ): + import logging + import time + with torch.no_grad(): ( query_response, @@ -334,8 +357,9 @@ def get_batch_loss_metrics( ) = self.get_answers_and_rewards( model=model, inputs=inputs, + do_broadcast=True if train_eval == 'train' else False ) - + start = time.time() ref_logprobs = self.get_logprobs( model=self.ref_model, input_ids=query_response, @@ -343,9 +367,10 @@ def get_batch_loss_metrics( position_ids=position_ids, loss_mask=response_tokens_mask.to(torch.float32), ) - + logging.info(f'reference elapsed time:{time.time() - start}') rewards, valid_mask, rewards_metrics = self.process_rewards(rewards=rewards, query_response=query_response) - + + start = time.time() logprobs = self.get_logprobs( model=model, input_ids=query_response, @@ -353,7 +378,7 @@ def get_batch_loss_metrics( position_ids=position_ids, loss_mask=response_tokens_mask, ) - + logging.info(f'policy logrobs elapsed time:{time.time() - start}') with torch.no_grad(): kl_term = logprobs.detach() - ref_logprobs regularized_rewards = rewards - self.kl_coef * kl_term @@ -402,10 +427,15 @@ def get_batch_loss_metrics( return loss.mean(), metrics def compute_loss(self, model, inputs, return_outputs: bool = False, num_items_in_batch=None): + import logging + import time + logging.info(f'{isinstance(model, DeepSpeedEngine)=}') + start = time.time() + loss, metrics = self.get_batch_loss_metrics(model, inputs, 'train') self.store_metrics(metrics=metrics, train_eval='train') - + logging.info(f'Compute Loss elapsed time:{time.time() - start}') return (loss, metrics) if return_outputs else loss def prediction_step( @@ -415,6 +445,10 @@ def prediction_step( prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + + import logging + logging.info(f'{isinstance(model, DeepSpeedEngine)=}') + with torch.no_grad(): loss, metrics = self.get_batch_loss_metrics(model, inputs, 'eval') @@ -475,7 +509,7 @@ def get_logprobs( #'use_cache': False } #TODO only one reference model -> maybe delete from group?? - logits = ray.get(model.async_forward(records))[0].logits[:, :-1] + return ray.get(model.reference_forward(records, self.args.temperature, loss_mask)) else: logits = model( input_ids=input_ids, @@ -484,13 +518,13 @@ def get_logprobs( use_cache=False, ).logits[:, :-1] - logits /= self.args.temperature - all_logprob = F.log_softmax(logits, dim=-1) - # print(f'{all_logprob.device=}, {input_ids.device=}') - logprob = torch.gather(all_logprob, 2, input_ids[:, 1:].unsqueeze(-1)).squeeze(-1) - logprob[~loss_mask[:, 1:].to(torch.bool)] = 0 + logits /= self.args.temperature + all_logprob = F.log_softmax(logits, dim=-1) + # print(f'{all_logprob.device=}, {input_ids.device=}') + logprob = torch.gather(all_logprob, 2, input_ids[:, 1:].unsqueeze(-1)).squeeze(-1) + logprob[~loss_mask[:, 1:].to(torch.bool)] = 0 - return logprob.sum(-1) + return logprob.sum(-1) def fill_nonvalid_rewards(self, rewards, query_response) -> Tuple[torch.Tensor, torch.Tensor]: if self.args.non_eos_penalty: diff --git a/turbo_alignment/trainers/online/reward_actor.py b/turbo_alignment/trainers/online/reward_actor.py index a3368bc..c2adceb 100644 --- a/turbo_alignment/trainers/online/reward_actor.py +++ b/turbo_alignment/trainers/online/reward_actor.py @@ -33,9 +33,10 @@ def eval(self): @torch.no_grad def forward(self, x): + torch.cuda.empty_cache() self.model.eval() x = {k: v.cuda() for k, v in x.items()} # print('\n\n\nRM model', [v.shape for k, v in x.items()], 'RM model\n\n\n') # print(self.tokenizer.decode(x['input_ids'][0], skip_special_tokens=False)) #TODO reward from eos - return self.model(**x) \ No newline at end of file + return self.model(**x).logits \ No newline at end of file