Skip to content

Commit

Permalink
time profiler
Browse files Browse the repository at this point in the history
  • Loading branch information
thehir0 committed Dec 5, 2024
1 parent 18016c3 commit b94156e
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 41 deletions.
15 changes: 13 additions & 2 deletions turbo_alignment/generators/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,26 @@ def generate_from_batch_records(
dataset_name: str,
records: list[dict[str, Any]],
original_records: bool = False,
time_profiler = None
) -> list[ChatInferenceOutput]:
import os
import time
import logging
#TODO Make sure that records are already splitted between ranks(Assuming micro_rollout_batch_size equal to micro_batch_size)
input_ids = records['input_ids'].tolist()

rank = torch.distributed.get_rank()
llm = self.vllm_engines[rank % len(self.vllm_engines)]
request_outputs = ray.get(llm.generate.remote(sampling_params=self._sampling_params, prompt_token_ids=input_ids, lora_request=self._lora_request))

if torch.distributed.get_rank() == 0:
start = time.time()

request_outputs = ray.get(llm.generate.remote(sampling_params=self._sampling_params, prompt_token_ids=input_ids, lora_request=self._lora_request))

if torch.distributed.get_rank() == 0:
end = time.time()
logging.info(f'Generation elapsed time:{end - start}')
time_profiler.completions_time.append(end - start)

outputs = []
for i, request_output in enumerate(request_outputs):
answers = []
Expand Down
121 changes: 82 additions & 39 deletions turbo_alignment/trainers/online/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,40 @@
import os
from deepspeed.runtime.engine import DeepSpeedEngine
import gc
import time
import logging

class TimeProfiler:
def __init__(self):
self.broadcast_time = []
self.completions_time = []
self.reward_model_time = []
self.reference_model_time = []
self.policy_model_time = []
self.total_time = []

def print(self):
import numpy as np

additional_time = []

for length in range(len(self.total_time)):
other_time = 2 * self.total_time[length] - sum([self.__getattribute__(key)[length] for key in self.__dict__.keys()])
additional_time.append(other_time)

def mean_std_format(values):
mean = np.mean(values)
std = np.std(values)
return f"{mean:.2f}{std:.2f})"

print(f"Broadcast Time: {mean_std_format(self.broadcast_time)}", flush=True)
print(f"Completions Time: {mean_std_format(self.completions_time)}", flush=True)
print(f"Reward Model Time: {mean_std_format(self.reward_model_time)}", flush=True)
print(f"Reference Model Time: {mean_std_format(self.reference_model_time)}", flush=True)
print(f"Policy Model Time: {mean_std_format(self.policy_model_time)}", flush=True)
print(f"Total Time: {mean_std_format(self.total_time)}", flush=True)
print(f"Additional Time: {mean_std_format(additional_time)}", flush=True)


# FIXME
@dataclass
Expand Down Expand Up @@ -85,10 +119,6 @@ def __init__(
callbacks: list[TrainerCallback] | None = None,
**kwargs,
) -> None:
import time
import logging

start = time.time()

