Skip to content

Commit

Permalink
bf16
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad committed Feb 29, 2024
1 parent e10dca0 commit 220413b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 17 deletions.
2 changes: 1 addition & 1 deletion amt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aria.data.midi import MidiDict
from amt.tokenizer import AmtTokenizer
from amt.config import load_config
from amt.audio import pad_or_trim
from amt.audio import pad_or_trim, AudioTransform


def get_wav_mid_segments(
Expand Down
27 changes: 11 additions & 16 deletions amt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,13 @@ def _bench():
f"{'{:,}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))} "
"parameters"
)
logger.info("Profiling FLOP")
logger.info("Compiling model...")
_bench()

with flop_counter:
_bench()
total_flop = sum(flop_counter.get_flop_counts()["Global"].values())
logger.info(f"Forwards & backwards FLOP: {total_flop / 1e12} TF")
# with flop_counter:
# _bench()
# total_flop = sum(flop_counter.get_flop_counts()["Global"].values())
# logger.info(f"Forwards & backwards FLOP: {total_flop / 1e12} TF")

def make_checkpoint(_accelerator, _epoch: int, _step: int):
checkpoint_dir = os.path.join(
Expand Down Expand Up @@ -339,8 +339,7 @@ def train_loop(

wav, src, tgt = batch # (b_sz, s_len), (b_sz, s_len, v_sz)
with torch.no_grad():
mel = audio_transform.mel(wav)
# mel, (src, tgt) = audio_transform.forward(wav, src, tgt)
mel, (src, tgt) = audio_transform.forward(wav, src, tgt)
logits = model(mel, src) # (b_sz, s_len, v_sz)
logits = logits.transpose(1, 2) # Transpose for CrossEntropyLoss
loss = loss_fn(logits, tgt)
Expand Down Expand Up @@ -408,7 +407,7 @@ def val_loop(dataloader, _epoch: int):
):
wav, src, tgt = batch
with torch.no_grad():
mel, (src, tgt) = audio_transform.forward(wav, src, tgt)
mel = audio_transform.mel(wav)
logits = model(mel, src)
logits = logits.transpose(1, 2) # Transpose for CrossEntropyLoss
loss = loss_fn(logits, tgt)
Expand Down Expand Up @@ -549,7 +548,7 @@ def resume_train(
model_config = ModelConfig(**load_model_config(model_name))
model_config.set_vocab_size(tokenizer.vocab_size)
model = AmtEncoderDecoder(model_config)
audio_transform = AudioTransform()
audio_transform = AudioTransform().to(accelerator.device)
logger.info(f"Loaded model with config: {load_model_config(model_name)}")

train_dataloader, val_dataloader = get_dataloaders(
Expand Down Expand Up @@ -580,14 +579,12 @@ def resume_train(

(
model,
audio_transform,
train_dataloader,
val_dataloader,
optimizer,
scheduler,
) = accelerator.prepare(
model,
audio_transform,
train_dataloader,
val_dataloader,
optimizer,
Expand Down Expand Up @@ -667,9 +664,9 @@ def train(
model_config = ModelConfig(**load_model_config(model_name))
model_config.set_vocab_size(tokenizer.vocab_size)
model = AmtEncoderDecoder(model_config)
# logger.info("Compiling model...")
# model = torch.compile(model)
audio_transform = AudioTransform()
audio_transform = AudioTransform().to(accelerator.device)
model = torch.compile(model)
audio_transform.compile()
logger.info(f"Loaded model with config: {load_model_config(model_name)}")
if mode == "finetune":
try:
Expand Down Expand Up @@ -708,14 +705,12 @@ def train(

(
model,
audio_transform,
train_dataloader,
val_dataloader,
optimizer,
scheduler,
) = accelerator.prepare(
model,
audio_transform,
train_dataloader,
val_dataloader,
optimizer,
Expand Down

0 comments on commit 220413b

Please sign in to comment.