Skip to content

Commit

Permalink
Migrate to ariautils (#61)
Browse files Browse the repository at this point in the history
* migrate to ariautils

* add grad_acc_steps

* update tests
  • Loading branch information
loubbrad authored Dec 2, 2024
1 parent e4b13b4 commit 11f6005
Show file tree
Hide file tree
Showing 11 changed files with 139 additions and 173 deletions.
13 changes: 9 additions & 4 deletions amt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from multiprocessing import Pool, Queue, Process
from typing import Callable, Tuple

from aria.data.midi import MidiDict
from ariautils.midi import MidiDict
from amt.tokenizer import AmtTokenizer
from amt.config import load_config

Expand Down Expand Up @@ -49,7 +49,7 @@ def get_mid_segments(

start_ms = 0
while start_ms < last_note_msg_ms:
mid_feature = tokenizer._tokenize_midi_dict(
mid_feature = tokenizer.tokenize(
midi_dict=midi_dict,
start_ms=start_ms,
end_ms=start_ms + chunk_len_ms,
Expand Down Expand Up @@ -319,7 +319,7 @@ def build_synth_worker_fn(
class AmtDataset(torch.utils.data.Dataset):
def __init__(self, load_paths: str | list):
super().__init__()
self.tokenizer = AmtTokenizer(return_tensors=True)
self.tokenizer = AmtTokenizer()
self.config = load_config()["data"]
self.mixup_fn = self.tokenizer.export_msg_mixup()

Expand Down Expand Up @@ -380,7 +380,12 @@ def _format(tok):
seq_len=self.config["max_seq_len"],
)

return wav, self.tokenizer.encode(src), self.tokenizer.encode(tgt), idx
return (
wav,
torch.tensor(self.tokenizer.encode(src)),
torch.tensor(self.tokenizer.encode(tgt)),
idx,
)

def close(self):
for buff in self.file_buffs:
Expand Down
16 changes: 9 additions & 7 deletions amt/inference/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ def process_segments(
prefixes = [
tokenizer.trunc_seq(prefix, MAX_BLOCK_LEN) for prefix in raw_prefixes
]
seq = torch.stack([tokenizer.encode(prefix) for prefix in prefixes]).cuda()
seq = torch.stack(
[torch.tensor(tokenizer.encode(prefix)) for prefix in prefixes]
).cuda()
eos_idxs = torch.tensor(
[MAX_BLOCK_LEN for _ in prefixes], dtype=torch.int
).cuda()
Expand Down Expand Up @@ -294,7 +296,7 @@ def process_segments(
logger.warning("Context length overflow when transcribing segment(s)")

results = [
tokenizer.decode(seq[_idx, : eos_idxs[_idx] + 1])
tokenizer.decode(seq[_idx, : eos_idxs[_idx] + 1].tolist())
for _idx in range(seq.shape[0])
]

Expand Down Expand Up @@ -339,7 +341,7 @@ def gpu_manager(
)

audio_transform = AudioTransform().cuda()
tokenizer = AmtTokenizer(return_tensors=True)
tokenizer = AmtTokenizer()

try:
while True:
Expand Down Expand Up @@ -526,18 +528,18 @@ def _truncate_seq(
):
# Truncates and shifts a sequence by retokenizing the underlying midi_dict
if start_ms == end_ms:
_mid_dict, unclosed_notes = tokenizer._detokenize_midi_dict(
_mid_dict, unclosed_notes = tokenizer.detokenize(
seq, start_ms, return_unclosed_notes=True
)
random.shuffle(unclosed_notes)
return [("prev", p) for p in unclosed_notes] + [tokenizer.bos_tok]
else:
_mid_dict = tokenizer._detokenize_midi_dict(seq, LEN_MS)
_mid_dict = tokenizer.detokenize(seq, LEN_MS)
if len(_mid_dict.note_msgs) == 0:
return [tokenizer.bos_tok]
else:
# The end_ms - 1 is a workaround to get rid of the off msgs
res = tokenizer._tokenize_midi_dict(_mid_dict, start_ms, end_ms - 1)
res = tokenizer.tokenize(_mid_dict, start_ms, end_ms - 1)

if res[-1] == tokenizer.eos_tok:
res.pop()
Expand Down Expand Up @@ -815,7 +817,7 @@ def _save_seq(_seq: List, _save_path: str):
break

try:
mid_dict = tokenizer._detokenize_midi_dict(
mid_dict = tokenizer.detokenize(
tokenized_seq=_seq,
len_ms=last_onset,
)
Expand Down
2 changes: 1 addition & 1 deletion amt/mir.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import os

from aria.data.midi import MidiDict, get_duration_ms
from ariautils.midi import MidiDict, get_duration_ms


def midi_to_intervals_and_pitches(midi_file_path):
Expand Down
2 changes: 1 addition & 1 deletion amt/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def transcribe(
from amt.inference.transcribe import batch_transcribe
from amt.config import load_model_config
from amt.inference.model import ModelConfig, AmtEncoderDecoder
from aria.utils import _load_weight
from amt.utils import _load_weight

assert cuda_is_available(), "CUDA device not found"
assert os.path.isfile(checkpoint_path), "model checkpoint file not found"
Expand Down
35 changes: 31 additions & 4 deletions amt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from torch import Tensor
from collections import defaultdict

from aria.data.midi import MidiDict, get_duration_ms
from aria.tokenizer import Tokenizer
from ariautils.midi import MidiDict, get_duration_ms
from ariautils.tokenizer import Tokenizer

from amt.config import load_config


Expand All @@ -17,8 +18,8 @@
class AmtTokenizer(Tokenizer):
"""MidiDict tokenizer designed for AMT"""

def __init__(self, return_tensors: bool = False):
super().__init__(return_tensors)
def __init__(self):
super().__init__()
self.config = load_config()["tokenizer"]
self.name = "amt"

Expand Down Expand Up @@ -239,6 +240,20 @@ def _tokenize_midi_dict(
else:
return prefix + [self.bos_tok] + tokenized_seq

def tokenize(
self,
midi_dict: MidiDict,
start_ms: int,
end_ms: int,
max_pedal_len_ms: int | None = None,
):
return self._tokenize_midi_dict(
midi_dict=midi_dict,
start_ms=start_ms,
end_ms=end_ms,
max_pedal_len_ms=max_pedal_len_ms,
)

def _detokenize_midi_dict(
self,
tokenized_seq: list,
Expand Down Expand Up @@ -408,6 +423,18 @@ def _detokenize_midi_dict(
else:
return midi_dict

def detokenize(
self,
tokenized_seq: list,
len_ms: int,
return_unclosed_notes: bool = False,
):
return self._detokenize_midi_dict(
tokenized_seq=tokenized_seq,
len_ms=len_ms,
return_unclosed_notes=return_unclosed_notes,
)

def trunc_seq(self, seq: list, seq_len: int):
"""Truncate or pad sequence to feature sequence length."""
seq += [self.pad_tok] * (seq_len - len(seq))
Expand Down
39 changes: 27 additions & 12 deletions amt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
from amt.audio import AudioTransform
from amt.data import AmtDataset
from amt.config import load_model_config
from aria.utils import _load_weight

GRADIENT_ACC_STEPS = 2
from amt.utils import _load_weight

# ----- USAGE -----
#
Expand Down Expand Up @@ -283,7 +281,7 @@ def _debug(wav, mel, src, tgt, idx):
plot_spec(mel[_idx].cpu(), f"debug/{idx}/mel_{_idx}.png")
tokenizer = AmtTokenizer()
src_dec = tokenizer.decode(src[_idx])
mid_dict = tokenizer._detokenize_midi_dict(src_dec, 30000)
mid_dict = tokenizer.detokenize(src_dec, 30000)
mid = mid_dict.to_midi()
mid.save(f"debug/{idx}/mid_{_idx}.mid")

Expand Down Expand Up @@ -562,6 +560,7 @@ def resume_train(
mode: str,
num_workers: int,
batch_size: int,
grad_acc_steps: int,
epochs: int,
checkpoint_dir: str,
resume_epoch: int,
Expand All @@ -582,7 +581,7 @@ def resume_train(

tokenizer = AmtTokenizer()
accelerator = accelerate.Accelerator(
project_dir=project_dir, gradient_accumulation_steps=GRADIENT_ACC_STEPS
project_dir=project_dir, gradient_accumulation_steps=grad_acc_steps
)
if accelerator.is_main_process:
project_dir = setup_project_dir(project_dir)
Expand All @@ -605,7 +604,7 @@ def resume_train(
f"epochs={epochs}, "
f"num_proc={accelerator.num_processes}, "
f"batch_size={batch_size}, "
f"grad_acc_steps={GRADIENT_ACC_STEPS}, "
f"grad_acc_steps={grad_acc_steps}, "
f"num_workers={num_workers}, "
f"checkpoint_dir={checkpoint_dir}, "
f"resume_step={resume_step}, "
Expand Down Expand Up @@ -638,13 +637,13 @@ def resume_train(
optimizer, scheduler = get_pretrain_optim(
model,
num_epochs=epochs,
steps_per_epoch=len(train_dataloader) // GRADIENT_ACC_STEPS,
steps_per_epoch=len(train_dataloader) // grad_acc_steps,
)
elif mode == "finetune":
optimizer, scheduler = get_finetune_optim(
model,
num_epochs=epochs,
steps_per_epoch=len(train_dataloader) // GRADIENT_ACC_STEPS,
steps_per_epoch=len(train_dataloader) // grad_acc_steps,
)
else:
raise Exception
Expand Down Expand Up @@ -697,6 +696,7 @@ def train(
mode: str,
num_workers: int,
batch_size: int,
grad_acc_steps: int,
epochs: int,
finetune_cp_path: str | None = None, # loads ft optimizer and cp
steps_per_checkpoint: int | None = None,
Expand All @@ -716,7 +716,7 @@ def train(

tokenizer = AmtTokenizer()
accelerator = accelerate.Accelerator(
project_dir=project_dir, gradient_accumulation_steps=GRADIENT_ACC_STEPS
project_dir=project_dir, gradient_accumulation_steps=grad_acc_steps
)
if accelerator.is_main_process:
project_dir = setup_project_dir(project_dir)
Expand All @@ -731,7 +731,7 @@ def train(
f"epochs={epochs}, "
f"num_proc={accelerator.num_processes}, "
f"batch_size={batch_size}, "
f"grad_acc_steps={GRADIENT_ACC_STEPS}, "
f"grad_acc_steps={grad_acc_steps}, "
f"num_workers={num_workers}"
)

Expand Down Expand Up @@ -767,13 +767,13 @@ def train(
optimizer, scheduler = get_pretrain_optim(
model,
num_epochs=epochs,
steps_per_epoch=len(train_dataloader) // GRADIENT_ACC_STEPS,
steps_per_epoch=len(train_dataloader) // grad_acc_steps,
)
elif mode == "finetune":
optimizer, scheduler = get_finetune_optim(
model,
num_epochs=epochs,
steps_per_epoch=len(train_dataloader) // GRADIENT_ACC_STEPS,
steps_per_epoch=len(train_dataloader) // grad_acc_steps,
)
else:
raise Exception
Expand Down Expand Up @@ -844,6 +844,12 @@ def parse_resume_args():
argp.add_argument("-repoch", help="resume epoch", type=int, required=True)
argp.add_argument("-epochs", help="train epochs", type=int, required=True)
argp.add_argument("-bs", help="batch size", type=int, default=32)
argp.add_argument(
"-grad_acc_steps",
help="gradient accumulation steps",
type=int,
default=1,
)
argp.add_argument("-workers", help="number workers", type=int, default=1)
argp.add_argument("-pdir", help="project dir", type=str, required=False)
argp.add_argument(
Expand All @@ -863,6 +869,12 @@ def parse_train_args():
)
argp.add_argument("-epochs", help="train epochs", type=int, required=True)
argp.add_argument("-bs", help="batch size", type=int, default=32)
argp.add_argument(
"-grad_acc_steps",
help="gradient accumulation steps",
type=int,
default=1,
)
argp.add_argument("-workers", help="number workers", type=int, default=1)
argp.add_argument("-pdir", help="project dir", type=str, required=False)
argp.add_argument(
Expand Down Expand Up @@ -895,6 +907,7 @@ def parse_train_args():
mode="pretrain",
num_workers=train_args.workers,
batch_size=train_args.bs,
grad_acc_steps=train_args.grad_acc_steps,
epochs=train_args.epochs,
steps_per_checkpoint=train_args.spc,
project_dir=train_args.pdir,
Expand All @@ -908,6 +921,7 @@ def parse_train_args():
mode="finetune",
num_workers=train_args.workers,
batch_size=train_args.bs,
grad_acc_steps=train_args.grad_acc_steps,
epochs=train_args.epochs,
finetune_cp_path=train_args.cpath,
steps_per_checkpoint=train_args.spc,
Expand All @@ -922,6 +936,7 @@ def parse_train_args():
mode="pretrain" if resume_args.resume_mode == "pt" else "finetune",
num_workers=resume_args.workers,
batch_size=resume_args.bs,
grad_acc_steps=resume_args.grad_acc_steps,
epochs=resume_args.epochs,
checkpoint_dir=resume_args.cdir,
resume_step=resume_args.rstep,
Expand Down
16 changes: 16 additions & 0 deletions amt/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Contains utils."""


def _load_weight(ckpt_path: str, device="cpu"):
if ckpt_path.endswith("safetensors"):
try:
from safetensors.torch import load_file
except ImportError as e:
raise ImportError(
f"Please install safetensors in order to read from the checkpoint: {ckpt_path}"
) from e
return load_file(ckpt_path, device=device)
else:
import torch

return torch.load(ckpt_path, map_location=device)
4 changes: 0 additions & 4 deletions requirements-eval.txt

This file was deleted.

6 changes: 2 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
aria @ git+https://github.com/EleutherAI/aria.git
ariautils @ git+https://github.com/EleutherAI/aria-utils.git
torch >= 2.3
torchaudio
accelerate
psutil
librosa
mido
tqdm
orjson
mir_eval
mir_eval
Loading

0 comments on commit 11f6005

Please sign in to comment.