From afe5ad25ed70648c9056ef9f538468f7aeeced90 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 9 Apr 2024 19:58:29 +0000 Subject: [PATCH] add scripts --- scripts/eval/dedupe.py | 72 +++++++++++++ scripts/eval/dtw.py | 211 ++++++++++++++++++++++++++++++++++++++ scripts/eval/dtw.sh | 5 + scripts/eval/mir.sh | 5 + scripts/eval/prune.py | 91 ++++++++++++++++ scripts/eval/prune.sh | 6 ++ scripts/eval/req-eval.txt | 4 + scripts/eval/split.py | 57 ++++++++++ scripts/eval/split.sh | 4 + 9 files changed, 455 insertions(+) create mode 100644 scripts/eval/dedupe.py create mode 100644 scripts/eval/dtw.py create mode 100644 scripts/eval/dtw.sh create mode 100644 scripts/eval/mir.sh create mode 100644 scripts/eval/prune.py create mode 100644 scripts/eval/prune.sh create mode 100644 scripts/eval/req-eval.txt create mode 100644 scripts/eval/split.py create mode 100644 scripts/eval/split.sh diff --git a/scripts/eval/dedupe.py b/scripts/eval/dedupe.py new file mode 100644 index 0000000..8d67d81 --- /dev/null +++ b/scripts/eval/dedupe.py @@ -0,0 +1,72 @@ +import os +import hashlib +import argparse +import multiprocessing + +from pydub import AudioSegment + + +def hash_audio_file(file_path): + """Hash the audio content of an MP3 file.""" + try: + audio = AudioSegment.from_mp3(file_path) + raw_data = audio.raw_data + except Exception as e: + print(e) + return file_path, -1 + else: + return file_path, hashlib.sha256(raw_data).hexdigest() + + +def find_duplicates(root_dir): + """Find and remove duplicate MP3 files in the directory and its subdirectories.""" + duplicates = [] + mp3_paths = [] + for root, _, files in os.walk(root_dir): + for file in files: + if file.endswith(".mp3"): + mp3_paths.append(os.path.join(root, file)) + + with multiprocessing.Pool() as pool: + hashes = pool.map(hash_audio_file, mp3_paths) + + seen_hash = {} + for p, h in hashes: + if seen_hash.get(h, False) is True: + print("Seen dupe") + duplicates.append(p) + else: + print("Seen orig") + seen_hash[h] = True + + return duplicates + + +def remove_duplicates(duplicate_files): + """Remove the duplicate files.""" + for file in duplicate_files: + os.remove(file) + print(f"Removed duplicate file: {file}") + + +def main(): + parser = argparse.ArgumentParser( + description="Remove duplicate MP3 files based on audio content." + ) + parser.add_argument( + "dir", type=str, help="Directory to scan for duplicate MP3 files." + ) + args = parser.parse_args() + + root_directory = args.dir + duplicates = find_duplicates(root_directory) + + if duplicates: + print(f"Found {len(duplicates)} duplicates. Removing...") + remove_duplicates(duplicates) + else: + print("No duplicates found.") + + +if __name__ == "__main__": + main() diff --git a/scripts/eval/dtw.py b/scripts/eval/dtw.py new file mode 100644 index 0000000..41e5754 --- /dev/null +++ b/scripts/eval/dtw.py @@ -0,0 +1,211 @@ +# pip install git+https://github.com/alex2awesome/djitw.git + +import argparse +import csv +import librosa +import djitw +import pretty_midi +import scipy +import random +import multiprocessing +import os +import warnings +import functools +import glob +import numpy as np + +from multiprocessing.dummy import Pool as ThreadPool + +# Audio/CQT parameters +FS = 22050.0 +NOTE_START = 36 +N_NOTES = 48 +HOP_LENGTH = 1024 + +# DTW parameters +GULLY = 0.96 + + +def compute_cqt(audio_data): + """Compute the CQT and frame times for some audio data""" + # Compute CQT + cqt = librosa.cqt( + audio_data, + sr=FS, + fmin=librosa.midi_to_hz(NOTE_START), + n_bins=N_NOTES, + hop_length=HOP_LENGTH, + tuning=0.0, + ) + # Compute the time of each frame + times = librosa.frames_to_time( + np.arange(cqt.shape[1]), sr=FS, hop_length=HOP_LENGTH + ) + # Compute log-amplitude + cqt = librosa.amplitude_to_db(cqt, ref=cqt.max()) + # Normalize and return + return librosa.util.normalize(cqt, norm=2).T, times + + +# Had to change this to average chunks for large audio files for cpu reasons +def load_and_run_dtw(args): + def calc_score(_midi_cqt, _audio_cqt): + # Nearly all high-performing systems used cosine distance + distance_matrix = scipy.spatial.distance.cdist( + _midi_cqt, _audio_cqt, "cosine" + ) + + # Get lowest cost path + p, q, score = djitw.dtw( + distance_matrix, + GULLY, # The gully for all high-performing systems was near 1 + np.median( + distance_matrix + ), # The penalty was also near 1.0*median(distance_matrix) + inplace=False, + ) + # Normalize by path length, normalize by distance matrix submatrix within path + score = score / len(p) + score = ( + score / distance_matrix[p.min() : p.max(), q.min() : q.max()].mean() + ) + + return score + + audio_file, midi_file = args + # Load in the audio data + audio_data, _ = librosa.load(audio_file, sr=FS) + audio_cqt, audio_times = compute_cqt(audio_data) + + midi_object = pretty_midi.PrettyMIDI(midi_file) + midi_audio = midi_object.fluidsynth(fs=FS) + midi_cqt, midi_times = compute_cqt(midi_audio) + + # Truncate to save on compute time for long tracks + MAX_LEN = 10000 + total_len = midi_cqt.shape[0] + if total_len > MAX_LEN: + idx = 0 + scores = [] + while idx < total_len: + scores.append( + calc_score( + _midi_cqt=midi_cqt[idx : idx + MAX_LEN, :], + _audio_cqt=audio_cqt[idx : idx + MAX_LEN, :], + ) + ) + idx += MAX_LEN + + max_score = max(scores) + avg_score = sum(scores) / len(scores) if scores else 1.0 + + else: + avg_score = calc_score(_midi_cqt=midi_cqt, _audio_cqt=audio_cqt) + max_score = avg_score + + return midi_file, avg_score, max_score + + +# I changed wav with mp3 in here :/ +def get_matched_files(audio_dir: str, mid_dir: str): + # We assume that the files have the same path relative to their directory + res = [] + wav_paths = glob.glob(os.path.join(audio_dir, "**/*.mp3"), recursive=True) + print(f"found {len(wav_paths)} mp3 files") + + for wav_path in wav_paths: + input_rel_path = os.path.relpath(wav_path, audio_dir) + mid_path = os.path.join( + mid_dir, os.path.splitext(input_rel_path)[0] + ".mid" + ) + if os.path.isfile(mid_path): + res.append((wav_path, mid_path)) + + print(f"found {len(res)} matched mp3-midi pairs") + + return res + + +def abortable_worker(func, *args, **kwargs): + timeout = kwargs.get("timeout", None) + p = ThreadPool(1) + res = p.apply_async(func, args=args) + try: + out = res.get(timeout) + return out + except multiprocessing.TimeoutError: + return None, None, None + except Exception as e: + print(e) + return None, None, None + finally: + p.close() + p.join() + + +if __name__ == "__main__": + multiprocessing.set_start_method("fork") + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="amplitude_to_db was called on complex input", + ) + parser = argparse.ArgumentParser() + parser.add_argument("-audio_dir", help="dir containing .wav files") + parser.add_argument( + "-mid_dir", help="dir containing .mid files", default=None + ) + parser.add_argument( + "-output_file", help="path to output file", default=None + ) + args = parser.parse_args() + + matched_files = get_matched_files( + audio_dir=args.audio_dir, mid_dir=args.mid_dir + ) + + results = {} + if os.path.exists(args.output_file): + with open(args.output_file, "r") as f: + reader = csv.DictReader(f) + for row in reader: + results[row["mid_path"]] = { + "avg_score": row["avg_score"], + "max_score": row["max_score"], + } + + matched_files = [ + (audio_path, mid_path) + for audio_path, mid_path in matched_files + if mid_path not in results.keys() + ] + random.shuffle(matched_files) + print(f"loaded {len(results)} results") + print(f"calculating scores for {len(matched_files)}") + + score_csv = open(args.output_file, "a") + csv_writer = csv.writer(score_csv) + csv_writer.writerow(["mid_path", "avg_score", "max_score"]) + + with multiprocessing.Pool() as pool: + abortable_func = functools.partial( + abortable_worker, load_and_run_dtw, timeout=15000 + ) + scores = pool.imap_unordered(abortable_func, matched_files) + + skipped = 0 + processed = 0 + for mid_path, avg_score, max_score in scores: + if avg_score is not None and max_score is not None: + csv_writer.writerow([mid_path, avg_score, max_score]) + score_csv.flush() + else: + print(f"timeout") + skipped += 1 + + processed += 1 + if processed % 10 == 0: + print(f"PROCESSED: {processed}/{len(matched_files)}") + print(f"***") + + print(f"skipped: {skipped}") diff --git a/scripts/eval/dtw.sh b/scripts/eval/dtw.sh new file mode 100644 index 0000000..1d25b71 --- /dev/null +++ b/scripts/eval/dtw.sh @@ -0,0 +1,5 @@ +python /home/loubb/work/aria-amt/scripts/eval/dtw.py \ + -audio_dir /mnt/ssd1/data/mp3/raw/aria-mp3 \ + -mid_dir /mnt/ssd1/amt/transcribed_data/0/aria-mid \ + -output_file /mnt/ssd1/amt/transcribed_data/0/aria-mid.csv + diff --git a/scripts/eval/mir.sh b/scripts/eval/mir.sh new file mode 100644 index 0000000..0364d4c --- /dev/null +++ b/scripts/eval/mir.sh @@ -0,0 +1,5 @@ +python /home/loubb/work/aria-amt/amt/evaluate.py \ + --est-dir /home/loubb/work/aria-amt/maestro-ft \ + --ref-dir /mnt/ssd1/data/mp3/raw/maestro-mp3 \ + --output-stats-file out.json + \ No newline at end of file diff --git a/scripts/eval/prune.py b/scripts/eval/prune.py new file mode 100644 index 0000000..1b8f919 --- /dev/null +++ b/scripts/eval/prune.py @@ -0,0 +1,91 @@ +import argparse +import csv +import os +import shutil + + +# Calculate percentiles without using numpy +def calculate_percentiles(data, percentiles): + data_sorted = sorted(data) + n = len(data_sorted) + results = [] + for percentile in percentiles: + k = (n - 1) * percentile / 100 + f = int(k) + c = k - f + if f + 1 < n: + result = data_sorted[f] + c * (data_sorted[f + 1] - data_sorted[f]) + else: + result = data_sorted[f] + results.append(result) + return results + + +def main(mid_dir, output_dir, score_file, max_score, dry): + if os.path.isdir(output_dir) is False: + os.makedirs(output_dir) + + scores = {} + with open(score_file, "r") as f: + reader = csv.DictReader(f) + failures = 0 + for row in reader: + try: + if 0.0 < float(row["avg_score"]) < 1.0: + scores[row["mid_path"]] = float(row["avg_score"]) + except Exception as e: + pass + + print(f"{failures} failures") + print(f"found {len(scores.items())} mid-score pairs") + + print("top 50 by score:") + for k, v in sorted(scores.items(), key=lambda item: item[1], reverse=True)[ + :50 + ]: + print(f"{v}: {k}") + print("bottom 50 by score:") + for k, v in sorted(scores.items(), key=lambda item: item[1])[:50]: + print(f"{v}: {k}") + + # Define the percentiles to calculate + percentiles = [10, 20, 30, 40, 50, 60, 70, 80, 90] + floats = [v for k, v in scores.items()] + + # Calculate the percentiles + print(f"percentiles: {calculate_percentiles(floats, percentiles)}") + + cnt = 0 + for mid_path, score in scores.items(): + mid_rel_path = os.path.relpath(mid_path, mid_dir) + output_path = os.path.join(output_dir, mid_rel_path) + if not os.path.exists(os.path.dirname(output_path)): + os.makedirs(os.path.dirname(output_path)) + + if score < max_score: + if args.dry is not True: + shutil.copyfile(mid_path, output_path) + else: + cnt += 1 + + print(f"excluded {cnt}/{len(scores.items())} files") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-mid_dir", help="dir containing .mid files", default=None + ) + parser.add_argument( + "-output_dir", help="dir containing .mid files", default=None + ) + parser.add_argument("-score_file", help="path to output file", default=None) + parser.add_argument( + "-max_score", type=float, help="path to output file", default=None + ) + parser.add_argument("-dry", action="store_true", help="path to output file") + args = parser.parse_args() + + main( + args.mid_dir, args.output_dir, args.score_file, args.max_score, args.dry + ) diff --git a/scripts/eval/prune.sh b/scripts/eval/prune.sh new file mode 100644 index 0000000..9b3f89a --- /dev/null +++ b/scripts/eval/prune.sh @@ -0,0 +1,6 @@ +python /home/loubb/work/aria-amt/scripts/eval/prune.py \ + -mid_dir /mnt/ssd1/amt/transcribed_data/0/pijama-mid \ + -output_dir /mnt/ssd1/amt/transcribed_data/0/pijama-mid-pruned \ + -score_file /mnt/ssd1/amt/transcribed_data/0/pijama-mid.csv \ + -max_score 0.42 \ + # -dry \ No newline at end of file diff --git a/scripts/eval/req-eval.txt b/scripts/eval/req-eval.txt new file mode 100644 index 0000000..d3ed177 --- /dev/null +++ b/scripts/eval/req-eval.txt @@ -0,0 +1,4 @@ +djitw @ git+https://github.com/alex2awesome/djitw.git +librosa +pretty_midi +pyfluidsynth \ No newline at end of file diff --git a/scripts/eval/split.py b/scripts/eval/split.py new file mode 100644 index 0000000..c912cbc --- /dev/null +++ b/scripts/eval/split.py @@ -0,0 +1,57 @@ +import csv +import random +import glob +import argparse +import os + + +def get_matched_paths(audio_dir: str, mid_dir: str): + # Assume that the files have the same path relative to their directory + res = [] + mid_paths = glob.glob(os.path.join(mid_dir, "**/*.mid"), recursive=True) + print(f"found {len(mid_paths)} mid files") + + audio_dir_last = os.path.basename(audio_dir) + mid_dir_last = os.path.basename(mid_dir) + + for mid_path in mid_paths: + input_rel_path = os.path.relpath(mid_path, mid_dir) + + mp3_rel_path = os.path.splitext(input_rel_path)[0] + ".mp3" + mp3_path = os.path.join(audio_dir, mp3_rel_path) + + # Check if the corresponding .mp3 file exists + if os.path.isfile(mp3_path): + matched_mid_path = os.path.join(mid_dir_last, input_rel_path) + matched_mp3_path = os.path.join(audio_dir_last, mp3_rel_path) + + res.append((matched_mp3_path, matched_mid_path)) + + print(f"found {len(res)} matched mp3-midi pairs") + assert len(mid_paths) == len(res), "audio files missing" + + return res + + +def create_csv(matched_paths, csv_path): + split_csv = open(csv_path, "w") + csv_writer = csv.writer(split_csv) + csv_writer.writerow(["mid_path", "audio_path", "split"]) + + for audio_path, mid_path in matched_paths: + if random.random() < 0.1: + csv_writer.writerow([mid_path, audio_path, "test"]) + else: + csv_writer.writerow([mid_path, audio_path, "train"]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-mid_dir", type=str) + parser.add_argument("-audio_dir", type=str) + parser.add_argument("-csv_path", type=str) + args = parser.parse_args() + + matched_paths = get_matched_paths(args.audio_dir, args.mid_dir) + + create_csv(matched_paths, args.csv_path) diff --git a/scripts/eval/split.sh b/scripts/eval/split.sh new file mode 100644 index 0000000..05faece --- /dev/null +++ b/scripts/eval/split.sh @@ -0,0 +1,4 @@ +python /home/loubb/work/aria-amt/scripts/eval/split.py \ + -mid_dir /mnt/ssd1/amt/transcribed_data/0/aria-mid-pruned \ + -audio_dir /mnt/ssd1/data/mp3/raw/aria-mp3 \ + -csv_path aria-pruned-split.csv