Skip to content

Commit

Permalink
Merge pull request #58 from instructlab/ap/async_logger
Browse files Browse the repository at this point in the history
Ap/async logger
  • Loading branch information
aldopareja authored Jun 21, 2024
2 parents 08647d7 + bf61b26 commit 9e7a50a
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 23 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ torch>=2.3.0a0
transformers>=4.41.2
datasets>=2.15.0
numba
numpy
numpy==1.26.4
rich
dolomite-engine @ git+https://github.com/ibm-granite/dolomite-engine.git@main
trl>=0.9.4
Expand All @@ -16,4 +16,4 @@ pydantic>=2.7.0

# deepspeed needs to be at the bottom or it'll break during installation
deepspeed>=0.14.3

aiofiles>=23.2.1
54 changes: 54 additions & 0 deletions src/instructlab/training/async_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import json
import asyncio
from datetime import datetime
import aiofiles
import threading
import os

Check warning on line 6 in src/instructlab/training/async_logger.py

View workflow job for this annotation

GitHub Actions / lint

W0611: Unused import os (unused-import)


class AsyncStructuredLogger:
def __init__(self, file_name="training_log.jsonl"):
self.file_name = file_name
self.logs = []
self.loop = asyncio.new_event_loop()
t = threading.Thread(
target=self._run_event_loop, args=(self.loop,), daemon=True
)
t.start()
asyncio.run_coroutine_threadsafe(self._initialize_log_file(), self.loop)

def _run_event_loop(self, loop):
asyncio.set_event_loop(loop) #
loop.run_forever()

async def _initialize_log_file(self):
self.logs = []
try:
async with aiofiles.open(self.file_name, "r") as f:
async for line in f:
if line.strip(): # Avoid empty lines
self.logs.append(json.loads(line.strip()))
except FileNotFoundError:
# File does not exist but the first log will create it.
pass

async def log(self, data):
"""logs a dictionary as a new line in a jsonl file with a timestamp"""
if not isinstance(data, dict):
raise ValueError("Logged data must be a dictionary")
data["timestamp"] = datetime.now().isoformat()
self.logs.append(data)
await self._write_logs_to_file(data)
{{print(f"\033[92m{json.dumps(data, indent=4)}\033[0m")}}

Check failure on line 42 in src/instructlab/training/async_logger.py

View workflow job for this annotation

GitHub Actions / lint

E1143: '{print(f'\x1b[92m{json.dumps(data, indent=4)}\x1b[0m')}' is unhashable and can't be used as a member in a set (unhashable-member)

async def _write_logs_to_file(self, data):
"""appends to the log instead of writing the whole log each time"""
async with aiofiles.open(self.file_name, "a") as f:
await f.write(json.dumps(data, indent=None) + "\n")

def log_sync(self, data: dict):
"""runs the log coroutine non-blocking"""
asyncio.run_coroutine_threadsafe(self.log(data), self.loop)

def __repr__(self):
return f"<AsyncStructuredLogger(file_name={self.file_name})>"
65 changes: 44 additions & 21 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from instructlab.training.token_dataset import setup_dataloader, setup_dataset
from instructlab.training.tokenizer_utils import setup_tokenizer
from instructlab.training.async_logger import AsyncStructuredLogger
from instructlab.training.utils import (
StreamablePopen,
add_noisy_embeddings,
Expand Down Expand Up @@ -316,7 +317,7 @@ def maybe_resume_training(args, model):
return model


def train(args, model, tokenizer, train_loader, grad_accum):
def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
model.train()

global_step = 1
Expand Down Expand Up @@ -394,16 +395,32 @@ def train(args, model, tokenizer, train_loader, grad_accum):
current_lr = model.lr_scheduler.get_last_lr()[0]
cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3)
cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
global_grad_norm = model.get_global_grad_norm()
global_grad_norm = (
float(global_grad_norm) if global_grad_norm is not None else None
)
weight_norm = float(
model.optimizer.single_partition_of_fp32_groups[0].norm()
)

print(
f"throughput: {overall_throughput} "
f"samples/s, lr: {current_lr}, "
f"loss: {loss.item()} "
f"cuda_mem_allocated: {cuda_mem_allocated} GB "
f"cuda_malloc_retries: {cuda_malloc_retries} "
f"num_loss_counted_tokens: {num_loss_counted_tokens} "
f"batch_size: {aggregated_values[1]} "
f"total loss: {aggregated_values[2]/num_loss_counted_tokens}"
metric_logger.log_sync(
{
"epoch": epoch,
"step": global_step,
"rank": torch.distributed.get_rank(),
"loss": loss.item(),
"overall_throughput": overall_throughput,
"lr": current_lr,
"cuda_mem_allocated": cuda_mem_allocated,
"cuda_malloc_retries": cuda_malloc_retries,
"num_loss_counted_tokens": int(num_loss_counted_tokens),
"batch_size": int(aggregated_values[1]),
"total_loss": float(
aggregated_values[2] / num_loss_counted_tokens
),
"gradnorm": global_grad_norm,
"weight_norm": weight_norm,
}
)

if global_step * batch_size % args.save_samples == 0:
Expand Down Expand Up @@ -435,8 +452,12 @@ def main(args):
# Third Party
import yaml

metric_logger = AsyncStructuredLogger(
args.output_dir + "/training_params_and_metrics.jsonl"
)
if os.environ["LOCAL_RANK"] == "0":
print(f"\033[38;5;120m{yaml.dump(vars(args), sort_keys=False)}\033[0m")
metric_logger.log_sync({"script_params": vars(args)})

setup_logger(args.log_level)
CHAT_TEMPLATE, SPECIAL_TOKENS = retrieve_chat_template(args.chat_tmpl_path)
Expand Down Expand Up @@ -483,22 +504,24 @@ def main(args):
)

if args.local_rank == 0:
print(
f"\033[96mnum_gpus: {torch.distributed.get_world_size()}\n"
f"avg_sample_len: {dataset.get_lengths().mean()}\n"
f"effective_batch_size: {args.effective_batch_size}\n"
f"max_batch_len_per_gpu: {args.max_batch_len}\n"
f"packing_max_batch_len: {packing_max_batch_len}\n"
f"grad_accum: {grad_accum}\n"
f"num batches: {len(train_loader)}\n"
f"avg_samples_per_batch: {len(dataset)/len(train_loader)}\n"
f"samples_per_gpu: {args.samples_per_gpu}\033[0m"
metric_logger.log_sync(
{
"num_gpus": torch.distributed.get_world_size(),
"avg_sample_len": dataset.get_lengths().mean(),
"effective_batch_size": args.effective_batch_size,
"max_batch_len_per_gpu": args.max_batch_len,
"packing_max_batch_len": packing_max_batch_len,
"grad_accum": grad_accum,
"num_batches": len(train_loader),
"avg_samples_per_batch": len(dataset) / len(train_loader),
"samples_per_gpu": args.samples_per_gpu,
}
)

model = setup_model(args, tokenizer, train_loader, grad_accum)
model = maybe_resume_training(args, model)

train(args, model, tokenizer, train_loader, grad_accum)
train(args, model, tokenizer, train_loader, grad_accum, metric_logger)

torch.distributed.barrier()
torch.distributed.destroy_process_group()
Expand Down

0 comments on commit 9e7a50a

Please sign in to comment.