diff --git a/amt/audio.py b/amt/audio.py index 7038d17..28913ab 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -4,7 +4,9 @@ import random import torch import torchaudio +import torch.nn.functional as F import torchaudio.functional as AF +import numpy as np from amt.config import load_config from amt.tokenizer import AmtTokenizer @@ -22,6 +24,34 @@ TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token +def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): + """ + Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + """ + if torch.is_tensor(array): + if array.shape[axis] > length: + array = array.index_select( + dim=axis, index=torch.arange(length, device=array.device) + ) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = F.pad( + array, [pad for sizes in pad_widths[::-1] for pad in sizes] + ) + else: + if array.shape[axis] > length: + array = array.take(indices=range(length), axis=axis) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = np.pad(array, pad_widths) + + return array + + # Refactor default params are stored in config.json class AudioTransform(torch.nn.Module): def __init__( @@ -39,7 +69,7 @@ def __init__( reduce_ratio: float = 0.01, detune_ratio: float = 0.1, detune_max_shift: float = 0.15, - spec_aug_ratio: float = 0.95, + spec_aug_ratio: float = 0.9, ): super().__init__() self.tokenizer = AmtTokenizer() @@ -105,12 +135,13 @@ def __init__( n_stft=self.config["n_fft"] // 2 + 1, ) self.spec_aug = torch.nn.Sequential( + torchaudio.transforms.TimeMasking( + time_mask_param=self.time_mask_param, + iid_masks=True, + ), torchaudio.transforms.FrequencyMasking( freq_mask_param=self.freq_mask_param, iid_masks=True ), - torchaudio.transforms.TimeMasking( - time_mask_param=self.time_mask_param, iid_masks=True - ), ) def get_params(self): diff --git a/amt/data.py b/amt/data.py index f4103ae..8182c76 100644 --- a/amt/data.py +++ b/amt/data.py @@ -88,9 +88,9 @@ def get_wav_mid_segments( max_pedal_len_ms=15000, ) - # Hardcoded to 2.5s - if _check_onset_threshold(mid_feature, 2500) is False: - print("No note messages after 2.5s - skipping") + # Hardcoded to 10s + if _check_onset_threshold(mid_feature, 10500) is False: + print("No note messages after 10s - skipping") continue else: @@ -106,15 +106,18 @@ def get_wav_mid_segments( def pianoteq_cmd_fn(mid_path: str, wav_path: str): presets = [ - "C. Bechstein", - "C. Bechstein Close Mic", - "C. Bechstein Under Lid", - "C. Bechstein 440", - "C. Bechstein Recording", - "C. Bechstein Werckmeister III", - "C. Bechstein Neidhardt III", - "C. Bechstein mesotonic", - "C. Bechstein well tempered", + "C. Bechstein DG Prelude", + "C. Bechstein DG Sweet", + "C. Bechstein DG Felt I", + "C. Bechstein DG Felt II", + "C. Bechstein DG D 282", + "C. Bechstein DG Recording 1", + "C. Bechstein DG Recording 2", + "C. Bechstein DG Recording 3", + "C. Bechstein DG Cinematic", + "C. Bechstein DG Snappy", + "C. Bechstein DG Venue", + "C. Bechstein DG Player", "HB Steinway D Blues", "HB Steinway D Pop", "HB Steinway D New Age", @@ -137,8 +140,6 @@ def pianoteq_cmd_fn(mid_path: str, wav_path: str): "HB Steinway D Cabaret", "HB Steinway D Bright", "HB Steinway D Hyper Bright", - "HB Steinway D Prepared", - "HB Steinway D Honky Tonk", ] preset = random.choice(presets) @@ -148,7 +149,7 @@ def pianoteq_cmd_fn(mid_path: str, wav_path: str): safe_mid_path = shlex.quote(mid_path) safe_wav_path = shlex.quote(wav_path) - executable_path = "/home/loubb/pianoteq/x86-64bit/Pianoteq 8 STAGE" + executable_path = "/mnt/ssd-1/aria/pianoteq/x86-64bit/Pianoteq 8 STAGE" command = f'"{executable_path}" --preset {safe_preset} --midi {safe_mid_path} --wav {safe_wav_path}' return command diff --git a/amt/model.py b/amt/model.py index 1b60a46..e2fef9f 100644 --- a/amt/model.py +++ b/amt/model.py @@ -54,6 +54,12 @@ def __init__(self, n_state: int, n_head: int): self.key = nn.Linear(n_state, n_state, bias=False) self.value = nn.Linear(n_state, n_state, bias=False) self.out = nn.Linear(n_state, n_state, bias=False) + + # self.x_norm = None + # self.q_norm = None + # self.k_norm = None + # self.v_norm = None + # self.out_norm = None def forward( self, @@ -78,6 +84,11 @@ def forward( q = q.view(batch_size, target_seq_len, self.n_head, self.d_head) k = k.view(batch_size, source_seq_len, self.n_head, self.d_head) v = v.view(batch_size, source_seq_len, self.n_head, self.d_head) + + # self.x_norm = torch.norm(x, dim=-1).mean() + # self.q_norm = torch.norm(q, dim=-1).mean() + # self.k_norm = torch.norm(k, dim=-1).mean() + # self.v_norm = torch.norm(v, dim=-1).mean() # (bz, L, nh, dh) -> (bz, nh, L, dh) q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) @@ -93,12 +104,14 @@ def forward( value=v, is_causal=_is_causal, ) + + # self.out_norm = torch.norm(wv, dim=-1).mean() # (bz, nh, L, dh) -> (bz, L, nh, dh) -> (bz, L, d) wv = wv.transpose(1, 2) wv = wv.view(batch_size, target_seq_len, self.n_head * self.d_head) - return self.out(wv), None + return self.out(wv) class ResidualAttentionBlock(nn.Module): @@ -129,9 +142,9 @@ def forward( xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, ): - x = x + self.attn(self.attn_ln(x), mask=mask)[0] + x = x + self.attn(self.attn_ln(x), mask=mask) if self.cross_attn: - x = x + self.cross_attn(self.cross_attn_ln(x), xa)[0] + x = x + self.cross_attn(self.cross_attn_ln(x), xa) x = x + self.mlp(self.mlp_ln(x)) return x @@ -188,6 +201,7 @@ def __init__( ] ) self.ln = nn.LayerNorm(n_state) + self.output = nn.Linear(n_state, n_vocab, bias=False) mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) self.register_buffer("mask", mask, persistent=False) @@ -206,9 +220,11 @@ def forward(self, x: Tensor, xa: Tensor): x = block(x, xa, mask=self.mask) x = self.ln(x) - logits = ( - x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) - ).float() + logits = self.output(x) + + # logits = ( + # x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) + # ).float() return logits @@ -244,7 +260,4 @@ def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor: @property def device(self): - return next(self.parameters()).device - - def get_empty_cache(self): - return {} + return next(self.parameters()).device \ No newline at end of file diff --git a/amt/train.py b/amt/train.py index ebad059..94a4bff 100644 --- a/amt/train.py +++ b/amt/train.py @@ -2,6 +2,7 @@ import sys import csv import random +import traceback import functools import argparse import logging @@ -24,7 +25,7 @@ from amt.config import load_model_config from aria.utils import _load_weight -GRADIENT_ACC_STEPS = 32 +GRADIENT_ACC_STEPS = 2 # ----- USAGE ----- # @@ -143,7 +144,7 @@ def _get_optim( model.parameters(), lr=lr, weight_decay=0.1, - betas=(0.9, 0.98), + betas=(0.9, 0.95), eps=1e-6, ) @@ -312,6 +313,22 @@ def make_checkpoint(_accelerator, _epoch: int, _step: int): f"EPOCH {_epoch}/{epochs + start_epoch}: Saving checkpoint - {checkpoint_dir}" ) _accelerator.save_state(checkpoint_dir) + + def log_activation_norms(_model: AmtEncoderDecoder, _accelerator: accelerate.Accelerator): + for idx, block in enumerate(_model.decoder.blocks): + x_norm = _accelerator.gather(block.attn.x_norm).mean() + q_norm = _accelerator.gather(block.attn.q_norm).mean() + k_norm = _accelerator.gather(block.attn.k_norm).mean() + v_norm = _accelerator.gather(block.attn.v_norm).mean() + out_norm = _accelerator.gather(block.attn.out_norm).mean() + logger.debug(f"{idx}.attn - x: {x_norm}, q: {q_norm}, k: {k_norm}, v: {v_norm}, out: {out_norm}") + + x_norm = _accelerator.gather(block.cross_attn.x_norm).mean() + q_norm = _accelerator.gather(block.cross_attn.q_norm).mean() + k_norm = _accelerator.gather(block.cross_attn.k_norm).mean() + v_norm = _accelerator.gather(block.cross_attn.v_norm).mean() + out_norm = _accelerator.gather(block.cross_attn.out_norm).mean() + logger.debug(f"{idx}.cross_attn - x: {x_norm}, q: {q_norm}, k: {k_norm}, v: {v_norm}, out: {out_norm}") def get_max_norm(named_parameters): max_grad_norm = {"val": 0.0} @@ -344,6 +361,7 @@ def train_loop( lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"]) model.train() + grad_norm = 0.0 for __step, batch in ( pbar := tqdm( enumerate(dataloader), @@ -378,8 +396,6 @@ def train_loop( grad_norm = accelerator.clip_grad_norm_( model.parameters(), 1.0 ).item() - else: - grad_norm = 0 optimizer.step() optimizer.zero_grad() @@ -398,7 +414,8 @@ def train_loop( pbar.set_postfix_str( f"lr={lr_for_print}, " f"loss={round(loss_buffer[-1], 4)}, " - f"trailing={round(trailing_loss, 4)}" + f"trailing={round(trailing_loss, 4)}, " + f"grad_norm={round(grad_norm, 4)}" ) if scheduler: @@ -470,6 +487,7 @@ def val_loop(dataloader, _epoch: int, aug: bool): PAD_ID = train_dataloader.dataset.tokenizer.pad_id logger = get_logger(__name__) # Accelerate logger loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID) + logger.info( f"Model has " f"{'{:,}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))} " @@ -522,19 +540,27 @@ def val_loop(dataloader, _epoch: int, aug: bool): ) for epoch in range(start_epoch, epochs + start_epoch): - avg_train_loss = train_loop(dataloader=train_dataloader, _epoch=epoch) - avg_val_loss = val_loop( - dataloader=val_dataloader, _epoch=epoch, aug=False - ) - avg_val_loss_aug = val_loop( - dataloader=val_dataloader, _epoch=epoch, aug=True - ) - if accelerator.is_main_process: - epoch_writer.writerow( - [epoch, avg_train_loss, avg_val_loss, avg_val_loss_aug] + try: + avg_train_loss = train_loop( + dataloader=train_dataloader, _epoch=epoch ) - epoch_csv.flush() - make_checkpoint(_accelerator=accelerator, _epoch=epoch + 1, _step=0) + avg_val_loss = val_loop( + dataloader=val_dataloader, _epoch=epoch, aug=False + ) + avg_val_loss_aug = val_loop( + dataloader=val_dataloader, _epoch=epoch, aug=True + ) + if accelerator.is_main_process: + epoch_writer.writerow( + [epoch, avg_train_loss, avg_val_loss, avg_val_loss_aug] + ) + epoch_csv.flush() + make_checkpoint( + _accelerator=accelerator, _epoch=epoch + 1, _step=0 + ) + except Exception as e: + logger.debug(traceback.format_exc()) + raise e logging.shutdown() if accelerator.is_main_process: