Skip to content

Commit

Permalink
backup
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad committed Apr 23, 2024
1 parent 23adab8 commit d71d5b5
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 46 deletions.
39 changes: 35 additions & 4 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import random
import torch
import torchaudio
import torch.nn.functional as F
import torchaudio.functional as AF
import numpy as np

from amt.config import load_config
from amt.tokenizer import AmtTokenizer
Expand All @@ -22,6 +24,34 @@
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token


def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(
dim=axis, index=torch.arange(length, device=array.device)
)

if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = F.pad(
array, [pad for sizes in pad_widths[::-1] for pad in sizes]
)
else:
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)

if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)

return array


# Refactor default params are stored in config.json
class AudioTransform(torch.nn.Module):
def __init__(
Expand All @@ -39,7 +69,7 @@ def __init__(
reduce_ratio: float = 0.01,
detune_ratio: float = 0.1,
detune_max_shift: float = 0.15,
spec_aug_ratio: float = 0.95,
spec_aug_ratio: float = 0.9,
):
super().__init__()
self.tokenizer = AmtTokenizer()
Expand Down Expand Up @@ -105,12 +135,13 @@ def __init__(
n_stft=self.config["n_fft"] // 2 + 1,
)
self.spec_aug = torch.nn.Sequential(
torchaudio.transforms.TimeMasking(
time_mask_param=self.time_mask_param,
iid_masks=True,
),
torchaudio.transforms.FrequencyMasking(
freq_mask_param=self.freq_mask_param, iid_masks=True
),
torchaudio.transforms.TimeMasking(
time_mask_param=self.time_mask_param, iid_masks=True
),
)

def get_params(self):
Expand Down
31 changes: 16 additions & 15 deletions amt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def get_wav_mid_segments(
max_pedal_len_ms=15000,
)

# Hardcoded to 2.5s
if _check_onset_threshold(mid_feature, 2500) is False:
print("No note messages after 2.5s - skipping")
# Hardcoded to 10s
if _check_onset_threshold(mid_feature, 10500) is False:
print("No note messages after 10s - skipping")
continue

else:
Expand All @@ -106,15 +106,18 @@ def get_wav_mid_segments(

def pianoteq_cmd_fn(mid_path: str, wav_path: str):
presets = [
"C. Bechstein",
"C. Bechstein Close Mic",
"C. Bechstein Under Lid",
"C. Bechstein 440",
"C. Bechstein Recording",
"C. Bechstein Werckmeister III",
"C. Bechstein Neidhardt III",
"C. Bechstein mesotonic",
"C. Bechstein well tempered",
"C. Bechstein DG Prelude",
"C. Bechstein DG Sweet",
"C. Bechstein DG Felt I",
"C. Bechstein DG Felt II",
"C. Bechstein DG D 282",
"C. Bechstein DG Recording 1",
"C. Bechstein DG Recording 2",
"C. Bechstein DG Recording 3",
"C. Bechstein DG Cinematic",
"C. Bechstein DG Snappy",
"C. Bechstein DG Venue",
"C. Bechstein DG Player",
"HB Steinway D Blues",
"HB Steinway D Pop",
"HB Steinway D New Age",
Expand All @@ -137,8 +140,6 @@ def pianoteq_cmd_fn(mid_path: str, wav_path: str):
"HB Steinway D Cabaret",
"HB Steinway D Bright",
"HB Steinway D Hyper Bright",
"HB Steinway D Prepared",
"HB Steinway D Honky Tonk",
]

preset = random.choice(presets)
Expand All @@ -148,7 +149,7 @@ def pianoteq_cmd_fn(mid_path: str, wav_path: str):
safe_mid_path = shlex.quote(mid_path)
safe_wav_path = shlex.quote(wav_path)

executable_path = "/home/loubb/pianoteq/x86-64bit/Pianoteq 8 STAGE"
executable_path = "/mnt/ssd-1/aria/pianoteq/x86-64bit/Pianoteq 8 STAGE"
command = f'"{executable_path}" --preset {safe_preset} --midi {safe_mid_path} --wav {safe_wav_path}'

return command
Expand Down
33 changes: 23 additions & 10 deletions amt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ def __init__(self, n_state: int, n_head: int):
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state, bias=False)
self.out = nn.Linear(n_state, n_state, bias=False)

# self.x_norm = None
# self.q_norm = None
# self.k_norm = None
# self.v_norm = None
# self.out_norm = None

def forward(
self,
Expand All @@ -78,6 +84,11 @@ def forward(
q = q.view(batch_size, target_seq_len, self.n_head, self.d_head)
k = k.view(batch_size, source_seq_len, self.n_head, self.d_head)
v = v.view(batch_size, source_seq_len, self.n_head, self.d_head)

# self.x_norm = torch.norm(x, dim=-1).mean()
# self.q_norm = torch.norm(q, dim=-1).mean()
# self.k_norm = torch.norm(k, dim=-1).mean()
# self.v_norm = torch.norm(v, dim=-1).mean()

# (bz, L, nh, dh) -> (bz, nh, L, dh)
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
Expand All @@ -93,12 +104,14 @@ def forward(
value=v,
is_causal=_is_causal,
)

# self.out_norm = torch.norm(wv, dim=-1).mean()

# (bz, nh, L, dh) -> (bz, L, nh, dh) -> (bz, L, d)
wv = wv.transpose(1, 2)
wv = wv.view(batch_size, target_seq_len, self.n_head * self.d_head)

return self.out(wv), None
return self.out(wv)


class ResidualAttentionBlock(nn.Module):
Expand Down Expand Up @@ -129,9 +142,9 @@ def forward(
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
):
x = x + self.attn(self.attn_ln(x), mask=mask)[0]
x = x + self.attn(self.attn_ln(x), mask=mask)
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa)[0]
x = x + self.cross_attn(self.cross_attn_ln(x), xa)
x = x + self.mlp(self.mlp_ln(x))
return x

Expand Down Expand Up @@ -188,6 +201,7 @@ def __init__(
]
)
self.ln = nn.LayerNorm(n_state)
self.output = nn.Linear(n_state, n_vocab, bias=False)

mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)
Expand All @@ -206,9 +220,11 @@ def forward(self, x: Tensor, xa: Tensor):
x = block(x, xa, mask=self.mask)