super().__init__(
model=policy,
Expand All @@ -99,11 +129,10 @@ def __init__(
callbacks=callbacks,
**kwargs,
)
logging.info(f'super().__init__ elapsed time:{time.time() - start}')
print(f'{self.accelerator.num_processes=}', flush=True)

self.time_profiler = TimeProfiler()
self.vllm_engines = vllm_engines
start = time.time()

if self.vllm_engines is not None and torch.distributed.get_rank() == 0:
master_address = ray._private.services.get_node_ip_address()
with socket.socket() as sock:
Expand Down Expand Up @@ -149,7 +178,6 @@ def __init__(
ray.get(refs)

torch.distributed.barrier()
logging.info(f'distributed vllm engine __init__ elapsed time:{time.time() - start}')

self.ref_model = ref_model

Expand Down Expand Up @@ -200,13 +228,10 @@ def __init__(
print("Generations Params:\n" + "\n".join([f"{attr}: {getattr(self.generator_transformers_settings, attr, None)}" for attr, _ in self.generator_transformers_settings.__annotations__.items()]))

start = time.time()
self.print_readable_stats()
self.norm_reward_mean, self.norm_reward_std = self.reward_stats(
model=self.model, dataloader=self.get_train_dataloader()
)

logging.info(f'statictis in __init__ elapsed time:{time.time() - start}')
self.print_readable_stats()
logging.info(f'Statictis calculation elapsed time:{time.time() - start}')

def _broadcast_to_vllm(self, model: DeepSpeedEngine):
# avoid OOM
Expand Down Expand Up @@ -278,8 +303,6 @@ def get_answers_and_rewards(
inputs: dict[str, torch.Tensor],
do_broadcast=True
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
import time
import logging

if torch.distributed.get_rank() == 0:
print(f'Input shape: {inputs["input_ids"].shape}', flush=True)
Expand All @@ -288,18 +311,23 @@ def get_answers_and_rewards(
if do_broadcast:
# TODO: move to generator
torch.distributed.barrier()
start = time.time()
if torch.distributed.get_rank() == 0:
start = time.time()

self._broadcast_to_vllm(model)
torch.distributed.barrier()
logging.info(f'broadcast elapsed time:{time.time() - start}')

start = time.time()
if torch.distributed.get_rank() == 0:
end = time.time()
self.time_profiler.broadcast_time.append(end - start)
logging.info(f'broadcast elapsed time:{end - start}')


generator = self._get_chat_generator()

generations: list[ChatInferenceOutput] = generator.generate_from_batch_records(
dataset_name='online', records=inputs, original_records=True
dataset_name='online', records=inputs, original_records=True, time_profiler = self.time_profiler
)
logging.info(f'generations elapsed time:{time.time() - start}')

response_ids = [torch.cat([g.input_token_ids, ans.answer_token_ids], dim=1) for g in generations for ans in g.answers]
response_attention_mask = [torch.cat([g.input_attention_mask, ans.answer_attention_mask], dim=1) for g in generations for ans in g.answers]
Expand Down Expand Up @@ -335,11 +363,17 @@ def pad_sequences(sequences, max_length, pad_value=0):
'position_ids': position_ids,
}

start = time.time()
if torch.distributed.get_rank() == 0:
start = time.time()

critic_generator = self._get_rm_generator(self.reward_model)
rewards = critic_generator.generate_from_batch_records(rm_inputs)

logging.info(f'rewards elapsed time:{time.time() - start}')
if torch.distributed.get_rank() == 0:
end = time.time()
self.time_profiler.reward_model_time.append(end - start)
logging.info(f'Rewards elapsed time:{end - start}')

return response_ids, response_attention_mask, response_tokens_mask, position_ids, rewards

def get_batch_loss_metrics(
Expand All @@ -348,8 +382,6 @@ def get_batch_loss_metrics(
inputs: dict[str, torch.Tensor],
train_eval: Literal['train', 'eval'] = 'train',
):
import logging
import time

with torch.no_grad():
(
Expand All @@ -363,37 +395,44 @@ def get_batch_loss_metrics(
inputs=inputs,
do_broadcast=True if train_eval == 'train' else False
)
start = time.time()

if torch.distributed.get_rank() == 0:
start = time.time()

ref_logprobs = self.get_logprobs(
model=self.ref_model,
input_ids=query_response,
attention_mask=attention_mask,
position_ids=position_ids,
loss_mask=response_tokens_mask,
)

logging.info(f'reference elapsed time:{time.time() - start}')

if torch.distributed.get_rank() == 0:
end = time.time()
self.time_profiler.reference_model_time.append(end - start)
logging.info(f'Reference logprobs elapsed time:{end - start}')

rewards, valid_mask, rewards_metrics = self.process_rewards(rewards=rewards, query_response=query_response)


del inputs
gc.collect()
torch.cuda.empty_cache()

if torch.distributed.get_rank() == 0:
start = time.time()

start = time.time()
logprobs = self.get_logprobs(
model=model,
input_ids=query_response,
attention_mask=attention_mask,
position_ids=position_ids,
loss_mask=response_tokens_mask,
)
logging.info(f'policy logrobs elapsed time:{time.time() - start}')

# print(f'rank {torch.distributed.get_rank()}:{ref_logprobs.shape=}, {logprobs.shape=}', flush=True)
# print(f'rank {torch.distributed.get_rank()}:{ref_logprobs=}, {logprobs=}', flush=True)
# print(f'rank {torch.distributed.get_rank()}:{(ref_logprobs[0, -1, :] - logprobs[0, -1, :]).abs().sum()=}', flush=True)
# assert torch.allclose(ref_logprobs, logprobs)
if torch.distributed.get_rank() == 0:
end = time.time()
self.time_profiler.policy_model_time.append(end - start)
logging.info(f'Policy logrobs elapsed time:{end - start}')

with torch.no_grad():
kl_term = logprobs.detach() - ref_logprobs
Expand Down Expand Up @@ -449,16 +488,20 @@ def print_readable_stats(self):
print(f"Reserved: {torch.cuda.memory_reserved() / (1024 ** 2):.2f} MB")

def compute_loss(self, model, inputs, return_outputs: bool = False, num_items_in_batch=None):
import logging
import time
import gc

start = time.time()

if torch.distributed.get_rank() == 0:
start = time.time()

loss, metrics = self.get_batch_loss_metrics(model, inputs, 'train')

self.store_metrics(metrics=metrics, train_eval='train')
logging.info(f'get_batch_loss_metrics elapsed time:{time.time() - start}')

if torch.distributed.get_rank() == 0:
end = time.time()
self.time_profiler.total_time.append(end - start)
logging.info(f'Total elapsed time:{end - start}')
self.time_profiler.print()

gc.collect()
torch.cuda.empty_cache()

Expand Down

0 comments on commit b94156e

Please sign in to comment.