From 220413b69b29a7166da29fcb1133105d988f4cbd Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 29 Feb 2024 20:33:31 +0000 Subject: [PATCH] bf16 --- amt/data.py | 2 +- amt/train.py | 27 +++++++++++---------------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/amt/data.py b/amt/data.py index cfbfc03..ebdb358 100644 --- a/amt/data.py +++ b/amt/data.py @@ -10,7 +10,7 @@ from aria.data.midi import MidiDict from amt.tokenizer import AmtTokenizer from amt.config import load_config -from amt.audio import pad_or_trim +from amt.audio import pad_or_trim, AudioTransform def get_wav_mid_segments( diff --git a/amt/train.py b/amt/train.py index c79a6a2..9b714e4 100644 --- a/amt/train.py +++ b/amt/train.py @@ -278,13 +278,13 @@ def _bench(): f"{'{:,}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))} " "parameters" ) - logger.info("Profiling FLOP") + logger.info("Compiling model...") _bench() - with flop_counter: - _bench() - total_flop = sum(flop_counter.get_flop_counts()["Global"].values()) - logger.info(f"Forwards & backwards FLOP: {total_flop / 1e12} TF") + # with flop_counter: + # _bench() + # total_flop = sum(flop_counter.get_flop_counts()["Global"].values()) + # logger.info(f"Forwards & backwards FLOP: {total_flop / 1e12} TF") def make_checkpoint(_accelerator, _epoch: int, _step: int): checkpoint_dir = os.path.join( @@ -339,8 +339,7 @@ def train_loop( wav, src, tgt = batch # (b_sz, s_len), (b_sz, s_len, v_sz) with torch.no_grad(): - mel = audio_transform.mel(wav) - # mel, (src, tgt) = audio_transform.forward(wav, src, tgt) + mel, (src, tgt) = audio_transform.forward(wav, src, tgt) logits = model(mel, src) # (b_sz, s_len, v_sz) logits = logits.transpose(1, 2) # Transpose for CrossEntropyLoss loss = loss_fn(logits, tgt) @@ -408,7 +407,7 @@ def val_loop(dataloader, _epoch: int): ): wav, src, tgt = batch with torch.no_grad(): - mel, (src, tgt) = audio_transform.forward(wav, src, tgt) + mel = audio_transform.mel(wav) logits = model(mel, src) logits = logits.transpose(1, 2) # Transpose for CrossEntropyLoss loss = loss_fn(logits, tgt) @@ -549,7 +548,7 @@ def resume_train( model_config = ModelConfig(**load_model_config(model_name)) model_config.set_vocab_size(tokenizer.vocab_size) model = AmtEncoderDecoder(model_config) - audio_transform = AudioTransform() + audio_transform = AudioTransform().to(accelerator.device) logger.info(f"Loaded model with config: {load_model_config(model_name)}") train_dataloader, val_dataloader = get_dataloaders( @@ -580,14 +579,12 @@ def resume_train( ( model, - audio_transform, train_dataloader, val_dataloader, optimizer, scheduler, ) = accelerator.prepare( model, - audio_transform, train_dataloader, val_dataloader, optimizer, @@ -667,9 +664,9 @@ def train( model_config = ModelConfig(**load_model_config(model_name)) model_config.set_vocab_size(tokenizer.vocab_size) model = AmtEncoderDecoder(model_config) - # logger.info("Compiling model...") - # model = torch.compile(model) - audio_transform = AudioTransform() + audio_transform = AudioTransform().to(accelerator.device) + model = torch.compile(model) + audio_transform.compile() logger.info(f"Loaded model with config: {load_model_config(model_name)}") if mode == "finetune": try: @@ -708,14 +705,12 @@ def train( ( model, - audio_transform, train_dataloader, val_dataloader, optimizer, scheduler, ) = accelerator.prepare( model, - audio_transform, train_dataloader, val_dataloader, optimizer,