diff --git a/.gitignore b/.gitignore index bf11599..55cfd9a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,28 @@ +# data files +*.csv +*.json +*.xls +*.xlsx +*.pkl +*.h5 +*.sqlite +*.db +*.dbf +*.shp +*.shx +*.prj +*.cpg +*.xml +*.html +*.htm +*.mid +*.midi +*.wav +*.mp3 + +.idea/ + + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/amt/audio.py b/amt/audio.py index 5151e4e..1ddad3e 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -190,6 +190,10 @@ def __init__( max_snr: int = 50, max_dist_gain: int = 25, min_dist_gain: int = 0, + # ratios for the reduction of the audio quality + distort_ratio: float = 0.2, + reduce_ratio: float = 0.2, + spec_aug_ratio: float = 0.2, ): super().__init__() self.tokenizer = AmtTokenizer() @@ -204,6 +208,10 @@ def __init__( self.chunk_len = self.config["chunk_len"] self.num_samples = self.sample_rate * self.chunk_len + self.dist_ratio = distort_ratio + self.reduce_ratio = reduce_ratio + self.spec_aug_ratio = spec_aug_ratio + # Audio aug impulse_paths = self._get_paths( os.path.join(os.path.dirname(__file__), "assets", "impulse") @@ -313,6 +321,16 @@ def apply_noise(self, wav: torch.tensor): return AF.add_noise(waveform=wav, noise=noise, snr=snr_dbs) + def apply_reduction(self, wav: torch.tensor): + """ + Limit the high-band pass filter, the low-band pass filter and the sample rate + Designed to mimic the effect of recording on a low-quality microphone or phone. + """ + wav = AF.highpass_biquad(wav, self.sample_rate, cutoff_freq=1200) + wav = AF.lowpass_biquad(wav, self.sample_rate, cutoff_freq=1400) + resample_rate = 6000 + return AF.resample(wav, orig_freq=self.sample_rate, new_freq=resample_rate, lowpass_filter_width=3) + def apply_distortion(self, wav: torch.tensor): gain = random.randint(self.min_dist_gain, self.max_dist_gain) colour = random.randint(5, 95) @@ -345,13 +363,20 @@ def shift_spec(self, specs: torch.Tensor, shift: int): return shifted_specs def aug_wav(self, wav: torch.Tensor): - # Only apply distortion in 20% of cases - if random.random() > 0.20: - return self.apply_reverb(self.apply_noise(wav)) - else: - return self.apply_reverb( - self.apply_distortion(self.apply_noise(wav)) - ) + """ + pipeline for audio augmentation: + 1. apply noise + 2. apply distortion (x% of the time) + 3. apply reduction (x% of the time) + 4. apply reverb + """ + + wav = self.apply_noise(wav) + if random.random() < self.dist_ratio: + wav = self.apply_distortion(wav) + if random.random() < self.reduce_ratio: + wav = self.apply_reduction(wav) + return self.apply_reverb(wav) def norm_mel(self, mel_spec: torch.Tensor): log_spec = torch.clamp(mel_spec, min=1e-10).log10() @@ -374,14 +399,14 @@ def log_mel(self, wav: torch.Tensor, shift: int | None = None): return log_spec def forward(self, wav: torch.Tensor, shift: int = 0): - # Reverb & noise + # noise, distortion, reduction and reverb wav = self.aug_wav(wav) # Spec & pitch shift log_mel = self.log_mel(wav, shift) - # Spec aug in 20% of cases - if random.random() > 0.20: + # Spec aug in 20% of the cases + if random.random() < self.spec_aug_ratio: log_mel = self.spec_aug(log_mel) return log_mel diff --git a/amt/evaluate.py b/amt/evaluate.py new file mode 100644 index 0000000..4b59dd0 --- /dev/null +++ b/amt/evaluate.py @@ -0,0 +1,117 @@ +import glob +from tqdm.auto import tqdm +import pretty_midi +import numpy as np +import mir_eval +import json +import os + +def midi_to_intervals_and_pitches(midi_file_path): + """ + This function reads a MIDI file and extracts note intervals and pitches + suitable for use with mir_eval's transcription evaluation functions. + """ + # Load the MIDI file + midi_data = pretty_midi.PrettyMIDI(midi_file_path) + + # Prepare lists to collect note intervals and pitches + notes = [] + for instrument in midi_data.instruments: + # Skip drum instruments + if not instrument.is_drum: + for note in instrument.notes: + notes.append([note.start, note.end, note.pitch]) + notes = sorted(notes, key=lambda x: x[0]) + notes = np.array(notes) + intervals, pitches = notes[:, :2], notes[:, 2] + intervals -= intervals[0][0] + return intervals, pitches + + +def midi_to_hz(note, shift=0): + """ + Convert MIDI to HZ. + + Shift, if != 0, is subtracted from the MIDI note. + Use "2" for the hFT augmented model transcriptions, else pitches won't match. + """ + # the one used in hFT transformer + return 440.0 * (2.0 ** (note.astype(int) - shift - 69) / 12) + # a = 440 # frequency of A (common value is 440Hz) + # return (a / 32) * (2 ** ((note - 9) / 12)) + + +def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0): + """ + Evaluate the estimated pitches against the reference pitches using mir_eval. + """ + # Evaluate the estimated pitches against the reference pitches + ref_midi_files = glob.glob(f"{ref_dir}/*.mid*") + est_midi_files = glob.glob(f"{est_dir}/*.mid*") + + est_ref_pairs = [] + for est_fpath in est_midi_files: + ref_fpath = os.path.join(ref_dir, os.path.basename(est_fpath)) + if ref_fpath in ref_midi_files: + est_ref_pairs.append((est_fpath, ref_fpath)) + if ref_fpath.replace(".mid", ".midi") in ref_midi_files: + est_ref_pairs.append((est_fpath, ref_fpath.replace(".mid", ".midi"))) + else: + print(f"Reference file not found for {est_fpath} (ref file: {ref_fpath})") + + output_fhandle = open(output_stats_file, "w") if output_stats_file is not None else None + + for est_file, ref_file in tqdm(est_ref_pairs): + ref_intervals, ref_pitches = midi_to_intervals_and_pitches(ref_file) + est_intervals, est_pitches = midi_to_intervals_and_pitches(est_file) + ref_pitches_hz = midi_to_hz(ref_pitches) + est_pitches_hz = midi_to_hz(est_pitches, est_shift) + scores = mir_eval.transcription.evaluate(ref_intervals, ref_pitches_hz, est_intervals, est_pitches_hz) + if output_fhandle is not None: + output_fhandle.write(json.dumps(scores)) + output_fhandle.write("\n") + else: + print(json.dumps(scores, indent=4)) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(usage="evaluate []") + parser.add_argument( + "--est-dir", + type=str, + help="Path to the directory containing either the transcribed MIDI files or WAV files to be transcribed." + ) + parser.add_argument( + "--ref-dir", + type=str, + help="Path to the directory containing the reference files (we'll use gold MIDI for mir_eval, WAV for dtw)." + ) + parser.add_argument( + '--output-stats-file', + default=None, + type=str, help="Path to the file to save the evaluation stats" + ) + + # add mir_eval and dtw subparsers + subparsers = parser.add_subparsers(help="sub-command help") + mir_eval_parse = subparsers.add_parser("run_mir_eval", help="Run standard mir_eval evaluation on MAESTRO test set.") + mir_eval_parse.add_argument('--shift', type=int, default=0, help="Shift to apply to the estimated pitches.") + + # to come + dtw_eval_parse = subparsers.add_parser("run_dtw", help="Run dynamic time warping evaluation on a specified dataset.") + + args = parser.parse_args() + if not hasattr(args, "command"): + parser.print_help() + print("Unrecognized command") + exit(1) + + # todo: should we add an option to run transcription again every time we wish to evaluate? + # that way, we can run both tests with a range of different audio augmentations right here. + # -> We expect that baseline methods will fall flat on these, while aria-amt will be OK. + + if args.command == "run_mir_eval": + evaluate_mir_eval(args.est_dir, args.ref_dir, args.output_stats_file, args.shift) + elif args.command == "run_dtw": + pass diff --git a/amt/infer.py b/amt/infer.py index 5f4dba1..b471b08 100644 --- a/amt/infer.py +++ b/amt/infer.py @@ -23,7 +23,6 @@ # TODO: Profile and fix gpu util - def calculate_vel( logits: torch.Tensor, init_vel: int, @@ -88,7 +87,23 @@ def calculate_onset( return tokenizer.tok_to_id[("onset", new_onset)] -@torch.autocast("cuda", dtype=torch.bfloat16) +from functools import wraps +from torch.cuda import is_bf16_supported +def optional_bf16_autocast(func): + @wraps(func) + def wrapper(*args, **kwargs): + # Assuming 'check_bfloat16_support()' returns True if bfloat16 is supported + if is_bf16_supported(): + with torch.autocast("cuda", dtype=torch.bfloat16): + return func(*args, **kwargs) + else: + # Call the function with float16 if bfloat16 is not supported + with torch.autocast("cuda", dtype=torch.float16): + return func(*args, **kwargs) + return wrapper + + +@optional_bf16_autocast def process_segments( tasks: list, model: AmtEncoderDecoder, diff --git a/amt/run.py b/amt/run.py index 2199236..be33161 100644 --- a/amt/run.py +++ b/amt/run.py @@ -9,66 +9,60 @@ # TODO: Implement a way of inferring the tokenizer name automatically -def _parse_maestro_args(): - argp = argparse.ArgumentParser(prog="amt maestro") - argp.add_argument("dir", help="MAESTRO directory path") - argp.add_argument("csv", help="MAESTRO csv path") - argp.add_argument("-train", help="train save path", required=True) - argp.add_argument("-val", help="val save path", required=True) - argp.add_argument("-test", help="test save path", required=True) - argp.add_argument( +def _add_maestro_args(subparser): + subparser.add_argument("dir", help="MAESTRO directory path") + subparser.add_argument("csv", help="MAESTRO csv path") + subparser.add_argument("-train", help="train save path", required=True) + subparser.add_argument("-val", help="val save path", required=True) + subparser.add_argument("-test", help="test save path", required=True) + subparser.add_argument( "-mp", help="number of processes to use", type=int, required=False, ) - return argp.parse_args(sys.argv[2:]) - -def _parse_transcribe_args(): - argp = argparse.ArgumentParser(prog="amt transcribe") - argp.add_argument("model_name", help="name of model config file") - argp.add_argument("cp", help="checkpoint path") - argp.add_argument("-load_path", help="path to mp3/wav file", required=False) - argp.add_argument( +def _add_transcribe_args(subparser): + subparser.add_argument("model_name", help="name of model config file") + subparser.add_argument('checkpoint_path', help="checkpoint path") + subparser.add_argument("-load_path", help="path to mp3/wav file", required=False) + subparser.add_argument( "-load_dir", help="dir containing mp3/wav files", required=False ) - argp.add_argument("-save_dir", help="dir to save midi files", required=True) - argp.add_argument( + subparser.add_argument("-save_dir", help="dir to save midi files", required=True) + subparser.add_argument( "-multi_gpu", help="use all GPUs", action="store_true", default=False ) - argp.add_argument("-bs", help="batch size", type=int, default=16) - - return argp.parse_args(sys.argv[2:]) + subparser.add_argument("-bs", help="batch size", type=int, default=16) -def build_maestro(args): +def build_maestro(maestro_dir, maestro_csv_file, train_file, val_file, test_file, num_procs): from amt.data import AmtDataset - assert os.path.isdir(args.dir), "MAESTRO directory not found" - assert os.path.isfile(args.csv), "MAESTRO csv not found" - if os.path.isfile(args.train): - print(f"Dataset file already exists at {args.train} - removing") - os.remove(args.train) - if os.path.isfile(args.val): - print(f"Dataset file already exists at {args.val} - removing") - os.remove(args.val) - if os.path.isfile(args.test): - print(f"Dataset file already exists at {args.test} - removing") - os.remove(args.test) + assert os.path.isdir(maestro_dir), "MAESTRO directory not found" + assert os.path.isfile(maestro_csv_file), "MAESTRO csv not found" + if os.path.isfile(train_file): + print(f"Dataset file already exists at {train_file} - removing") + os.remove(train_file) + if os.path.isfile(val_file): + print(f"Dataset file already exists at {val_file} - removing") + os.remove(val_file) + if os.path.isfile(test_file): + print(f"Dataset file already exists at {test_file} - removing") + os.remove(test_file) matched_paths_train = [] matched_paths_val = [] matched_paths_test = [] - with open(args.csv, "r") as f: + with open(maestro_csv_file, "r") as f: dict_reader = DictReader(f) for entry in dict_reader: audio_path = os.path.normpath( - os.path.join(args.dir, entry["audio_filename"]) + os.path.join(maestro_dir, entry["audio_filename"]) ) midi_path = os.path.normpath( - os.path.join(args.dir, entry["midi_filename"]) + os.path.join(maestro_dir, entry["midi_filename"]) ) if not os.path.isfile(audio_path) or not os.path.isfile(audio_path): @@ -86,27 +80,53 @@ def build_maestro(args): else: print("Invalid split") - print(f"Building {args.train}") + print(f"Building {train_file}") AmtDataset.build( matched_load_paths=matched_paths_train, - save_path=args.train, - num_processes=args.mp, + save_path=train_file, + num_processes=num_procs, ) - print(f"Building {args.val}") + print(f"Building {val_file}") AmtDataset.build( matched_load_paths=matched_paths_val, - save_path=args.val, - num_processes=args.mp, + save_path=val_file, + num_processes=num_procs, ) - print(f"Building {args.test}") + print(f"Building {test_file}") AmtDataset.build( matched_load_paths=matched_paths_test, - save_path=args.test, - num_processes=args.mp, + save_path=test_file, + num_processes=num_procs, ) -def transcribe(args): +def transcribe( + model_name, checkpoint_path, save_dir, load_path=None, load_dir=None, + batch_size=16, multi_gpu=False, + augment=None, +): + """ + Transcribe audio files to midi using the given model and checkpoint. + + Parameters + ---------- + model_name : str + Name of the model config file + checkpoint_path : str + Path to the model checkpoint + save_dir : str + Directory to save the transcribed midi files + load_path : str + Name of the audio file to transcribe (if specified, don't specify load_dir) + load_dir : str + Directory containing audio files to transcribe (if specified, don't specify load_path) + batch_size : int + Batch size to use for transcription + multi_gpu : bool + Use all available GPUs for transcription + augment : str + Augment the audio files before transcribing. This is used for evaluation. This tests the robustness of the model. + """ import torch from torch.cuda import is_available as cuda_is_available from amt.tokenizer import AmtTokenizer @@ -116,26 +136,26 @@ def transcribe(args): from aria.utils import _load_weight assert cuda_is_available(), "CUDA device not found" - assert os.path.isfile(args.cp), "model checkpoint file not found" - assert args.load_path or args.load_dir, "must give either load path or dir" - if args.load_path: + assert os.path.isfile(checkpoint_path), "model checkpoint file not found" + assert load_path or load_dir, "must give either load path or dir" + if load_path: assert os.path.isfile( - args.load_path - ), f"audio file not found: {args.load_path}" + load_path + ), f"audio file not found: {load_path}" trans_mode = "single" - if args.load_dir: - assert os.path.isdir(args.load_dir), "load directory doesn't exist" + if load_dir: + assert os.path.isdir(load_dir), "load directory doesn't exist" trans_mode = "batch" - if not os.path.exists(args.save_dir): - os.makedirs(args.save_dir) - assert os.path.isdir(args.save_dir), "save dir doesn't exist" + if not os.path.exists(save_dir): + os.makedirs(save_dir) + assert os.path.isdir(save_dir), "save dir doesn't exist" # Setup model tokenizer = AmtTokenizer() - model_config = ModelConfig(**load_model_config(args.model_name)) + model_config = ModelConfig(**load_model_config(model_name)) model_config.set_vocab_size(tokenizer.vocab_size) model = AmtEncoderDecoder(model_config) - model_state = _load_weight(ckpt_path=args.cp) + model_state = _load_weight(ckpt_path=checkpoint_path) # Fix keys in compiled model checkpoint _model_state = {} @@ -150,17 +170,17 @@ def transcribe(args): if trans_mode == "batch": found_wav = glob.glob( - os.path.join(args.load_dir, "**/*.wav"), recursive=True + os.path.join(load_dir, "**/*.wav"), recursive=True ) found_mp3 = glob.glob( - os.path.join(args.load_dir, "**/*.mp3"), recursive=True + os.path.join(load_dir, "**/*.mp3"), recursive=True ) print(f"Found {len(found_mp3)} mp3 and {len(found_wav)} wav files") file_paths = found_mp3 + found_wav else: - file_paths = [args.load_path] + file_paths = [load_path] - if args.multi_gpu: + if multi_gpu: # Generate chunks gpu_ids = [ int(id) for id in os.getenv("CUDA_VISIBLE_DEVICES").split(",") @@ -185,10 +205,10 @@ def transcribe(args): args=( chunk, model, - args.save_dir, - args.bs, + save_dir, + batch_size, gpu_ids[idx], - args.load_dir, + load_dir, ), ) process.start() @@ -201,33 +221,47 @@ def transcribe(args): batch_transcribe( file_paths=file_paths, model=model, - save_dir=args.save_dir, - batch_size=args.bs, - input_dir=args.load_dir, + save_dir=save_dir, + batch_size=batch_size, + input_dir=load_dir, ) def main(): # Nested argparse inspired by - https://shorturl.at/kuKW0 parser = argparse.ArgumentParser(usage="amt []") - parser.add_argument( - "command", - help="command to run", - choices=("maestro", "transcribe"), - ) + subparsers = parser.add_subparsers(help="sub-command help") + # add maestro and transcribe subparsers + subparser_maestro = subparsers.add_parser("maestro", help="Commands to build the maestro dataset.") + subparser_transcribe = subparsers.add_parser("transcribe", help="Commands to run transcription.") + _add_maestro_args(subparser_maestro) + _add_transcribe_args(subparser_transcribe) - # parse_args defaults to [1:] for args, but you need to - # exclude the rest of the args too, or validation will fail - args = parser.parse_args(sys.argv[1:2]) + args = parser.parse_args() if not hasattr(args, "command"): parser.print_help() print("Unrecognized command") exit(1) elif args.command == "maestro": - build_maestro(args=_parse_maestro_args()) + build_maestro( + maestro_dir=args.dir, + maestro_csv_file=args.csv, + train_file=args.train, + val_file=args.val, + test_file=args.test, + num_procs=args.mp, + ) elif args.command == "transcribe": - transcribe(args=_parse_transcribe_args()) + transcribe( + model_name=args.model_name, + checkpoint_path=args.checkpoint_path, + load_path=args.load_path, + load_dir=args.load_dir, + save_dir=args.save_dir, + batch_size=args.bs, + multi_gpu=args.multi_gpu, + ) else: print("Unrecognized command") parser.print_help() diff --git a/notebooks/2024-03-06__test-alignment-methods.ipynb b/notebooks/2024-03-06__test-alignment-methods.ipynb new file mode 100644 index 0000000..8e4d229 --- /dev/null +++ b/notebooks/2024-03-06__test-alignment-methods.ipynb @@ -0,0 +1,886 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "de9ab3d6-09d8-42ba-aa3a-a892f03f376a", + "metadata": {}, + "source": [ + "# Test how to induce phone effect" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "5864bcb4-6da0-4f22-9515-1395dfa9d56d", + "metadata": {}, + "outputs": [], + "source": [ + "import IPython\n", + "import torchaudio\n", + "import torchaudio.functional as F" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "6c4f9fa0-45a7-4faa-b4b6-a02ca7239deb", + "metadata": {}, + "outputs": [], + "source": [ + "fpath = 'test-transcription/hft-transcribed__02_R1_2004_05_Track05.wav'\n", + "waveform, sample_rate = torchaudio.load(fpath)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c91b8863-de1c-4c09-8106-9c4a8e0ab11c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "44100" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sample_rate" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "183078c0-5db2-423e-831a-2da570bb5830", + "metadata": {}, + "outputs": [], + "source": [ + "# IPython.display.Audio(data=waveform, rate=sample_rate)" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "0623242c-e8aa-4a39-a318-bf61d67490f4", + "metadata": {}, + "outputs": [], + "source": [ + "phone_wav = F.highpass_biquad(waveform, sample_rate, cutoff_freq=1200)\n", + "phone_wav = F.lowpass_biquad(phone_wav, sample_rate, cutoff_freq=1400)\n", + "resample_rate = 6000\n", + "phone_wav = F.resample(phone_wav, orig_freq=sample_rate, new_freq=resample_rate, lowpass_filter_width=3)" + ] + }, + { + "cell_type": "markdown", + "id": "f17b5a6d-968b-4b53-ada1-496f4a13efdf", + "metadata": {}, + "source": [ + "# MIR_EVAL" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "ac2937f2-9852-4478-9820-bbc504b8c24f", + "metadata": {}, + "outputs": [], + "source": [ + "import pretty_midi\n", + "import numpy as np \n", + "import mir_eval\n", + "\n", + "def midi_to_intervals_and_pitches(midi_file_path):\n", + " \"\"\"\n", + " This function reads a MIDI file and extracts note intervals and pitches\n", + " suitable for use with mir_eval's transcription evaluation functions.\n", + " \"\"\"\n", + " # Load the MIDI file\n", + " midi_data = pretty_midi.PrettyMIDI(midi_file_path)\n", + " \n", + " # Prepare lists to collect note intervals and pitches\n", + " notes = []\n", + " for instrument in midi_data.instruments:\n", + " # Skip drum instruments\n", + " if not instrument.is_drum:\n", + " for note in instrument.notes:\n", + " notes.append([note.start, note.end, note.pitch])\n", + " notes = sorted(notes, key=lambda x: x[0])\n", + " notes = np.array(notes)\n", + " intervals, pitches = notes[:, :2], notes[:, 2]\n", + " intervals -= intervals[0][0]\n", + " return intervals, pitches\n", + "\n", + "def midi_to_hz(note, shift=0):\n", + " \"\"\"\n", + " Convert MIDI to HZ.\n", + "\n", + " Shift, if != 0, is subtracted from the MIDI note. Use \"2\" for the hFT augmented model transcriptions, else pitches won't match.\n", + " \"\"\"\n", + " # the one used in hFT transformer\n", + " return 440.0 * (2.0 ** (note.astype(int) - shift - 69) / 12)\n", + " a = 440 # frequency of A (common value is 440Hz)\n", + " # return (a / 32) * (2 ** ((note - 9) / 12))" + ] + }, + { + "cell_type": "markdown", + "id": "f552f9d8-1fe1-4de3-8954-ae41b568a153", + "metadata": {}, + "source": [ + "# Kong's Alignment Method" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "a390d3bf-0816-4b5e-9e85-a9e409b4359b", + "metadata": {}, + "outputs": [], + "source": [ + "import csv\n", + "def get_stats(csv_path):\n", + " \"\"\"Parse aligned results csv file to get results.\n", + "\n", + " Args:\n", + " csv_path: str, aligned result path, e.g., xx_corresp.txt\n", + "\n", + " Returns:\n", + " stat_dict, dict, keys: \n", + " true positive (TP), \n", + " deletion (D), \n", + " insertion (I), \n", + " substitution (S), \n", + " error rate (ER), \n", + " ground truth number (N)\n", + " \"\"\"\n", + " with open(csv_path, 'r') as fr:\n", + " reader = csv.reader(fr, delimiter='\\t')\n", + " lines = list(reader)\n", + "\n", + " lines = lines[1 :]\n", + "\n", + " TP, D, I, S = 0, 0, 0, 0\n", + " align_counter = []\n", + " ref_counter = []\n", + "\n", + " for line in lines:\n", + " line = line[0 : -1]\n", + " [alignID, _, _, alignPitch, _, refID, _, _, refPitch, _] = line\n", + "\n", + " if alignID != '*' and refID != '*':\n", + " if alignPitch == refPitch:\n", + " TP += 1\n", + " else:\n", + " S += 1\n", + "\n", + " if alignID == '*':\n", + " D += 1\n", + "\n", + " if refID == '*':\n", + " I += 1\n", + "\n", + " N = TP + D + S\n", + " ER = (D + I + S) / N\n", + " stat_dict = {'TP': TP, 'D': D, 'I': I, 'S': S, 'ER': ER, 'N': N}\n", + " return stat_dict" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "a82c99f6-eac2-46fc-8c36-3b20c893c26a", + "metadata": {}, + "outputs": [], + "source": [ + "import os \n", + "def align_files(ref_fp, est_fp):\n", + " align_tools_dir = '../../2017_midi_alignment'\n", + " ref_fn = os.path.basename(ref_fp)\n", + " est_fn = os.path.basename(est_fp)\n", + " ref_fn_name, ext = os.path.splitext(ref_fn)\n", + " est_fn_name, ext = os.path.splitext(est_fn)\n", + " \n", + " # Copy MIDI files\n", + " cmd = f'cp \"{ref_fp}\" \"{align_tools_dir}/{ref_fn}\"; '\n", + " cmd += f'cp \"{est_fp}\" \"{align_tools_dir}/{est_fn}\"; '\n", + " print(cmd)\n", + " os.system(cmd)\n", + " \n", + " # Align\n", + " cmd = f'cd {align_tools_dir}; '\n", + " # cmd += f'./MIDIToMIDIAlign.sh {ref_fn_name} {est_fn_name}; '\n", + " cmd += f'./MIDIToMIDIAlign.sh {ref_fn} {est_fn}; '\n", + " print(cmd)\n", + " os.system(cmd)" + ] + }, + { + "cell_type": "markdown", + "id": "8d8f5331-92e3-49ae-8def-6af5834a5a9b", + "metadata": {}, + "source": [ + "# Test Sample MIDI" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35ee30b1-bed9-42f9-8011-086431aa60ae", + "metadata": {}, + "outputs": [], + "source": [ + "from importlib import reload\n", + "import sys\n", + "sys.path.insert(0, '../../aria-dl/hFT-Transformer/evaluation/')\n", + "import transcribe_new_files as t\n", + "import glob\n", + "import aria.utils\n", + "from importlib import reload \n", + "reload(aria)\n", + "import IPython\n", + "\n", + "all_maestro_files = sorted(glob.glob('../../corpus/maestro-v3.0.0/2004/*'))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "51c500b6-20b9-4a4a-b677-8c14d58fd2aa", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "input_wav_file = 'test-transcription/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.wav'\n", + "output_midi_file = 'test-transcription/hft-transcribed__02_R1_2004_05_Track05.midi'\n", + "# t.transcribe_file(input_wav_file, output_midi_file)\n", + "gold_truth_midi_file = 'test-transcription/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44129b9c-4fb1-49c7-8d10-e186395ad8b0", + "metadata": {}, + "outputs": [], + "source": [ + "aria.utils.midi_to_audio(\"test-transcription/hft-transcribed__02_R1_2004_05_Track05.midi\")" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "id": "c94d739a-6583-4ee4-875b-3d0e8493220f", + "metadata": {}, + "outputs": [], + "source": [ + "import IPython\n", + "# IPython.display.Audio(data='test-transcription/hft-transcribed__02_R1_2004_05_Track05.wav', rate=44100)" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "id": "3385a328-0728-4b60-bd1b-e03957ac79b5", + "metadata": {}, + "outputs": [], + "source": [ + "import IPython\n", + "# IPython.display.Audio(data=input_wav_file, rate=44100)" + ] + }, + { + "cell_type": "markdown", + "id": "43746faf-a7bd-45a8-8a0c-4238e68b3d34", + "metadata": {}, + "source": [ + "#### evaluate using mir_eval" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "cf506615-9e3c-4f20-bf08-b12dd74a0670", + "metadata": {}, + "outputs": [], + "source": [ + "ref_intervals, ref_pitches = midi_to_intervals_and_pitches(gold_truth_midi_file)\n", + "est_intervals, est_pitches = midi_to_intervals_and_pitches(output_midi_file)\n", + "\n", + "ref_pitches_hz = midi_to_hz(ref_pitches)\n", + "est_pitches_hz = midi_to_hz(est_pitches, shift=2) ## shift=2 because hFT transcribes 2 notes above, for some reason" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "58c4b8af-4729-4c06-8a63-85198222bb0c", + "metadata": {}, + "outputs": [], + "source": [ + "scores = mir_eval.transcription.evaluate(ref_intervals, ref_pitches_hz, est_intervals, est_pitches_hz)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "34cbfa89-371a-4466-a4b2-ebe1900e5ae9", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "matched_onsets = mir_eval.transcription.match_note_onsets(ref_intervals, est_intervals)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "cf1847d5-d253-4062-b8cf-9ed5142707ae", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([71., 55., 71., 59., 62., 72.])" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ref_pitches[[0, 1, 2, 3, 4, 5]]" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "93c07152-104c-4c6d-8f6c-f7472c31a7ec", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([73., 57., 73., 61., 64., 74.])" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "est_pitches[[0, 1, 2, 3, 4, 5]]" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "187602e4-269d-448d-b300-f7e9087fd7a7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'{\\n \"Precision\": 0.7708092856226754,\\n \"Recall\": 0.7613377248543197,\\n \"F-measure\": 0.7660442291759608,\\n \"Average_Overlap_Ratio\": 0.8455788515638166,\\n \"Precision_no_offset\": 0.9976914197768373,\\n \"Recall_no_offset\": 0.9854319736508741,\\n \"F-measure_no_offset\": 0.9915238034542095,\\n \"Average_Overlap_Ratio_no_offset\": 0.7557460905388416,\\n \"Onset_Precision\": 0.9980761831473643,\\n \"Onset_Recall\": 0.9858120091208513,\\n \"Onset_F-measure\": 0.9919061882607865,\\n \"Offset_Precision\": 0.8117224573553931,\\n \"Offset_Recall\": 0.8017481631618951,\\n \"Offset_F-measure\": 0.806704480275317\\n}'" + ] + }, + "execution_count": 64, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import json\n", + "json.dumps(scores, indent=4)" + ] + }, + { + "cell_type": "markdown", + "id": "94c5ab54-c4be-45a5-92f4-9b99a40ffc50", + "metadata": {}, + "source": [ + "### evaluate using kong's method" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "77bb09e5-e989-4042-8d6d-977c232ac817", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'test-transcription/hft-transcribed__02_R1_2004_05_Track05.midi'" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output_midi_file" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "ba68de3b-b1a6-4f3d-815b-09403862a2a1", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cp \"test-transcription/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi\" \"../../2017_midi_alignment/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi\"; cp \"test-transcription/hft-transcribed__02_R1_2004_05_Track05.midi\" \"../../2017_midi_alignment/hft-transcribed__02_R1_2004_05_Track05.midi\"; \n", + "cd ../../2017_midi_alignment; ./MIDIToMIDIAlign.sh MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi hft-transcribed__02_R1_2004_05_Track05.midi; \n", + "File not found: ./MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi_fmt3x.txt\n", + "File not found: ./MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi_hmm.txt\n", + "File not found: ./MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi_fmt3x.txt\n", + "File not found: ./MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi_fmt3x.txt\n", + "File not found: ./hft-transcribed__02_R1_2004_05_Track05.midi_match.txt\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Assertion failed: (ifs.is_open()), function ReadFile, file Midi_v170101.hpp, line 177.\n", + "./MIDIToMIDIAlign.sh: line 14: 68574 Abort trap: 6 $ProgramFolder/midi2pianoroll 0 $RelCurrentFolder/${I1}\n", + "Assertion failed: (ifs.is_open()), function ReadFile, file Midi_v170101.hpp, line 177.\n", + "./MIDIToMIDIAlign.sh: line 15: 68575 Abort trap: 6 $ProgramFolder/midi2pianoroll 0 $RelCurrentFolder/${I2}\n", + "./MIDIToMIDIAlign.sh: line 17: 68576 Segmentation fault: 11 $ProgramFolder/SprToFmt3x $RelCurrentFolder/${I1}_spr.txt $RelCurrentFolder/${I1}_fmt3x.txt\n", + "Assertion failed: (false), function ReadFile, file Fmt3x_v170225.hpp, line 252.\n", + "./MIDIToMIDIAlign.sh: line 18: 68577 Abort trap: 6 $ProgramFolder/Fmt3xToHmm $RelCurrentFolder/${I1}_fmt3x.txt $RelCurrentFolder/${I1}_hmm.txt\n", + "Assertion failed: (false), function ReadFile, file Hmm_v170225.hpp, line 69.\n", + "./MIDIToMIDIAlign.sh: line 20: 68578 Abort trap: 6 $ProgramFolder/ScorePerfmMatcher $RelCurrentFolder/${I1}_hmm.txt $RelCurrentFolder/${I2}_spr.txt $RelCurrentFolder/${I2}_pre_match.txt 0.001\n", + "Assertion failed: (false), function ReadFile, file Fmt3x_v170225.hpp, line 252.\n", + "./MIDIToMIDIAlign.sh: line 21: 68579 Abort trap: 6 $ProgramFolder/ErrorDetection $RelCurrentFolder/${I1}_fmt3x.txt $RelCurrentFolder/${I1}_hmm.txt $RelCurrentFolder/${I2}_pre_match.txt $RelCurrentFolder/${I2}_err_match.txt 0\n", + "Assertion failed: (false), function ReadFile, file Fmt3x_v170225.hpp, line 252.\n", + "./MIDIToMIDIAlign.sh: line 22: 68580 Abort trap: 6 $ProgramFolder/RealignmentMOHMM $RelCurrentFolder/${I1}_fmt3x.txt $RelCurrentFolder/${I1}_hmm.txt $RelCurrentFolder/${I2}_err_match.txt $RelCurrentFolder/${I2}_realigned_match.txt 0.3\n", + "cp: ./hft-transcribed__02_R1_2004_05_Track05.midi_realigned_match.txt: No such file or directory\n", + "Assertion failed: (false), function ReadFile, file ScorePerfmMatch_v170503.hpp, line 86.\n", + "./MIDIToMIDIAlign.sh: line 25: 68582 Abort trap: 6 $ProgramFolder/MatchToCorresp $RelCurrentFolder/${I2}_match.txt $RelCurrentFolder/${I1}_spr.txt $RelCurrentFolder/${I2}_corresp.txt\n", + "rm: ./hft-transcribed__02_R1_2004_05_Track05.midi_realigned_match.txt: No such file or directory\n", + "rm: ./hft-transcribed__02_R1_2004_05_Track05.midi_err_match.txt: No such file or directory\n", + "rm: ./hft-transcribed__02_R1_2004_05_Track05.midi_pre_match.txt: No such file or directory\n" + ] + } + ], + "source": [ + "align_files(gold_truth_midi_file, output_midi_file)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "434270a9-91c5-4b75-b56f-d2f0cda631ab", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[34mCode\u001b[m\u001b[m/\n", + "LICENCE.txt\n", + "MANUAL.pdf\n", + "MIDI-Unprocessed_02_R1_2009_03-06_ORIG_MID--AUDIO_02_R1_2009_02_R1_2009_04_WAV.mid\n", + "MIDI-Unprocessed_02_R1_2009_03-06_ORIG_MID--AUDIO_02_R1_2009_02_R1_2009_04_WAV_corresp.txt\n", + "MIDI-Unprocessed_02_R1_2009_03-06_ORIG_MID--AUDIO_02_R1_2009_02_R1_2009_04_WAV_match.txt\n", + "MIDI-Unprocessed_02_R1_2009_03-06_ORIG_MID--AUDIO_02_R1_2009_02_R1_2009_04_WAV_spr.txt\n", + "MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi\n", + "\u001b[31mMIDIToMIDIAlign.sh\u001b[m\u001b[m*\n", + "\u001b[31mMusicXMLToMIDIAlign.sh\u001b[m\u001b[m*\n", + "\u001b[34mPrograms\u001b[m\u001b[m/\n", + "Scriabin_op_42_Hf4EIJB4DGc_cut_no_4.mid\n", + "Scriabin_op_42_Hf4EIJB4DGc_cut_no_4_corresp.txt\n", + "Scriabin_op_42_Hf4EIJB4DGc_cut_no_4_match.txt\n", + "Scriabin_op_42_Hf4EIJB4DGc_cut_no_4_spr.txt\n", + "\u001b[31mcompile.sh\u001b[m\u001b[m*\n", + "ex_align1.mid\n", + "\u001b[31mex_align2.mid\u001b[m\u001b[m*\n", + "ex_ref.musx\n", + "ex_ref.pdf\n", + "ex_ref.xml\n", + "ex_ref_fmt3x.txt\n", + "ex_ref_hmm.txt\n", + "hft-transcribed__02_R1_2004_05_Track05.midi\n", + "scriabin_etude_op_42_no_4_dery.mid\n", + "scriabin_etude_op_42_no_4_dery_fmt3x.txt\n", + "scriabin_etude_op_42_no_4_dery_hmm.txt\n", + "scriabin_etude_op_42_no_4_dery_spr.txt\n" + ] + } + ], + "source": [ + "ls ../../2017_midi_alignment/" + ] + }, + { + "cell_type": "markdown", + "id": "4e1f906a-6e8a-46c4-88fe-fadac564ba39", + "metadata": {}, + "source": [ + "# Test Kong's Samples" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "cbcd4410-403e-4565-96ec-c7067b9b4b76", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.insert(0, '..')\n", + "import amt.audio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07e5a978-fc0f-444d-8aeb-851dfda1085a", + "metadata": {}, + "outputs": [], + "source": [ + "audio_transform = amt.audio.AudioTransform()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "365ccc37-2ab3-4f58-89bd-bb972b39075c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Chopin, Frédéric, Études, Op.10, g0hoN6_HDVU.mid\n", + "Handel, George Frideric, Air in E major, HWV 425, bNzVz5byPqk.mid\n", + "Liszt, Franz, Hungarian Rhapsody No.2, S.244_2, LdH1hSWGFGU.mid\n", + "Ravel, Maurice, Jeux d'eau, v-QmwrhO3ec.mid\n" + ] + } + ], + "source": [ + "ls ../../GiantMIDI-Piano/midis_preview/" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "27d20498-f77f-4c34-b368-40a3ccefc871", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import os \n", + "df = pd.read_csv('../../GiantMIDI-Piano/midis_for_evaluation/groundtruth_maestro_giantmidi-piano.csv', sep='\\t')" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "id": "81c4dabd-be84-4d51-82a2-b2426f001ef6", + "metadata": {}, + "outputs": [], + "source": [ + "gt_folder = '../../GiantMIDI-Piano/midis_for_evaluation/ground_truth/'\n", + "giant_midi_folder = '../../GiantMIDI-Piano/midis_for_evaluation/giantmidi-piano/'\n", + "maestro_midi_folder = '../../GiantMIDI-Piano/midis_for_evaluation/maestro/'\n", + "gt_fn, giant_midi_fn, maestro_fn = df[['GroundTruth', 'GiantMIDI-Piano', 'Maestro']].iloc[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "0c8b2c1a-6dbb-4191-9a6f-43982b1dd65f", + "metadata": {}, + "outputs": [], + "source": [ + "gt_fp = os.path.join(gt_folder, gt_fn)\n", + "giant_midi_fp = os.path.join(giant_midi_folder, giant_midi_fn)\n", + "maestro_midi_fp = os.path.join(maestro_midi_folder, maestro_fn)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4313bfc4-ffa4-4182-8f17-5519421aa126", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "import mirdata\n", + "import mido \n", + "mido.MidiFile(filename=gt_fp)" + ] + }, + { + "cell_type": "code", + "execution_count": 160, + "id": "f1b4c7b9-b3b5-48e0-a5cf-43e704c6fb51", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "ref_intervals, ref_pitches = midi_to_intervals_and_pitches(gt_fp)\n", + "est_intervals, est_pitches = midi_to_intervals_and_pitches(maestro_midi_fp)\n", + "\n", + "ref_pitches_hz = midi_to_hz(ref_pitches)\n", + "est_pitches_hz = midi_to_hz(est_pitches)" + ] + }, + { + "cell_type": "code", + "execution_count": 161, + "id": "b619d23b-080a-4ef1-b2f7-b3241a18e9cb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[(1, 1), (2, 2), (23, 23), (67, 66), (71, 70), (336, 385), (677, 709)]" + ] + }, + "execution_count": 161, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mir_eval.transcription.match_notes(\n", + " ref_intervals, ref_pitches_hz, est_intervals, est_pitches_hz\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 164, + "id": "bcb0f59e-e097-455f-90c3-fa28a1b78a91", + "metadata": {}, + "outputs": [], + "source": [ + "# mir_eval.transcription.precision_recall_f1_overlap()\n", + "# mir_eval.transcription.evaluate()" + ] + }, + { + "cell_type": "code", + "execution_count": 165, + "id": "7096c4ac-a857-40d3-b14b-5b91443134fe", + "metadata": {}, + "outputs": [], + "source": [ + "scores = mir_eval.transcription.evaluate(ref_intervals, ref_pitches_hz, est_intervals, est_pitches_hz)" + ] + }, + { + "cell_type": "code", + "execution_count": 166, + "id": "a70c27b3-f9f7-4bd6-8471-4493c9fd8a89", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "OrderedDict([('Precision', 0.008939974457215836),\n", + " ('Recall', 0.008816120906801008),\n", + " ('F-measure', 0.008877615726062145),\n", + " ('Average_Overlap_Ratio', 0.8771661491453748),\n", + " ('Precision_no_offset', 0.04469987228607918),\n", + " ('Recall_no_offset', 0.04408060453400504),\n", + " ('F-measure_no_offset', 0.04438807863031072),\n", + " ('Average_Overlap_Ratio_no_offset', 0.5049151558206666),\n", + " ('Onset_Precision', 0.2771392081736909),\n", + " ('Onset_Recall', 0.27329974811083124),\n", + " ('Onset_F-measure', 0.2752060875079264),\n", + " ('Offset_Precision', 0.4623243933588761),\n", + " ('Offset_Recall', 0.45591939546599497),\n", + " ('Offset_F-measure', 0.45909955611921366)])" + ] + }, + "execution_count": 166, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "scores" + ] + }, + { + "cell_type": "code", + "execution_count": 146, + "id": "273e921e-fdf6-4385-a17a-ab1e43b2f2f2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Precision: 0.0012091898428053204, Recall: 0.0012594458438287153, F-measure: 0.0012338062924120913\n" + ] + } + ], + "source": [ + "import pretty_midi\n", + "import numpy as np\n", + "import mir_eval\n", + "\n", + "def midi_to_intervals_and_pitches(midi_file_path):\n", + " # Load the MIDI file\n", + " midi_data = pretty_midi.PrettyMIDI(midi_file_path)\n", + " \n", + " intervals = []\n", + " pitches = []\n", + " \n", + " for instrument in midi_data.instruments:\n", + " if not instrument.is_drum:\n", + " for note in instrument.notes:\n", + " start_time = note.start\n", + " end_time = note.end\n", + " intervals.append([start_time, end_time])\n", + " pitches.append(note.pitch)\n", + " \n", + " intervals = np.array(intervals)\n", + " pitches = np.array(pitches)\n", + " \n", + " return intervals, pitches\n", + "\n", + "# Load your reference and estimated MIDI files\n", + "ref_intervals, ref_pitches = midi_to_intervals_and_pitches(gt_fp)\n", + "est_intervals, est_pitches = midi_to_intervals_and_pitches(giant_midi_fp)\n", + "ref_pitches_hz = midi_to_hz(ref_pitches)\n", + "est_pitches_hz = midi_to_hz(est_pitches)\n", + "\n", + "# Evaluate using mir_eval\n", + "precision, recall, f_measure, _ = mir_eval.transcription.precision_recall_f1_overlap(\n", + " ref_intervals, ref_pitches_hz, est_intervals, est_pitches_hz\n", + ")\n", + "\n", + "print(f\"Precision: {precision}, Recall: {recall}, F-measure: {f_measure}\")" + ] + }, + { + "cell_type": "markdown", + "id": "3392980e-8d35-4764-a9f4-4d8e322791c5", + "metadata": {}, + "source": [ + "# Try using the GiantMIDI Method" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffc236a1-8837-47d5-a02b-56eded2be3a4", + "metadata": {}, + "outputs": [], + "source": [ + "csv_path = f'{align_tools_dir}/{maestro_fn[: -4]}_corresp.txt'\n", + "maestro_stats = get_stats(csv_path)\n", + "\n", + "csv_path = f'{align_tools_dir}/{giant_midi_fn[: -4]}_corresp.txt'\n", + "giantmidi_stats = get_stats(csv_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 155, + "id": "2128d51d-9845-47ec-871a-2fff47fd9640", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'TP': 780, 'D': 8, 'I': 41, 'S': 6, 'ER': 0.06926952141057935, 'N': 794}" + ] + }, + "execution_count": 155, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "giantmidi_stats" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d045113-f996-4959-a615-9edc909b08a5", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d37949b3-2ab3-426f-827b-7e77fe728a3c", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ceee36a-3022-4c84-9e09-84cbd709cbae", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "dad2f979-56cd-4af6-acf1-64420ab61255", + "metadata": {}, + "outputs": [], + "source": [ + "# IPython.display.Audio(data=phone_wav, rate=resample_rate)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "feaced10-0bc3-44b3-8f9d-1647f203c89c", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/2024-03-07__run-aria-amt-and-evals.ipynb b/notebooks/2024-03-07__run-aria-amt-and-evals.ipynb new file mode 100644 index 0000000..9ff83d0 --- /dev/null +++ b/notebooks/2024-03-07__run-aria-amt-and-evals.ipynb @@ -0,0 +1,309 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2024-03-08T02:01:24.262700Z", + "start_time": "2024-03-08T02:01:23.754126Z" + }, + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mkdir: cannot create directory ‘../amt/assets’: File exists\r\n" + ] + } + ], + "source": [ + "! mkdir '../amt/assets'\n", + "! mkdir '../amt/assets/impulse'\n", + "! mkdir '../amt/assets/noise'" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2024-03-08T03:39:12.094229Z", + "start_time": "2024-03-08T03:39:07.414566Z" + }, + "collapsed": false + }, + "outputs": [], + "source": [ + "import torch\n", + "import os\n", + "import sys\n", + "import subprocess\n", + "\n", + "MODEL_NAME = \"medium\"\n", + "CHECKPOINT_NAME = f\"med-81.safetensors\"" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2024-03-08T03:51:28.491243Z", + "start_time": "2024-03-08T03:51:28.477962Z" + }, + "collapsed": false + }, + "outputs": [], + "source": [ + "if not os.path.isfile(f\"{CHECKPOINT_NAME}\"):\n", + " ! wget https://storage.googleapis.com/aria-checkpoints/amt/{CHECKPOINT_NAME} {CHECKPOINT_NAME}" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2024-03-08T03:39:12.177865Z", + "start_time": "2024-03-08T03:39:12.133447Z" + }, + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import amt.run\n", + "from importlib import reload\n", + "reload(amt.run)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2024-03-08T03:50:35.603392Z", + "start_time": "2024-03-08T03:39:13.021352Z" + }, + "collapsed": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "838861: Getting wav segments\n", + "838861: Finished file 1 - test-transcription/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.wav\n", + "838861: 0 file(s) remaining in queue\n", + "839109: GPU task timeout\n", + "839109: Finished GPU tasks\n" + ] + } + ], + "source": [ + "# model_name, checkpoint_path, save_dir, load_path=None, load_dir=None, batch_size=16, multi_gpu=False\n", + "amt.run.transcribe(\n", + " model_name=MODEL_NAME,\n", + " checkpoint_path=CHECKPOINT_NAME,\n", + " load_path='test-transcription/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.wav',\n", + " save_dir=\"test-transcription/aria-amt-tests\",\n", + " batch_size=1,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "# Evaluate Output" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "ExecuteTime": { + "end_time": "2024-03-08T04:42:50.244198Z", + "start_time": "2024-03-08T04:42:46.774805Z" + }, + "collapsed": false + }, + "outputs": [], + "source": [ + "import amt.evaluate" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2024-03-08T04:04:40.989918Z", + "start_time": "2024-03-08T04:04:40.944807Z" + }, + "collapsed": false + }, + "outputs": [], + "source": [ + "t = 'test-transcription/aria-amt-tests/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.mid'" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "ExecuteTime": { + "end_time": "2024-03-08T04:47:00.700219Z", + "start_time": "2024-03-08T04:46:51.855649Z" + }, + "collapsed": false + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f0b22ca267334ff99664db71cf5bf507", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reload(amt.evaluate)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "ExecuteTime": { + "end_time": "2024-03-08T04:46:00.278645Z", + "start_time": "2024-03-08T04:46:00.086183Z" + }, + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "test-transcription/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi\r\n", + "test-transcription/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.wav\r\n" + ] + } + ], + "source": [ + "ls test-transcription/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_*" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/requirements.txt b/requirements.txt index ebf748f..a74c9cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ accelerate mido tqdm orjson +mir_eval \ No newline at end of file