diff --git a/amt/audio.py b/amt/audio.py index 8ce6f4c..28913ab 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -69,7 +69,7 @@ def __init__( reduce_ratio: float = 0.01, detune_ratio: float = 0.1, detune_max_shift: float = 0.15, - spec_aug_ratio: float = 0.95, + spec_aug_ratio: float = 0.9, ): super().__init__() self.tokenizer = AmtTokenizer() @@ -135,12 +135,13 @@ def __init__( n_stft=self.config["n_fft"] // 2 + 1, ) self.spec_aug = torch.nn.Sequential( + torchaudio.transforms.TimeMasking( + time_mask_param=self.time_mask_param, + iid_masks=True, + ), torchaudio.transforms.FrequencyMasking( freq_mask_param=self.freq_mask_param, iid_masks=True ), - torchaudio.transforms.TimeMasking( - time_mask_param=self.time_mask_param, iid_masks=True - ), ) def get_params(self): diff --git a/amt/data.py b/amt/data.py index cde420a..4bf3ccb 100644 --- a/amt/data.py +++ b/amt/data.py @@ -88,9 +88,9 @@ def get_wav_mid_segments( max_pedal_len_ms=15000, ) - # Hardcoded to 2.5s - if _check_onset_threshold(mid_feature, 2500) is False: - print("No note messages after 2.5s - skipping") + # Hardcoded to 5s + if _check_onset_threshold(mid_feature, 5000) is False: + print("No note messages after 5s - skipping") continue else: @@ -149,7 +149,7 @@ def pianoteq_cmd_fn(mid_path: str, wav_path: str): safe_mid_path = shlex.quote(mid_path) safe_wav_path = shlex.quote(wav_path) - executable_path = "/home/loubb/pianoteq/x86-64bit/Pianoteq 8 STAGE" + executable_path = "/mnt/ssd-1/aria/pianoteq/x86-64bit/Pianoteq 8 STAGE" command = f'"{executable_path}" --preset {safe_preset} --midi {safe_mid_path} --wav {safe_wav_path}' return command diff --git a/amt/model.py b/amt/model.py index 1b60a46..89b4da0 100644 --- a/amt/model.py +++ b/amt/model.py @@ -98,7 +98,7 @@ def forward( wv = wv.transpose(1, 2) wv = wv.view(batch_size, target_seq_len, self.n_head * self.d_head) - return self.out(wv), None + return self.out(wv) class ResidualAttentionBlock(nn.Module): @@ -129,9 +129,9 @@ def forward( xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, ): - x = x + self.attn(self.attn_ln(x), mask=mask)[0] + x = x + self.attn(self.attn_ln(x), mask=mask) if self.cross_attn: - x = x + self.cross_attn(self.cross_attn_ln(x), xa)[0] + x = x + self.cross_attn(self.cross_attn_ln(x), xa) x = x + self.mlp(self.mlp_ln(x)) return x @@ -188,6 +188,7 @@ def __init__( ] ) self.ln = nn.LayerNorm(n_state) + self.output = nn.Linear(n_state, n_vocab, bias=False) mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) self.register_buffer("mask", mask, persistent=False) @@ -206,9 +207,7 @@ def forward(self, x: Tensor, xa: Tensor): x = block(x, xa, mask=self.mask) x = self.ln(x) - logits = ( - x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) - ).float() + logits = self.output(x) return logits @@ -245,6 +244,3 @@ def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor: @property def device(self): return next(self.parameters()).device - - def get_empty_cache(self): - return {} diff --git a/amt/train.py b/amt/train.py index ebad059..a5d0736 100644 --- a/amt/train.py +++ b/amt/train.py @@ -2,6 +2,7 @@ import sys import csv import random +import traceback import functools import argparse import logging @@ -24,7 +25,7 @@ from amt.config import load_model_config from aria.utils import _load_weight -GRADIENT_ACC_STEPS = 32 +GRADIENT_ACC_STEPS = 2 # ----- USAGE ----- # @@ -143,7 +144,7 @@ def _get_optim( model.parameters(), lr=lr, weight_decay=0.1, - betas=(0.9, 0.98), + betas=(0.9, 0.95), eps=1e-6, ) @@ -344,6 +345,7 @@ def train_loop( lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"]) model.train() + grad_norm = 0.0 for __step, batch in ( pbar := tqdm( enumerate(dataloader), @@ -378,8 +380,6 @@ def train_loop( grad_norm = accelerator.clip_grad_norm_( model.parameters(), 1.0 ).item() - else: - grad_norm = 0 optimizer.step() optimizer.zero_grad() @@ -398,7 +398,8 @@ def train_loop( pbar.set_postfix_str( f"lr={lr_for_print}, " f"loss={round(loss_buffer[-1], 4)}, " - f"trailing={round(trailing_loss, 4)}" + f"trailing={round(trailing_loss, 4)}, " + f"grad_norm={round(grad_norm, 4)}" ) if scheduler: @@ -470,6 +471,7 @@ def val_loop(dataloader, _epoch: int, aug: bool): PAD_ID = train_dataloader.dataset.tokenizer.pad_id logger = get_logger(__name__) # Accelerate logger loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID) + logger.info( f"Model has " f"{'{:,}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))} " @@ -522,19 +524,27 @@ def val_loop(dataloader, _epoch: int, aug: bool): ) for epoch in range(start_epoch, epochs + start_epoch): - avg_train_loss = train_loop(dataloader=train_dataloader, _epoch=epoch) - avg_val_loss = val_loop( - dataloader=val_dataloader, _epoch=epoch, aug=False - ) - avg_val_loss_aug = val_loop( - dataloader=val_dataloader, _epoch=epoch, aug=True - ) - if accelerator.is_main_process: - epoch_writer.writerow( - [epoch, avg_train_loss, avg_val_loss, avg_val_loss_aug] + try: + avg_train_loss = train_loop( + dataloader=train_dataloader, _epoch=epoch ) - epoch_csv.flush() - make_checkpoint(_accelerator=accelerator, _epoch=epoch + 1, _step=0) + avg_val_loss = val_loop( + dataloader=val_dataloader, _epoch=epoch, aug=False + ) + avg_val_loss_aug = val_loop( + dataloader=val_dataloader, _epoch=epoch, aug=True + ) + if accelerator.is_main_process: + epoch_writer.writerow( + [epoch, avg_train_loss, avg_val_loss, avg_val_loss_aug] + ) + epoch_csv.flush() + make_checkpoint( + _accelerator=accelerator, _epoch=epoch + 1, _step=0 + ) + except Exception as e: + logger.debug(traceback.format_exc()) + raise e logging.shutdown() if accelerator.is_main_process: