diff --git a/amt/assets/mel_filters.npz b/amt/assets/mel_filters.npz index 28ea269..c57535f 100644 Binary files a/amt/assets/mel_filters.npz and b/amt/assets/mel_filters.npz differ diff --git a/amt/audio.py b/amt/audio.py index f9bc27e..d03b1a3 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -116,7 +116,7 @@ def mel_filters(device, n_mels: int) -> torch.Tensor: mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), ) """ - assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" + assert n_mels in {80, 128, 256}, f"Unsupported n_mels: {n_mels}" filters_path = os.path.join( os.path.dirname(__file__), "assets", "mel_filters.npz" @@ -127,7 +127,7 @@ def mel_filters(device, n_mels: int) -> torch.Tensor: def log_mel_spectrogram( audio: Union[str, np.ndarray, torch.Tensor], - n_mels: int = 80, + n_mels: int = 256, padding: int = 0, device: Optional[Union[str, torch.device]] = None, ): diff --git a/amt/data.py b/amt/data.py index ec90141..f10bbaa 100644 --- a/amt/data.py +++ b/amt/data.py @@ -1,11 +1,9 @@ import mmap import os -import logging -import json -import jsonlines +import shutil +import orjson import torch -from typing import Callable from multiprocessing import Pool from aria.data.midi import MidiDict @@ -17,36 +15,19 @@ N_FRAMES, ) -config = load_config()["data"] -STRIDE_FACTOR = config["stride_factor"] +config = load_config() +STRIDE_FACTOR = config["data"]["stride_factor"] -def setup_logger(): - # Get logger and reset all handlers - logger = logging.getLogger(__name__) - for h in logger.handlers[:]: - logger.removeHandler(h) - - logger.propagate = False - logger.setLevel(logging.INFO) - formatter = logging.Formatter( - "[%(asctime)s] %(name)s: [%(levelname)s] %(message)s", - ) - - ch = logging.StreamHandler() - ch.setLevel(logging.INFO) - ch.setFormatter(formatter) - logger.addHandler(ch) - - return logger - - -def get_features(audio_path: str, mid_path: str = ""): +def get_features( + audio_path: str, mid_path: str = "", return_json: bool = False +): """This function yields tuples of matched log mel spectrograms and tokenized sequences (np.array, list). If it is given only an audio path then it will return an empty list for the mid_feature """ tokenizer = AmtTokenizer() + n_mels = config["audio"]["n_mels"] if not os.path.isfile(audio_path): return None @@ -57,7 +38,7 @@ def get_features(audio_path: str, mid_path: str = ""): return None try: - log_spec = log_mel_spectrogram(audio=audio_path) + log_spec = log_mel_spectrogram(audio=audio_path, n_mels=n_mels) if mid_path != "": midi_dict = MidiDict.from_midi(mid_path) else: @@ -79,19 +60,37 @@ def get_features(audio_path: str, mid_path: str = ""): else: mid_feature = [] + if return_json is True: + audio_feature = audio_feature.tolist() + res.append((audio_feature, mid_feature)) return res -def get_features_mp(args): - """Multiprocessing wrapper for get_features""" - res = get_features(*args) +def write_features(args): + audio_path, mid_path, save_path = args + features = get_features( + audio_path=audio_path, + mid_path=mid_path, + return_json=False, + ) + dirname, basename = os.path.split(save_path) + proc_save_path = os.path.join(dirname, str(os.getpid()) + basename) + + with open(proc_save_path, mode="ab") as file: + for mel, seq in features: + file.write( + orjson.dumps( + mel.numpy(), + option=orjson.OPT_SERIALIZE_NUMPY, + ) + ) + file.write(b"\n") + file.write(orjson.dumps(seq)) + file.write(b"\n") - if res is None: - return False, None - else: - return True, res + return proc_save_path class AmtDataset(torch.utils.data.Dataset): @@ -127,9 +126,9 @@ def _format(tok): self.file_mmap.seek(self.index[idx]) # Load data from line - spec, _seq = json.loads(self.file_mmap.readline()) + mel = torch.tensor(orjson.loads(self.file_mmap.readline())) + _seq = orjson.loads(self.file_mmap.readline()) - spec = torch.tensor(spec) # Format spectrogram into tensor _seq = [_format(tok) for tok in _seq] # Format seq _seq = self.aug_fn(_seq) # Data augmentation @@ -142,15 +141,15 @@ def _format(tok): seq_len=self.config["max_seq_len"], ) - return spec, self.tokenizer.encode(src), self.tokenizer.encode(tgt) + return mel, self.tokenizer.encode(src), self.tokenizer.encode(tgt) def _build_index(self): self.file_mmap.seek(0) index = [] while True: pos = self.file_mmap.tell() - line_buffer = self.file_mmap.readline() - if line_buffer == b"": + self.file_mmap.readline() + if self.file_mmap.readline() == b"": break else: index.append(pos) @@ -162,33 +161,33 @@ def build( cls, matched_load_paths: list[tuple[str, str]], save_path: str, - num_processes: int = 4, + num_processes: int = 1, ): - def _get_features(_matched_load_paths: list): - num_paths = len(_matched_load_paths) - for idx, entry in enumerate(_matched_load_paths): - success, res = get_features_mp(entry) + assert os.path.isfile(save_path) is False, f"{save_path} already exists" + num_paths = len(matched_load_paths) + with Pool(processes=num_processes) as pool: + sharded_save_paths = [] + res = pool.imap_unordered( + write_features, + ((ap, mp, save_path) for ap, mp in matched_load_paths), + ) + for idx, proc_save_path in enumerate(res): if idx % 10 == 0 and idx != 0: - print(f"Processed audio-mid pairs: {idx}/{num_paths}") - if success == False: - continue - for _audio_feature, _mid_feature in res: - yield _audio_feature.tolist(), _mid_feature - - # MP CODE DOESN'T WORK FOR SOME REASON !! - - # with Pool(num_processes) as pool: - # results = pool.imap(get_features_mp, _matched_load_paths) - # num_paths = len(_matched_load_paths) - # for idx, (success, res) in enumerate(results): - # if idx % 10 == 0 and idx != 0: - # print(f"Processed audio-mid pairs: {idx}/{num_paths}") - - # if success == False: - # continue - # for _audio_feature, _mid_feature in res: - # yield _audio_feature.tolist(), _mid_feature - - with jsonlines.open(save_path, mode="w") as writer: - for audio_feature, mid_feature in _get_features(matched_load_paths): - writer.write([audio_feature, mid_feature]) + print(f"Finished {idx}/{num_paths}") + if proc_save_path not in sharded_save_paths: + sharded_save_paths.append(proc_save_path) + + # This is bad, however cat is fast + if shutil.which("cat") is None: + print("The GNU cat command is not available") + else: + print("Concatinating sharded dataset files") + shell_cmd = f"cat " + for _path in sharded_save_paths: + shell_cmd += f"{_path} " + print() + shell_cmd += f">> {save_path}" + + os.system(shell_cmd) + for _path in sharded_save_paths: + os.remove(_path) diff --git a/amt/inference.py b/amt/inference.py index 83fd61d..31eca8f 100644 --- a/amt/inference.py +++ b/amt/inference.py @@ -17,6 +17,9 @@ # sort of branching to make sure that we don't miss notes, ect... Implement this # next week -- Exciting problem (checkout other inference algos) +# Implement maximum note len =5s +# Implement either beam search or decoding initial onset note on first + def greedy_sample( model: AmtEncoderDecoder, @@ -38,6 +41,7 @@ def _process_segment( audio_seg = audio_seg.unsqueeze(0).to(device) seq = tokenizer.encode(tokenizer.trunc_seq(prefix, MAX_SEQ_LEN)) seq = torch.tensor(seq).unsqueeze(0).to(device) + audio_feature = model.embed_audio(mel=audio_seg) for idx in ( pbar := tqdm( @@ -46,21 +50,14 @@ def _process_segment( leave=False, ) ): - logits = model.forward(mel=audio_seg, tokens=seq[:, :idx]) + logits = model.logits( + audio_features=audio_feature, tokens=seq[:, :idx] + ) next_tok_id = torch.argmax(logits[0, -1], dim=-1) - # probs = torch.softmax(logits[0, -1], dim=-1) - # next_tok_id = torch.argmax(probs, dim=-1) - - # Debug logging: - # print(f"input seq shape: {seq[:, :idx].shape}") - # print(f"logits shape: {logits.shape}") - # print(f"probs shape: {probs.shape}") - # print(int(next_tok_id), tokenizer.id_to_tok[int(next_tok_id)]) + seq[0, idx] = next_tok_id if next_tok_id == pad_id or next_tok_id == eos_id: break - else: - seq[0, idx] = next_tok_id if idx == MAX_SEQ_LEN - 2: print("WARNING: Ran out of context when generating sequence") @@ -81,7 +78,7 @@ def _process_segment( model.eval() tokenizer = AmtTokenizer() _unclosed_notes = [] - concat_seq = [] + concat_seq = [tokenizer.bos_tok] _onset_adj = 0 for idx, _audio_seg in enumerate(audio_segments): _seq = [("prev", p) for p in _unclosed_notes] + [tokenizer.bos_tok] @@ -99,14 +96,17 @@ def _process_segment( __midi = __midi_dict.to_midi() __midi.save(f"/weka/proj-aria/aria-amt/samples/res{idx}.mid") - print(f"Done {idx}/{len(audio_segments)}:\n{_seq}") - + print(f"Done {idx + 1}/{len(audio_segments)}") for tok in _seq: if type(tok) is tuple and tok[0] == "onset": _onset_orig = tok[1] _onset_adj = _onset_orig + (idx * LEN_MS) concat_seq.append(("onset", _onset_adj)) - elif tok is tokenizer.pad_tok: + elif type(tok) is tuple and tok[0] == "prev": + continue + elif tok is tokenizer.bos_tok: + continue + elif tok is tokenizer.pad_tok or tok is tokenizer.eos_tok: break else: concat_seq.append(tok) diff --git a/amt/run.py b/amt/run.py index 0c164dd..ec0eb9f 100644 --- a/amt/run.py +++ b/amt/run.py @@ -40,12 +40,15 @@ def build_maestro(args): assert os.path.isdir(args.dir), "MAESTRO directory not found" assert os.path.isfile(args.csv), "MAESTRO csv not found" - if ( - os.path.isfile(args.train) - or os.path.isfile(args.val) - or os.path.isfile(args.test) - ): - print("Dataset files already exist - overwriting") + if os.path.isfile(args.train): + print(f"Dataset file already exists at {args.train} - removing") + os.remove(args.train) + if os.path.isfile(args.val): + print(f"Dataset file already exists at {args.val} - removing") + os.remove(args.val) + if os.path.isfile(args.test): + print(f"Dataset file already exists at {args.test} - removing") + os.remove(args.test) matched_paths_train = [] matched_paths_val = [] diff --git a/amt/tokenizer.py b/amt/tokenizer.py index 180ab6c..3a8d58b 100644 --- a/amt/tokenizer.py +++ b/amt/tokenizer.py @@ -241,7 +241,6 @@ def _detokenize_midi_dict( if DEBUG: raise Exception else: - notes_to_close[tok_1_data] = (tok_2_data, tok_3_data) elif tok_1_type == "off": if tok_2_type != "onset": @@ -322,6 +321,9 @@ def export_data_aug(self): def export_msg_mixup(self): def msg_mixup(src: list): + def round_to_base(n, base=150): + return base * round(n / base) + # Process bos, eos, and pad tokens orig_len = len(src) seen_pad_tok = False diff --git a/amt/train.py b/amt/train.py index a6a846e..ed473bb 100644 --- a/amt/train.py +++ b/amt/train.py @@ -138,9 +138,9 @@ def _get_optim( optimizer = torch.optim.AdamW( model.parameters(), lr=lr, - weight_decay=0.01, - betas=(0.9, 0.95), - eps=1e-5, + weight_decay=0.1, + betas=(0.9, 0.98), + eps=1e-6, ) warmup_lrs = torch.optim.lr_scheduler.LinearLR( @@ -365,6 +365,9 @@ def train_loop( # Backwards step accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() optimizer.zero_grad() if scheduler: diff --git a/config/config.json b/config/config.json index 0e77778..be4cfd7 100644 --- a/config/config.json +++ b/config/config.json @@ -11,12 +11,13 @@ }, "audio": { "sample_rate": 16000, - "n_fft": 400, + "n_fft": 2048, "hop_len": 160, - "chunk_len": 30 + "chunk_len": 30, + "n_mels": 256 }, "data": { - "stride_factor": 3, + "stride_factor": 1, "max_seq_len": 4096 } } \ No newline at end of file diff --git a/config/models/medium.json b/config/models/medium.json index 9d93f9b..45c0de6 100644 --- a/config/models/medium.json +++ b/config/models/medium.json @@ -1,5 +1,5 @@ { - "n_mels": 80, + "n_mels": 256, "n_audio_ctx": 1500, "n_audio_state": 512, "n_audio_head": 8, diff --git a/config/models/small.json b/config/models/small.json index fd29fa3..1c87733 100644 --- a/config/models/small.json +++ b/config/models/small.json @@ -1,5 +1,5 @@ { - "n_mels": 80, + "n_mels": 256, "n_audio_ctx": 1500, "n_audio_state": 384, "n_audio_head": 6, diff --git a/config/models/test.json b/config/models/test.json index 5ad3f27..93c0f16 100644 --- a/config/models/test.json +++ b/config/models/test.json @@ -1,5 +1,5 @@ { - "n_mels": 80, + "n_mels": 256, "n_audio_ctx": 1500, "n_audio_state": 64, "n_audio_head": 4, diff --git a/requirements.txt b/requirements.txt index 571a4d0..ebf748f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,4 @@ torchaudio accelerate mido tqdm -jsonlines +orjson diff --git a/tests/test_data.py b/tests/test_data.py index 7e2d85d..4ba3d4a 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,6 +1,7 @@ import unittest import logging import os +import time from amt.data import get_features, AmtDataset from amt.tokenizer import AmtTokenizer @@ -26,7 +27,13 @@ def test_feature_gen(self): class TestAmtDataset(unittest.TestCase): def test_build(self): - matched_paths = [("tests/test_data/147.wav", "tests/test_data/147.mid")] + matched_paths = [ + ("tests/test_data/147.wav", "tests/test_data/147.mid") + for _ in range(3) + ] + if os.path.isfile("tests/test_results/dataset.jsonl"): + os.remove("tests/test_results/dataset.jsonl") + AmtDataset.build( matched_load_paths=matched_paths, save_path="tests/test_results/dataset.jsonl", @@ -53,7 +60,7 @@ def test_maestro(self): dataset = AmtDataset(load_path=MAESTRO_PATH) for idx, (mel, src, tgt) in enumerate(dataset): src_dec, tgt_dec = tokenizer.decode(src), tokenizer.decode(tgt) - if (idx + 1) % 200 == 0: + if (idx + 1) % 100 == 0: break if idx % 7 == 0: src_mid_dict = tokenizer._detokenize_midi_dict(