x = self.ln(x)
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
logits = self.output(x)

# logits = (
# x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
# ).float()

return logits

Expand Down Expand Up @@ -244,7 +260,4 @@ def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:

@property
def device(self):
return next(self.parameters()).device

def get_empty_cache(self):
return {}
return next(self.parameters()).device
60 changes: 43 additions & 17 deletions amt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import csv
import random
import traceback
import functools
import argparse
import logging
Expand All @@ -24,7 +25,7 @@
from amt.config import load_model_config
from aria.utils import _load_weight

GRADIENT_ACC_STEPS = 32
GRADIENT_ACC_STEPS = 2

# ----- USAGE -----
#
Expand Down Expand Up @@ -143,7 +144,7 @@ def _get_optim(
model.parameters(),
lr=lr,
weight_decay=0.1,
betas=(0.9, 0.98),
betas=(0.9, 0.95),
eps=1e-6,
)

Expand Down Expand Up @@ -312,6 +313,22 @@ def make_checkpoint(_accelerator, _epoch: int, _step: int):
f"EPOCH {_epoch}/{epochs + start_epoch}: Saving checkpoint - {checkpoint_dir}"
)
_accelerator.save_state(checkpoint_dir)

def log_activation_norms(_model: AmtEncoderDecoder, _accelerator: accelerate.Accelerator):
for idx, block in enumerate(_model.decoder.blocks):
x_norm = _accelerator.gather(block.attn.x_norm).mean()
q_norm = _accelerator.gather(block.attn.q_norm).mean()
k_norm = _accelerator.gather(block.attn.k_norm).mean()
v_norm = _accelerator.gather(block.attn.v_norm).mean()
out_norm = _accelerator.gather(block.attn.out_norm).mean()
logger.debug(f"{idx}.attn - x: {x_norm}, q: {q_norm}, k: {k_norm}, v: {v_norm}, out: {out_norm}")

x_norm = _accelerator.gather(block.cross_attn.x_norm).mean()
q_norm = _accelerator.gather(block.cross_attn.q_norm).mean()
k_norm = _accelerator.gather(block.cross_attn.k_norm).mean()
v_norm = _accelerator.gather(block.cross_attn.v_norm).mean()
out_norm = _accelerator.gather(block.cross_attn.out_norm).mean()
logger.debug(f"{idx}.cross_attn - x: {x_norm}, q: {q_norm}, k: {k_norm}, v: {v_norm}, out: {out_norm}")

def get_max_norm(named_parameters):
max_grad_norm = {"val": 0.0}
Expand Down Expand Up @@ -344,6 +361,7 @@ def train_loop(
lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"])

model.train()
grad_norm = 0.0
for __step, batch in (
pbar := tqdm(
enumerate(dataloader),
Expand Down Expand Up @@ -378,8 +396,6 @@ def train_loop(
grad_norm = accelerator.clip_grad_norm_(
model.parameters(), 1.0
).item()
else:
grad_norm = 0
optimizer.step()
optimizer.zero_grad()

Expand All @@ -398,7 +414,8 @@ def train_loop(
pbar.set_postfix_str(
f"lr={lr_for_print}, "
f"loss={round(loss_buffer[-1], 4)}, "
f"trailing={round(trailing_loss, 4)}"
f"trailing={round(trailing_loss, 4)}, "
f"grad_norm={round(grad_norm, 4)}"
)

if scheduler:
Expand Down Expand Up @@ -470,6 +487,7 @@ def val_loop(dataloader, _epoch: int, aug: bool):
PAD_ID = train_dataloader.dataset.tokenizer.pad_id
logger = get_logger(__name__) # Accelerate logger
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)

logger.info(
f"Model has "
f"{'{:,}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))} "
Expand Down Expand Up @@ -522,19 +540,27 @@ def val_loop(dataloader, _epoch: int, aug: bool):
)

for epoch in range(start_epoch, epochs + start_epoch):
avg_train_loss = train_loop(dataloader=train_dataloader, _epoch=epoch)
avg_val_loss = val_loop(
dataloader=val_dataloader, _epoch=epoch, aug=False
)
avg_val_loss_aug = val_loop(
dataloader=val_dataloader, _epoch=epoch, aug=True
)
if accelerator.is_main_process:
epoch_writer.writerow(
[epoch, avg_train_loss, avg_val_loss, avg_val_loss_aug]
try:
avg_train_loss = train_loop(
dataloader=train_dataloader, _epoch=epoch
)
epoch_csv.flush()
make_checkpoint(_accelerator=accelerator, _epoch=epoch + 1, _step=0)
avg_val_loss = val_loop(
dataloader=val_dataloader, _epoch=epoch, aug=False
)
avg_val_loss_aug = val_loop(
dataloader=val_dataloader, _epoch=epoch, aug=True
)
if accelerator.is_main_process:
epoch_writer.writerow(
[epoch, avg_train_loss, avg_val_loss, avg_val_loss_aug]
)
epoch_csv.flush()
make_checkpoint(
_accelerator=accelerator, _epoch=epoch + 1, _step=0
)
except Exception as e:
logger.debug(traceback.format_exc())
raise e

logging.shutdown()
if accelerator.is_main_process:
Expand Down

0 comments on commit d71d5b5

Please sign in to comment.