Skip to content

Commit

Permalink
Fix prev token bug (#25)
Browse files Browse the repository at this point in the history
* fix lr

* fix prev token bug

* adj logic
  • Loading branch information
loubbrad authored Apr 10, 2024
1 parent 1c9d666 commit c333406
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 32 deletions.
8 changes: 8 additions & 0 deletions amt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,14 @@ def tensor_pitch_aug(
seq[i, j] = tok_to_id.get(
(msg_type, pitch + shift), unk_tok
)
elif (
type(tok) is tuple
and tok[0] == "prev"
and tok[1] != "pedal"
):
seq[i, j] = tok_to_id.get(
("prev", tok[1] + shift), unk_tok
)
elif tok == pad_tok:
break

Expand Down
57 changes: 25 additions & 32 deletions amt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from torch import nn as nn
from torch.utils.data import DataLoader

from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
from accelerate.logging import get_logger
from safetensors.torch import load_file
from logging.handlers import RotatingFileHandler
Expand All @@ -27,6 +25,8 @@
from amt.config import load_model_config
from aria.utils import _load_weight

GRADIENT_ACC_STEPS = 2

# ----- USAGE -----
#
# This script is meant to be run using the huggingface accelerate cli, see:
Expand Down Expand Up @@ -247,7 +247,7 @@ def _collate_fn(seqs, max_pitch_shift: int):
train_dataset,
batch_size=batch_size,
num_workers=num_workers,
collate_fn=functools.partial(_collate_fn, max_pitch_shift=4),
collate_fn=functools.partial(_collate_fn, max_pitch_shift=5),
shuffle=True,
)
val_dataloader = DataLoader(
Expand Down Expand Up @@ -327,10 +327,11 @@ def train_loop(
):
with accelerator.accumulate(model):
step = __step + _resume_step + 1

wav, src, tgt, pitch_shift, idxs = batch

mel = audio_transform.forward(wav, shift=pitch_shift)
with torch.no_grad():
mel = audio_transform.forward(wav, shift=pitch_shift)

logits = model(mel, src) # (b_sz, s_len, v_sz)
logits = logits.transpose(
1, 2
Expand All @@ -344,13 +345,25 @@ def train_loop(
)
avg_train_loss = sum(loss_buffer) / len(loss_buffer)

# Backwards step
accelerator.backward(loss)
if accelerator.sync_gradients:
grad_norm = accelerator.clip_grad_norm_(
model.parameters(), 1.0
).item()
else:
grad_norm = 0
optimizer.step()
optimizer.zero_grad()

# Logging
logger.debug(
f"EPOCH {_epoch} STEP {step}: "
f"lr={lr_for_print}, "
f"loss={round(loss_buffer[-1], 4)}, "
f"trailing_loss={round(trailing_loss, 4)}, "
f"average_loss={round(avg_train_loss, 4)}"
f"average_loss={round(avg_train_loss, 4)}, "
f"grad_norm={round(grad_norm, 4)}"
)
if accelerator.is_main_process:
loss_writer.writerow([_epoch, step, loss_buffer[-1]])
Expand All @@ -361,26 +374,6 @@ def train_loop(
f"trailing={round(trailing_loss, 4)}"
)

# Backwards step
accelerator.backward(loss)

max_grad_norm_bn = get_max_norm(model.named_parameters())
accelerator.clip_grad_norm_(model.parameters(), 1.0)
max_grad_norm_an = get_max_norm(model.named_parameters())

if max_grad_norm_bn["val"] > 1.5:
logger.warning(
f"Seen large grad_norm {max_grad_norm_bn['name']}: {max_grad_norm_bn['val']} -> {max_grad_norm_an['val']}"
)
logger.debug(accelerator.gather(loss))
logger.debug(accelerator.gather(idxs))
elif math.isnan(trailing_loss):
logger.error(accelerator.gather(loss))
logger.error(loss_buffer)
logger.error(accelerator.gather(idxs))

optimizer.step()
optimizer.zero_grad()
if scheduler:
scheduler.step()
lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0])
Expand Down Expand Up @@ -553,7 +546,7 @@ def resume_train(

tokenizer = AmtTokenizer()
accelerator = accelerate.Accelerator(
project_dir=project_dir, gradient_accumulation_steps=4
project_dir=project_dir, gradient_accumulation_steps=GRADIENT_ACC_STEPS
)
if accelerator.is_main_process:
project_dir = setup_project_dir(project_dir)
Expand Down Expand Up @@ -607,13 +600,13 @@ def resume_train(
optimizer, scheduler = get_pretrain_optim(
model,
num_epochs=epochs,
steps_per_epoch=len(train_dataloader),
steps_per_epoch=len(train_dataloader) // GRADIENT_ACC_STEPS,
)
elif mode == "finetune":
optimizer, scheduler = get_finetune_optim(
model,
num_epochs=epochs,
steps_per_epoch=len(train_dataloader),
steps_per_epoch=len(train_dataloader) // GRADIENT_ACC_STEPS,
)
else:
raise Exception
Expand Down Expand Up @@ -686,7 +679,7 @@ def train(

tokenizer = AmtTokenizer()
accelerator = accelerate.Accelerator(
project_dir=project_dir, gradient_accumulation_steps=4
project_dir=project_dir, gradient_accumulation_steps=GRADIENT_ACC_STEPS
)
if accelerator.is_main_process:
project_dir = setup_project_dir(project_dir)
Expand Down Expand Up @@ -735,13 +728,13 @@ def train(
optimizer, scheduler = get_pretrain_optim(
model,
num_epochs=epochs,
steps_per_epoch=len(train_dataloader),
steps_per_epoch=len(train_dataloader) // GRADIENT_ACC_STEPS,
)
elif mode == "finetune":
optimizer, scheduler = get_finetune_optim(
model,
num_epochs=epochs,
steps_per_epoch=len(train_dataloader),
steps_per_epoch=len(train_dataloader) // GRADIENT_ACC_STEPS,
)
else:
raise Exception
Expand Down

0 comments on commit c333406

Please sign in to comment.