diff --git a/amt/data.py b/amt/data.py index 5eeaef4..688af93 100644 --- a/amt/data.py +++ b/amt/data.py @@ -13,7 +13,7 @@ from multiprocessing import Pool, Queue, Process from typing import Callable, Tuple -from aria.data.midi import MidiDict +from ariautils.midi import MidiDict from amt.tokenizer import AmtTokenizer from amt.config import load_config @@ -49,7 +49,7 @@ def get_mid_segments( start_ms = 0 while start_ms < last_note_msg_ms: - mid_feature = tokenizer._tokenize_midi_dict( + mid_feature = tokenizer.tokenize( midi_dict=midi_dict, start_ms=start_ms, end_ms=start_ms + chunk_len_ms, @@ -319,7 +319,7 @@ def build_synth_worker_fn( class AmtDataset(torch.utils.data.Dataset): def __init__(self, load_paths: str | list): super().__init__() - self.tokenizer = AmtTokenizer(return_tensors=True) + self.tokenizer = AmtTokenizer() self.config = load_config()["data"] self.mixup_fn = self.tokenizer.export_msg_mixup() @@ -380,7 +380,12 @@ def _format(tok): seq_len=self.config["max_seq_len"], ) - return wav, self.tokenizer.encode(src), self.tokenizer.encode(tgt), idx + return ( + wav, + torch.tensor(self.tokenizer.encode(src)), + torch.tensor(self.tokenizer.encode(tgt)), + idx, + ) def close(self): for buff in self.file_buffs: diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index d8ea3e7..ecf9a3d 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -229,7 +229,9 @@ def process_segments( prefixes = [ tokenizer.trunc_seq(prefix, MAX_BLOCK_LEN) for prefix in raw_prefixes ] - seq = torch.stack([tokenizer.encode(prefix) for prefix in prefixes]).cuda() + seq = torch.stack( + [torch.tensor(tokenizer.encode(prefix)) for prefix in prefixes] + ).cuda() eos_idxs = torch.tensor( [MAX_BLOCK_LEN for _ in prefixes], dtype=torch.int ).cuda() @@ -294,7 +296,7 @@ def process_segments( logger.warning("Context length overflow when transcribing segment(s)") results = [ - tokenizer.decode(seq[_idx, : eos_idxs[_idx] + 1]) + tokenizer.decode(seq[_idx, : eos_idxs[_idx] + 1].tolist()) for _idx in range(seq.shape[0]) ] @@ -339,7 +341,7 @@ def gpu_manager( ) audio_transform = AudioTransform().cuda() - tokenizer = AmtTokenizer(return_tensors=True) + tokenizer = AmtTokenizer() try: while True: @@ -526,18 +528,18 @@ def _truncate_seq( ): # Truncates and shifts a sequence by retokenizing the underlying midi_dict if start_ms == end_ms: - _mid_dict, unclosed_notes = tokenizer._detokenize_midi_dict( + _mid_dict, unclosed_notes = tokenizer.detokenize( seq, start_ms, return_unclosed_notes=True ) random.shuffle(unclosed_notes) return [("prev", p) for p in unclosed_notes] + [tokenizer.bos_tok] else: - _mid_dict = tokenizer._detokenize_midi_dict(seq, LEN_MS) + _mid_dict = tokenizer.detokenize(seq, LEN_MS) if len(_mid_dict.note_msgs) == 0: return [tokenizer.bos_tok] else: # The end_ms - 1 is a workaround to get rid of the off msgs - res = tokenizer._tokenize_midi_dict(_mid_dict, start_ms, end_ms - 1) + res = tokenizer.tokenize(_mid_dict, start_ms, end_ms - 1) if res[-1] == tokenizer.eos_tok: res.pop() @@ -815,7 +817,7 @@ def _save_seq(_seq: List, _save_path: str): break try: - mid_dict = tokenizer._detokenize_midi_dict( + mid_dict = tokenizer.detokenize( tokenized_seq=_seq, len_ms=last_onset, ) diff --git a/amt/mir.py b/amt/mir.py index 2a33881..8693eb8 100644 --- a/amt/mir.py +++ b/amt/mir.py @@ -5,7 +5,7 @@ import json import os -from aria.data.midi import MidiDict, get_duration_ms +from ariautils.midi import MidiDict, get_duration_ms def midi_to_intervals_and_pitches(midi_file_path): diff --git a/amt/run.py b/amt/run.py index ed154d8..7c3d266 100644 --- a/amt/run.py +++ b/amt/run.py @@ -399,7 +399,7 @@ def transcribe( from amt.inference.transcribe import batch_transcribe from amt.config import load_model_config from amt.inference.model import ModelConfig, AmtEncoderDecoder - from aria.utils import _load_weight + from amt.utils import _load_weight assert cuda_is_available(), "CUDA device not found" assert os.path.isfile(checkpoint_path), "model checkpoint file not found" diff --git a/amt/tokenizer.py b/amt/tokenizer.py index 3146e32..da21fc8 100644 --- a/amt/tokenizer.py +++ b/amt/tokenizer.py @@ -6,8 +6,9 @@ from torch import Tensor from collections import defaultdict -from aria.data.midi import MidiDict, get_duration_ms -from aria.tokenizer import Tokenizer +from ariautils.midi import MidiDict, get_duration_ms +from ariautils.tokenizer import Tokenizer + from amt.config import load_config @@ -17,8 +18,8 @@ class AmtTokenizer(Tokenizer): """MidiDict tokenizer designed for AMT""" - def __init__(self, return_tensors: bool = False): - super().__init__(return_tensors) + def __init__(self): + super().__init__() self.config = load_config()["tokenizer"] self.name = "amt" @@ -239,6 +240,20 @@ def _tokenize_midi_dict( else: return prefix + [self.bos_tok] + tokenized_seq + def tokenize( + self, + midi_dict: MidiDict, + start_ms: int, + end_ms: int, + max_pedal_len_ms: int | None = None, + ): + return self._tokenize_midi_dict( + midi_dict=midi_dict, + start_ms=start_ms, + end_ms=end_ms, + max_pedal_len_ms=max_pedal_len_ms, + ) + def _detokenize_midi_dict( self, tokenized_seq: list, @@ -408,6 +423,18 @@ def _detokenize_midi_dict( else: return midi_dict + def detokenize( + self, + tokenized_seq: list, + len_ms: int, + return_unclosed_notes: bool = False, + ): + return self._detokenize_midi_dict( + tokenized_seq=tokenized_seq, + len_ms=len_ms, + return_unclosed_notes=return_unclosed_notes, + ) + def trunc_seq(self, seq: list, seq_len: int): """Truncate or pad sequence to feature sequence length.""" seq += [self.pad_tok] * (seq_len - len(seq)) diff --git a/amt/train.py b/amt/train.py index 5bd8bd2..119b2cc 100644 --- a/amt/train.py +++ b/amt/train.py @@ -23,9 +23,7 @@ from amt.audio import AudioTransform from amt.data import AmtDataset from amt.config import load_model_config -from aria.utils import _load_weight - -GRADIENT_ACC_STEPS = 2 +from amt.utils import _load_weight # ----- USAGE ----- # @@ -283,7 +281,7 @@ def _debug(wav, mel, src, tgt, idx): plot_spec(mel[_idx].cpu(), f"debug/{idx}/mel_{_idx}.png") tokenizer = AmtTokenizer() src_dec = tokenizer.decode(src[_idx]) - mid_dict = tokenizer._detokenize_midi_dict(src_dec, 30000) + mid_dict = tokenizer.detokenize(src_dec, 30000) mid = mid_dict.to_midi() mid.save(f"debug/{idx}/mid_{_idx}.mid") @@ -562,6 +560,7 @@ def resume_train( mode: str, num_workers: int, batch_size: int, + grad_acc_steps: int, epochs: int, checkpoint_dir: str, resume_epoch: int, @@ -582,7 +581,7 @@ def resume_train( tokenizer = AmtTokenizer() accelerator = accelerate.Accelerator( - project_dir=project_dir, gradient_accumulation_steps=GRADIENT_ACC_STEPS + project_dir=project_dir, gradient_accumulation_steps=grad_acc_steps ) if accelerator.is_main_process: project_dir = setup_project_dir(project_dir) @@ -605,7 +604,7 @@ def resume_train( f"epochs={epochs}, " f"num_proc={accelerator.num_processes}, " f"batch_size={batch_size}, " - f"grad_acc_steps={GRADIENT_ACC_STEPS}, " + f"grad_acc_steps={grad_acc_steps}, " f"num_workers={num_workers}, " f"checkpoint_dir={checkpoint_dir}, " f"resume_step={resume_step}, " @@ -638,13 +637,13 @@ def resume_train( optimizer, scheduler = get_pretrain_optim( model, num_epochs=epochs, - steps_per_epoch=len(train_dataloader) // GRADIENT_ACC_STEPS, + steps_per_epoch=len(train_dataloader) // grad_acc_steps, ) elif mode == "finetune": optimizer, scheduler = get_finetune_optim( model, num_epochs=epochs, - steps_per_epoch=len(train_dataloader) // GRADIENT_ACC_STEPS, + steps_per_epoch=len(train_dataloader) // grad_acc_steps, ) else: raise Exception @@ -697,6 +696,7 @@ def train( mode: str, num_workers: int, batch_size: int, + grad_acc_steps: int, epochs: int, finetune_cp_path: str | None = None, # loads ft optimizer and cp steps_per_checkpoint: int | None = None, @@ -716,7 +716,7 @@ def train( tokenizer = AmtTokenizer() accelerator = accelerate.Accelerator( - project_dir=project_dir, gradient_accumulation_steps=GRADIENT_ACC_STEPS + project_dir=project_dir, gradient_accumulation_steps=grad_acc_steps ) if accelerator.is_main_process: project_dir = setup_project_dir(project_dir) @@ -731,7 +731,7 @@ def train( f"epochs={epochs}, " f"num_proc={accelerator.num_processes}, " f"batch_size={batch_size}, " - f"grad_acc_steps={GRADIENT_ACC_STEPS}, " + f"grad_acc_steps={grad_acc_steps}, " f"num_workers={num_workers}" ) @@ -767,13 +767,13 @@ def train( optimizer, scheduler = get_pretrain_optim( model, num_epochs=epochs, - steps_per_epoch=len(train_dataloader) // GRADIENT_ACC_STEPS, + steps_per_epoch=len(train_dataloader) // grad_acc_steps, ) elif mode == "finetune": optimizer, scheduler = get_finetune_optim( model, num_epochs=epochs, - steps_per_epoch=len(train_dataloader) // GRADIENT_ACC_STEPS, + steps_per_epoch=len(train_dataloader) // grad_acc_steps, ) else: raise Exception @@ -844,6 +844,12 @@ def parse_resume_args(): argp.add_argument("-repoch", help="resume epoch", type=int, required=True) argp.add_argument("-epochs", help="train epochs", type=int, required=True) argp.add_argument("-bs", help="batch size", type=int, default=32) + argp.add_argument( + "-grad_acc_steps", + help="gradient accumulation steps", + type=int, + default=1, + ) argp.add_argument("-workers", help="number workers", type=int, default=1) argp.add_argument("-pdir", help="project dir", type=str, required=False) argp.add_argument( @@ -863,6 +869,12 @@ def parse_train_args(): ) argp.add_argument("-epochs", help="train epochs", type=int, required=True) argp.add_argument("-bs", help="batch size", type=int, default=32) + argp.add_argument( + "-grad_acc_steps", + help="gradient accumulation steps", + type=int, + default=1, + ) argp.add_argument("-workers", help="number workers", type=int, default=1) argp.add_argument("-pdir", help="project dir", type=str, required=False) argp.add_argument( @@ -895,6 +907,7 @@ def parse_train_args(): mode="pretrain", num_workers=train_args.workers, batch_size=train_args.bs, + grad_acc_steps=train_args.grad_acc_steps, epochs=train_args.epochs, steps_per_checkpoint=train_args.spc, project_dir=train_args.pdir, @@ -908,6 +921,7 @@ def parse_train_args(): mode="finetune", num_workers=train_args.workers, batch_size=train_args.bs, + grad_acc_steps=train_args.grad_acc_steps, epochs=train_args.epochs, finetune_cp_path=train_args.cpath, steps_per_checkpoint=train_args.spc, @@ -922,6 +936,7 @@ def parse_train_args(): mode="pretrain" if resume_args.resume_mode == "pt" else "finetune", num_workers=resume_args.workers, batch_size=resume_args.bs, + grad_acc_steps=resume_args.grad_acc_steps, epochs=resume_args.epochs, checkpoint_dir=resume_args.cdir, resume_step=resume_args.rstep, diff --git a/amt/utils.py b/amt/utils.py new file mode 100644 index 0000000..2bc369d --- /dev/null +++ b/amt/utils.py @@ -0,0 +1,16 @@ +"""Contains utils.""" + + +def _load_weight(ckpt_path: str, device="cpu"): + if ckpt_path.endswith("safetensors"): + try: + from safetensors.torch import load_file + except ImportError as e: + raise ImportError( + f"Please install safetensors in order to read from the checkpoint: {ckpt_path}" + ) from e + return load_file(ckpt_path, device=device) + else: + import torch + + return torch.load(ckpt_path, map_location=device) diff --git a/requirements-eval.txt b/requirements-eval.txt deleted file mode 100644 index d3ed177..0000000 --- a/requirements-eval.txt +++ /dev/null @@ -1,4 +0,0 @@ -djitw @ git+https://github.com/alex2awesome/djitw.git -librosa -pretty_midi -pyfluidsynth \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9dd4f77..a8785f6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,8 @@ -aria @ git+https://github.com/EleutherAI/aria.git +ariautils @ git+https://github.com/EleutherAI/aria-utils.git torch >= 2.3 torchaudio accelerate -psutil librosa -mido tqdm orjson -mir_eval +mir_eval \ No newline at end of file diff --git a/tests/test_data.py b/tests/test_data.py index 9b14e1e..967609e 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,8 +1,6 @@ import unittest import logging import os -import cProfile -import pstats import torch import torchaudio import matplotlib.pyplot as plt @@ -10,10 +8,7 @@ from amt.data import get_paired_wav_mid_segments, AmtDataset from amt.tokenizer import AmtTokenizer from amt.audio import AudioTransform -from amt.train import get_dataloaders -from aria.data.midi import MidiDict - -from torch.utils.data import DataLoader +from ariautils.midi import MidiDict logging.basicConfig(level=logging.INFO) @@ -29,7 +24,7 @@ def plot_spec( onsets: list = [], offsets: list = [], ): - # mel tensor dimensions [height, width] + # mel: [height, width] height, width = mel.shape fig_width, fig_height = width // 100, height // 100 @@ -38,21 +33,21 @@ def plot_spec( mel, aspect="auto", origin="lower", cmap="viridis", interpolation="none" ) - line_width_in_points = 1 / 100 * 72 # Convert pixel width to points + line_width_in_points = 1 / 100 * 72 for x in onsets: plt.axvline( x=x, color="red", alpha=0.5, - linewidth=line_width_in_points, # setting the correct line width + linewidth=line_width_in_points, ) for x in offsets: plt.axvline( x=x, color="purple", alpha=0.5, - linewidth=line_width_in_points, # setting the correct line width + linewidth=line_width_in_points, ) plt.axis("off") @@ -77,28 +72,10 @@ def test_wav_mid_segments(self): f"tests/test_results/{idx}.wav", wav.unsqueeze(0), 16000 ) print(idx) - tokenizer._detokenize_midi_dict(seq, 30000).to_midi().save( + tokenizer.detokenize(seq, 30000).to_midi().save( f"tests/test_results/{idx}.mid" ) - def test_new_wav(self): - from amt.data import ( - get_wav_segments, - get_mid_segments, - get_paired_wav_mid_segments, - ) - - for idx, segs in enumerate( - get_paired_wav_mid_segments( - audio_path="/home/loubb/work/aria-amt/data/audio.mp3", - mid_path="/home/loubb/work/aria-amt/data/audio.mid", - stride_factor=3, - pad_last=True, - ) - ): - a, b = segs - print(a.shape, len(b)) - class TestAmtDataset(unittest.TestCase): def test_build(self): @@ -118,13 +95,11 @@ def test_build(self): tokenizer = AmtTokenizer() for idx, (wav, src, tgt, idx) in enumerate(dataset): print(wav.shape, src.shape, tgt.shape) - src_decoded = tokenizer.decode(src) - tgt_decoded = tokenizer.decode(tgt) + src_decoded = tokenizer.decode(src.tolist()) + tgt_decoded = tokenizer.decode(tgt.tolist()) self.assertListEqual(src_decoded[1:], tgt_decoded[:-1]) - mid = tokenizer._detokenize_midi_dict( - src_decoded, len_ms=30000 - ).to_midi() + mid = tokenizer.detokenize(src_decoded, len_ms=30000).to_midi() mid.save(f"tests/test_results/trunc_{idx}.mid") def test_build_multiple(self): @@ -166,11 +141,13 @@ def test_maestro(self): dataset = AmtDataset(load_paths=MAESTRO_PATH) print(f"Dataset length: {len(dataset)}") for idx, (wav, src, tgt, __idx) in enumerate(dataset): - src_dec, tgt_dec = tokenizer.decode(src), tokenizer.decode(tgt) + src_dec, tgt_dec = tokenizer.decode(src.tolist()), tokenizer.decode( + tgt.tolist() + ) if idx % 7 == 0 and idx < 100: print(idx) - src_mid_dict = tokenizer._detokenize_midi_dict( + src_mid_dict = tokenizer.detokenize( src_dec, len_ms=30000, ) @@ -194,35 +171,8 @@ def test_maestro(self): for src_tok, tgt_tok in zip(src_dec[1:], tgt_dec): self.assertEqual(src_tok, tgt_tok) - def test_tensor_pitch_aug(self): - tokenizer = AmtTokenizer() - audio_transform = AudioTransform() - dataset = AmtDataset(load_paths=MAESTRO_PATH) - tensor_pitch_aug = AmtTokenizer().export_tensor_pitch_aug() - - dataloader = DataLoader( - dataset, - batch_size=4, - num_workers=1, - shuffle=False, - ) - - for batch in dataloader: - wav, src, tgt, idxs = batch - - src_p = tensor_pitch_aug(seq=src.clone(), shift=1)[0] - src_p_dec = tokenizer.decode(src_p) - - src_np = src.clone()[0] - src_np_dec = tokenizer.decode(src_np) - - for x, y in zip(src_p_dec, src_np_dec): - if x == "
": - break - else: - print(x, y) - +# TODO: Port these over to new spectrogram format (audio transform) class TestAug(unittest.TestCase): def test_spec(self): SAMPLE_RATE, CHUNK_LEN = 16000, 30 @@ -246,20 +196,20 @@ def test_spec(self): torchaudio.save("tests/test_results/shift.wav", shift_wav, SAMPLE_RATE) def test_pitch_aug(self): - tokenizer = AmtTokenizer(return_tensors=True) + tokenizer = AmtTokenizer() tensor_pitch_aug_fn = tokenizer.export_tensor_pitch_aug() mid_dict = MidiDict.from_midi("tests/test_data/maestro2.mid") - seq = tokenizer._tokenize_midi_dict(mid_dict, 0, 30000) - src = tokenizer.encode(tokenizer.trunc_seq(seq, 4096)) - tgt = tokenizer.encode(tokenizer.trunc_seq(seq[1:], 4096)) + seq = tokenizer.tokenize(mid_dict, 0, 30000) + src = torch.tensor(tokenizer.encode(tokenizer.trunc_seq(seq, 4096))) + tgt = torch.tensor(tokenizer.encode(tokenizer.trunc_seq(seq[1:], 4096))) src = torch.stack((src, src, src)) tgt = torch.stack((tgt, tgt, tgt)) src_aug = tensor_pitch_aug_fn(src.clone(), shift=1) tgt_aug = tensor_pitch_aug_fn(tgt.clone(), shift=1) - src_aug_dec = tokenizer.decode(src_aug[1]) - tgt_aug_dec = tokenizer.decode(tgt_aug[2]) + src_aug_dec = tokenizer.decode(src_aug[1].tolist()) + tgt_aug_dec = tokenizer.decode(tgt_aug[2].tolist()) print(seq[:20]) print(src_aug_dec[:20]) print(tgt_aug_dec[:20]) @@ -298,7 +248,7 @@ def test_mels(self): audio_transform = AudioTransform() SAMPLE_RATE, N_FFT, CHUNK_LEN = ( audio_transform.sample_rate, - audio_transform.n_fft, + 1, 30, ) wav, sr = torchaudio.load("tests/test_data/maestro.wav") @@ -306,17 +256,6 @@ def test_mels(self): 0, keepdim=True )[:, : SAMPLE_RATE * CHUNK_LEN] - # tokenizer = AmtTokenizer() - # mid_dict = MidiDict.from_midi("tests/test_data/maestro-test.mid") - # seq = tokenizer._tokenize_midi_dict(mid_dict, 0, 30000, 10000) - # mid_dict = tokenizer._detokenize_midi_dict(seq, 30000) - # onsets = [msg["data"]["start"] // 10 for msg in mid_dict.note_msgs] - # offsets = [ - # msg["data"]["end"] // 10 - # for msg in mid_dict.note_msgs - # if msg["data"]["end"] < 30000 - # ] - wavs = torch.stack((wav[0], wav[0], wav[0])) mels = audio_transform(wavs) for idx in range(mels.shape[0]): @@ -387,27 +326,5 @@ def test_noise(self): torchaudio.save("tests/test_results/noise.wav", res, SAMPLE_RATE) -class TestDataLoader(unittest.TestCase): - def load_data(self, dataloader, num_batches=100): - for idx, data in enumerate(dataloader): - if idx >= num_batches: - break - - def test_profile_dl(self): - train_dataloader, val_dataloader = get_dataloaders( - train_data_paths="/weka/proj-aria/aria-amt/data/train.jsonl", - val_data_path="/weka/proj-aria/aria-amt/data/train.jsonl", - batch_size=16, - num_workers=0, - ) - - profiler = cProfile.Profile() - profiler.enable() - self.load_data(train_dataloader, num_batches=10) - profiler.disable() - stats = pstats.Stats(profiler).sort_stats("cumulative") - stats.print_stats() - - if __name__ == "__main__": unittest.main() diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 1148c0c..2f8b980 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -4,7 +4,7 @@ import os from amt.tokenizer import AmtTokenizer -from aria.data.midi import MidiDict +from ariautils.midi import MidiDict logging.basicConfig(level=logging.INFO) if os.path.isdir("tests/test_results") is False: @@ -19,10 +19,10 @@ def _tokenize_detokenize(mid_name: str, start: int, end: int): midi_dict = MidiDict.from_midi(f"tests/test_data/{mid_name}") logging.info(f"tokenizing {mid_name} in range ({start}, {end})...") - tokenized_seq = tokenizer._tokenize_midi_dict(midi_dict, start, end) + tokenized_seq = tokenizer.tokenize(midi_dict, start, end) tokenized_seq = tokenizer.decode(tokenizer.encode(tokenized_seq)) self.assertTrue(tokenizer.unk_tok not in tokenized_seq) - _midi_dict = tokenizer._detokenize_midi_dict(tokenized_seq, length) + _midi_dict = tokenizer.detokenize(tokenized_seq, length) _mid = _midi_dict.to_midi() _mid.save(f"tests/test_results/{start}_{end}_{mid_name}") @@ -43,7 +43,7 @@ def test_eos_tok(self): cnt = 0 while True: - seq = tokenizer._tokenize_midi_dict( + seq = tokenizer.tokenize( midi_dict, start_ms=cnt * 10000, end_ms=(cnt * 10000) + 30000 ) if len(seq) <= 2: @@ -53,40 +53,40 @@ def test_eos_tok(self): cnt += 1 def test_pitch_aug(self): - tokenizer = AmtTokenizer(return_tensors=True) + tokenizer = AmtTokenizer() tensor_pitch_aug = tokenizer.export_tensor_pitch_aug() midi_dict_1 = MidiDict.from_midi("tests/test_data/maestro1.mid") midi_dict_2 = MidiDict.from_midi("tests/test_data/maestro2.mid") midi_dict_3 = MidiDict.from_midi("tests/test_data/maestro3.mid") - seq_1 = tokenizer._tokenize_midi_dict(midi_dict_1, 0, 30000) + seq_1 = tokenizer.tokenize(midi_dict_1, 0, 30000) seq_1 = tokenizer.trunc_seq(seq_1, 2048) seq_2 = tokenizer.trunc_seq( - tokenizer._tokenize_midi_dict(midi_dict_2, 0, 30000), 2048 + tokenizer.tokenize(midi_dict_2, 0, 30000), 2048 ) seq_2 = tokenizer.trunc_seq(seq_2, 2048) seq_3 = tokenizer.trunc_seq( - tokenizer._tokenize_midi_dict(midi_dict_3, 0, 30000), 2048 + tokenizer.tokenize(midi_dict_3, 0, 30000), 2048 ) seq_3 = tokenizer.trunc_seq(seq_3, 2048) seqs = torch.stack( ( - tokenizer.encode(seq_1), - tokenizer.encode(seq_2), - tokenizer.encode(seq_3), + torch.tensor(tokenizer.encode(seq_1)), + torch.tensor(tokenizer.encode(seq_2)), + torch.tensor(tokenizer.encode(seq_3)), ) ) aug_seqs = tensor_pitch_aug(seqs, shift=2) - midi_dict_1_aug = tokenizer._detokenize_midi_dict( - tokenizer.decode(aug_seqs[0]), 30000 + midi_dict_1_aug = tokenizer.detokenize( + tokenizer.decode(aug_seqs[0].tolist()), 30000 ) - midi_dict_2_aug = tokenizer._detokenize_midi_dict( - tokenizer.decode(aug_seqs[1]), 30000 + midi_dict_2_aug = tokenizer.detokenize( + tokenizer.decode(aug_seqs[1].tolist()), 30000 ) - midi_dict_3_aug = tokenizer._detokenize_midi_dict( - tokenizer.decode(aug_seqs[2]), 30000 + midi_dict_3_aug = tokenizer.detokenize( + tokenizer.decode(aug_seqs[2].tolist()), 30000 ) midi_dict_1_aug.to_midi().save("tests/test_results/pitch1.mid") midi_dict_2_aug.to_midi().save("tests/test_results/pitch2.mid") @@ -94,7 +94,7 @@ def test_pitch_aug(self): def test_aug(self): def aug(_midi_dict: MidiDict, _start_ms: int, _end_ms: int): - _tokenized_seq = tokenizer._tokenize_midi_dict( + _tokenized_seq = tokenizer.tokenize( midi_dict=_midi_dict, start_ms=_start_ms, end_ms=_end_ms, @@ -117,15 +117,9 @@ def aug(_midi_dict: MidiDict, _start_ms: int, _end_ms: int): ) self.assertEqual( + len(tokenizer.detokenize(tokenized_seq, DELTA_MS).note_msgs), len( - tokenizer._detokenize_midi_dict( - tokenized_seq, DELTA_MS - ).note_msgs - ), - len( - tokenizer._detokenize_midi_dict( - aug_tokenized_seq, DELTA_MS - ).note_msgs + tokenizer.detokenize(aug_tokenized_seq, DELTA_MS).note_msgs ), ) @@ -134,15 +128,11 @@ def aug(_midi_dict: MidiDict, _start_ms: int, _end_ms: int): f"msg mixup: {tokenized_seq} ->\n{aug_tokenized_seq}" ) - _midi_dict = tokenizer._detokenize_midi_dict( - tokenized_seq, DELTA_MS - ) + _midi_dict = tokenizer.detokenize(tokenized_seq, DELTA_MS) _mid = _midi_dict.to_midi() _mid.save(f"tests/test_results/maestro2_orig.mid") - _midi_dict = tokenizer._detokenize_midi_dict( - aug_tokenized_seq, DELTA_MS - ) + _midi_dict = tokenizer.detokenize(aug_tokenized_seq, DELTA_MS) _mid = _midi_dict.to_midi() _mid.save(f"tests/test_results/maestro2_aug.mid")