diff --git a/turbo_alignment/generators/chat.py b/turbo_alignment/generators/chat.py index e3580a9..d6570ad 100755 --- a/turbo_alignment/generators/chat.py +++ b/turbo_alignment/generators/chat.py @@ -68,15 +68,26 @@ def generate_from_batch_records( dataset_name: str, records: list[dict[str, Any]], original_records: bool = False, + time_profiler = None ) -> list[ChatInferenceOutput]: - import os + import time + import logging #TODO Make sure that records are already splitted between ranks(Assuming micro_rollout_batch_size equal to micro_batch_size) input_ids = records['input_ids'].tolist() rank = torch.distributed.get_rank() llm = self.vllm_engines[rank % len(self.vllm_engines)] - request_outputs = ray.get(llm.generate.remote(sampling_params=self._sampling_params, prompt_token_ids=input_ids, lora_request=self._lora_request)) + if torch.distributed.get_rank() == 0: + start = time.time() + + request_outputs = ray.get(llm.generate.remote(sampling_params=self._sampling_params, prompt_token_ids=input_ids, lora_request=self._lora_request)) + + if torch.distributed.get_rank() == 0: + end = time.time() + logging.info(f'Generation elapsed time:{end - start}') + time_profiler.completions_time.append(end - start) + outputs = [] for i, request_output in enumerate(request_outputs): answers = [] diff --git a/turbo_alignment/trainers/online/reinforce.py b/turbo_alignment/trainers/online/reinforce.py index fcd2a31..5f62940 100644 --- a/turbo_alignment/trainers/online/reinforce.py +++ b/turbo_alignment/trainers/online/reinforce.py @@ -44,6 +44,40 @@ import os from deepspeed.runtime.engine import DeepSpeedEngine import gc +import time +import logging + +class TimeProfiler: + def __init__(self): + self.broadcast_time = [] + self.completions_time = [] + self.reward_model_time = [] + self.reference_model_time = [] + self.policy_model_time = [] + self.total_time = [] + + def print(self): + import numpy as np + + additional_time = [] + + for length in range(len(self.total_time)): + other_time = 2 * self.total_time[length] - sum([self.__getattribute__(key)[length] for key in self.__dict__.keys()]) + additional_time.append(other_time) + + def mean_std_format(values): + mean = np.mean(values) + std = np.std(values) + return f"{mean:.2f} (±{std:.2f})" + + print(f"Broadcast Time: {mean_std_format(self.broadcast_time)}", flush=True) + print(f"Completions Time: {mean_std_format(self.completions_time)}", flush=True) + print(f"Reward Model Time: {mean_std_format(self.reward_model_time)}", flush=True) + print(f"Reference Model Time: {mean_std_format(self.reference_model_time)}", flush=True) + print(f"Policy Model Time: {mean_std_format(self.policy_model_time)}", flush=True) + print(f"Total Time: {mean_std_format(self.total_time)}", flush=True) + print(f"Additional Time: {mean_std_format(additional_time)}", flush=True) + # FIXME @dataclass @@ -85,10 +119,6 @@ def __init__( callbacks: list[TrainerCallback] | None = None, **kwargs, ) -> None: - import time - import logging - - start = time.time() super().__init__( model=policy, @@ -99,11 +129,10 @@ def __init__( callbacks=callbacks, **kwargs, ) - logging.info(f'super().__init__ elapsed time:{time.time() - start}') - print(f'{self.accelerator.num_processes=}', flush=True) + self.time_profiler = TimeProfiler() self.vllm_engines = vllm_engines - start = time.time() + if self.vllm_engines is not None and torch.distributed.get_rank() == 0: master_address = ray._private.services.get_node_ip_address() with socket.socket() as sock: @@ -149,7 +178,6 @@ def __init__( ray.get(refs) torch.distributed.barrier() - logging.info(f'distributed vllm engine __init__ elapsed time:{time.time() - start}') self.ref_model = ref_model @@ -200,13 +228,10 @@ def __init__( print("Generations Params:\n" + "\n".join([f"{attr}: {getattr(self.generator_transformers_settings, attr, None)}" for attr, _ in self.generator_transformers_settings.__annotations__.items()])) start = time.time() - self.print_readable_stats() self.norm_reward_mean, self.norm_reward_std = self.reward_stats( model=self.model, dataloader=self.get_train_dataloader() ) - - logging.info(f'statictis in __init__ elapsed time:{time.time() - start}') - self.print_readable_stats() + logging.info(f'Statictis calculation elapsed time:{time.time() - start}') def _broadcast_to_vllm(self, model: DeepSpeedEngine): # avoid OOM @@ -278,8 +303,6 @@ def get_answers_and_rewards( inputs: dict[str, torch.Tensor], do_broadcast=True ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - import time - import logging if torch.distributed.get_rank() == 0: print(f'Input shape: {inputs["input_ids"].shape}', flush=True) @@ -288,18 +311,23 @@ def get_answers_and_rewards( if do_broadcast: # TODO: move to generator torch.distributed.barrier() - start = time.time() + if torch.distributed.get_rank() == 0: + start = time.time() + self._broadcast_to_vllm(model) torch.distributed.barrier() - logging.info(f'broadcast elapsed time:{time.time() - start}') - start = time.time() + if torch.distributed.get_rank() == 0: + end = time.time() + self.time_profiler.broadcast_time.append(end - start) + logging.info(f'broadcast elapsed time:{end - start}') + + generator = self._get_chat_generator() generations: list[ChatInferenceOutput] = generator.generate_from_batch_records( - dataset_name='online', records=inputs, original_records=True + dataset_name='online', records=inputs, original_records=True, time_profiler = self.time_profiler ) - logging.info(f'generations elapsed time:{time.time() - start}') response_ids = [torch.cat([g.input_token_ids, ans.answer_token_ids], dim=1) for g in generations for ans in g.answers] response_attention_mask = [torch.cat([g.input_attention_mask, ans.answer_attention_mask], dim=1) for g in generations for ans in g.answers] @@ -335,11 +363,17 @@ def pad_sequences(sequences, max_length, pad_value=0): 'position_ids': position_ids, } - start = time.time() + if torch.distributed.get_rank() == 0: + start = time.time() + critic_generator = self._get_rm_generator(self.reward_model) rewards = critic_generator.generate_from_batch_records(rm_inputs) - logging.info(f'rewards elapsed time:{time.time() - start}') + if torch.distributed.get_rank() == 0: + end = time.time() + self.time_profiler.reward_model_time.append(end - start) + logging.info(f'Rewards elapsed time:{end - start}') + return response_ids, response_attention_mask, response_tokens_mask, position_ids, rewards def get_batch_loss_metrics( @@ -348,8 +382,6 @@ def get_batch_loss_metrics( inputs: dict[str, torch.Tensor], train_eval: Literal['train', 'eval'] = 'train', ): - import logging - import time with torch.no_grad(): ( @@ -363,7 +395,10 @@ def get_batch_loss_metrics( inputs=inputs, do_broadcast=True if train_eval == 'train' else False ) - start = time.time() + + if torch.distributed.get_rank() == 0: + start = time.time() + ref_logprobs = self.get_logprobs( model=self.ref_model, input_ids=query_response, @@ -371,16 +406,22 @@ def get_batch_loss_metrics( position_ids=position_ids, loss_mask=response_tokens_mask, ) - - logging.info(f'reference elapsed time:{time.time() - start}') + + if torch.distributed.get_rank() == 0: + end = time.time() + self.time_profiler.reference_model_time.append(end - start) + logging.info(f'Reference logprobs elapsed time:{end - start}') + rewards, valid_mask, rewards_metrics = self.process_rewards(rewards=rewards, query_response=query_response) del inputs gc.collect() torch.cuda.empty_cache() + + if torch.distributed.get_rank() == 0: + start = time.time() - start = time.time() logprobs = self.get_logprobs( model=model, input_ids=query_response, @@ -388,12 +429,10 @@ def get_batch_loss_metrics( position_ids=position_ids, loss_mask=response_tokens_mask, ) - logging.info(f'policy logrobs elapsed time:{time.time() - start}') - - # print(f'rank {torch.distributed.get_rank()}:{ref_logprobs.shape=}, {logprobs.shape=}', flush=True) - # print(f'rank {torch.distributed.get_rank()}:{ref_logprobs=}, {logprobs=}', flush=True) - # print(f'rank {torch.distributed.get_rank()}:{(ref_logprobs[0, -1, :] - logprobs[0, -1, :]).abs().sum()=}', flush=True) - # assert torch.allclose(ref_logprobs, logprobs) + if torch.distributed.get_rank() == 0: + end = time.time() + self.time_profiler.policy_model_time.append(end - start) + logging.info(f'Policy logrobs elapsed time:{end - start}') with torch.no_grad(): kl_term = logprobs.detach() - ref_logprobs @@ -449,16 +488,20 @@ def print_readable_stats(self): print(f"Reserved: {torch.cuda.memory_reserved() / (1024 ** 2):.2f} MB") def compute_loss(self, model, inputs, return_outputs: bool = False, num_items_in_batch=None): - import logging - import time - import gc - - start = time.time() + + if torch.distributed.get_rank() == 0: + start = time.time() loss, metrics = self.get_batch_loss_metrics(model, inputs, 'train') self.store_metrics(metrics=metrics, train_eval='train') - logging.info(f'get_batch_loss_metrics elapsed time:{time.time() - start}') + + if torch.distributed.get_rank() == 0: + end = time.time() + self.time_profiler.total_time.append(end - start) + logging.info(f'Total elapsed time:{end - start}') + self.time_profiler.print() + gc.collect() torch.cuda.empty_cache()