Skip to content

Commit

Permalink
add scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad committed Apr 9, 2024
1 parent b68bda3 commit afe5ad2
Show file tree
Hide file tree
Showing 9 changed files with 455 additions and 0 deletions.
72 changes: 72 additions & 0 deletions scripts/eval/dedupe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
import hashlib
import argparse
import multiprocessing

from pydub import AudioSegment


def hash_audio_file(file_path):
"""Hash the audio content of an MP3 file."""
try:
audio = AudioSegment.from_mp3(file_path)
raw_data = audio.raw_data
except Exception as e:
print(e)
return file_path, -1
else:
return file_path, hashlib.sha256(raw_data).hexdigest()


def find_duplicates(root_dir):
"""Find and remove duplicate MP3 files in the directory and its subdirectories."""
duplicates = []
mp3_paths = []
for root, _, files in os.walk(root_dir):
for file in files:
if file.endswith(".mp3"):
mp3_paths.append(os.path.join(root, file))

with multiprocessing.Pool() as pool:
hashes = pool.map(hash_audio_file, mp3_paths)

seen_hash = {}
for p, h in hashes:
if seen_hash.get(h, False) is True:
print("Seen dupe")
duplicates.append(p)
else:
print("Seen orig")
seen_hash[h] = True

return duplicates


def remove_duplicates(duplicate_files):
"""Remove the duplicate files."""
for file in duplicate_files:
os.remove(file)
print(f"Removed duplicate file: {file}")


def main():
parser = argparse.ArgumentParser(
description="Remove duplicate MP3 files based on audio content."
)
parser.add_argument(
"dir", type=str, help="Directory to scan for duplicate MP3 files."
)
args = parser.parse_args()

root_directory = args.dir
duplicates = find_duplicates(root_directory)

if duplicates:
print(f"Found {len(duplicates)} duplicates. Removing...")
remove_duplicates(duplicates)
else:
print("No duplicates found.")


if __name__ == "__main__":
main()
211 changes: 211 additions & 0 deletions scripts/eval/dtw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# pip install git+https://github.com/alex2awesome/djitw.git

import argparse
import csv
import librosa
import djitw
import pretty_midi
import scipy
import random
import multiprocessing
import os
import warnings
import functools
import glob
import numpy as np

from multiprocessing.dummy import Pool as ThreadPool

# Audio/CQT parameters
FS = 22050.0
NOTE_START = 36
N_NOTES = 48
HOP_LENGTH = 1024

# DTW parameters
GULLY = 0.96


def compute_cqt(audio_data):
"""Compute the CQT and frame times for some audio data"""
# Compute CQT
cqt = librosa.cqt(
audio_data,
sr=FS,
fmin=librosa.midi_to_hz(NOTE_START),
n_bins=N_NOTES,
hop_length=HOP_LENGTH,
tuning=0.0,
)
# Compute the time of each frame
times = librosa.frames_to_time(
np.arange(cqt.shape[1]), sr=FS, hop_length=HOP_LENGTH
)
# Compute log-amplitude
cqt = librosa.amplitude_to_db(cqt, ref=cqt.max())
# Normalize and return
return librosa.util.normalize(cqt, norm=2).T, times


# Had to change this to average chunks for large audio files for cpu reasons
def load_and_run_dtw(args):
def calc_score(_midi_cqt, _audio_cqt):
# Nearly all high-performing systems used cosine distance
distance_matrix = scipy.spatial.distance.cdist(
_midi_cqt, _audio_cqt, "cosine"
)

# Get lowest cost path
p, q, score = djitw.dtw(
distance_matrix,
GULLY, # The gully for all high-performing systems was near 1
np.median(
distance_matrix
), # The penalty was also near 1.0*median(distance_matrix)
inplace=False,
)
# Normalize by path length, normalize by distance matrix submatrix within path
score = score / len(p)
score = (
score / distance_matrix[p.min() : p.max(), q.min() : q.max()].mean()
)

return score

audio_file, midi_file = args
# Load in the audio data
audio_data, _ = librosa.load(audio_file, sr=FS)
audio_cqt, audio_times = compute_cqt(audio_data)

midi_object = pretty_midi.PrettyMIDI(midi_file)
midi_audio = midi_object.fluidsynth(fs=FS)
midi_cqt, midi_times = compute_cqt(midi_audio)

# Truncate to save on compute time for long tracks
MAX_LEN = 10000
total_len = midi_cqt.shape[0]
if total_len > MAX_LEN:
idx = 0
scores = []
while idx < total_len:
scores.append(
calc_score(
_midi_cqt=midi_cqt[idx : idx + MAX_LEN, :],
_audio_cqt=audio_cqt[idx : idx + MAX_LEN, :],
)
)
idx += MAX_LEN

max_score = max(scores)
avg_score = sum(scores) / len(scores) if scores else 1.0

else:
avg_score = calc_score(_midi_cqt=midi_cqt, _audio_cqt=audio_cqt)
max_score = avg_score

return midi_file, avg_score, max_score


# I changed wav with mp3 in here :/
def get_matched_files(audio_dir: str, mid_dir: str):
# We assume that the files have the same path relative to their directory
res = []
wav_paths = glob.glob(os.path.join(audio_dir, "**/*.mp3"), recursive=True)
print(f"found {len(wav_paths)} mp3 files")

for wav_path in wav_paths:
input_rel_path = os.path.relpath(wav_path, audio_dir)
mid_path = os.path.join(
mid_dir, os.path.splitext(input_rel_path)[0] + ".mid"
)
if os.path.isfile(mid_path):
res.append((wav_path, mid_path))

print(f"found {len(res)} matched mp3-midi pairs")

return res


def abortable_worker(func, *args, **kwargs):
timeout = kwargs.get("timeout", None)
p = ThreadPool(1)
res = p.apply_async(func, args=args)
try:
out = res.get(timeout)
return out
except multiprocessing.TimeoutError:
return None, None, None
except Exception as e:
print(e)
return None, None, None
finally:
p.close()
p.join()


if __name__ == "__main__":
multiprocessing.set_start_method("fork")
warnings.filterwarnings(
"ignore",
category=UserWarning,
message="amplitude_to_db was called on complex input",
)
parser = argparse.ArgumentParser()
parser.add_argument("-audio_dir", help="dir containing .wav files")
parser.add_argument(
"-mid_dir", help="dir containing .mid files", default=None
)
parser.add_argument(
"-output_file", help="path to output file", default=None
)
args = parser.parse_args()

matched_files = get_matched_files(
audio_dir=args.audio_dir, mid_dir=args.mid_dir
)

results = {}
if os.path.exists(args.output_file):
with open(args.output_file, "r") as f:
reader = csv.DictReader(f)
for row in reader:
results[row["mid_path"]] = {
"avg_score": row["avg_score"],
"max_score": row["max_score"],
}

matched_files = [
(audio_path, mid_path)
for audio_path, mid_path in matched_files
if mid_path not in results.keys()
]
random.shuffle(matched_files)
print(f"loaded {len(results)} results")
print(f"calculating scores for {len(matched_files)}")

score_csv = open(args.output_file, "a")
csv_writer = csv.writer(score_csv)
csv_writer.writerow(["mid_path", "avg_score", "max_score"])

with multiprocessing.Pool() as pool:
abortable_func = functools.partial(
abortable_worker, load_and_run_dtw, timeout=15000
)
scores = pool.imap_unordered(abortable_func, matched_files)

skipped = 0
processed = 0
for mid_path, avg_score, max_score in scores:
if avg_score is not None and max_score is not None:
csv_writer.writerow([mid_path, avg_score, max_score])
score_csv.flush()
else:
print(f"timeout")
skipped += 1

processed += 1
if processed % 10 == 0:
print(f"PROCESSED: {processed}/{len(matched_files)}")
print(f"***")

print(f"skipped: {skipped}")
5 changes: 5 additions & 0 deletions scripts/eval/dtw.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
python /home/loubb/work/aria-amt/scripts/eval/dtw.py \
-audio_dir /mnt/ssd1/data/mp3/raw/aria-mp3 \
-mid_dir /mnt/ssd1/amt/transcribed_data/0/aria-mid \
-output_file /mnt/ssd1/amt/transcribed_data/0/aria-mid.csv

5 changes: 5 additions & 0 deletions scripts/eval/mir.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
python /home/loubb/work/aria-amt/amt/evaluate.py \
--est-dir /home/loubb/work/aria-amt/maestro-ft \
--ref-dir /mnt/ssd1/data/mp3/raw/maestro-mp3 \
--output-stats-file out.json

91 changes: 91 additions & 0 deletions scripts/eval/prune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import argparse
import csv
import os
import shutil


# Calculate percentiles without using numpy
def calculate_percentiles(data, percentiles):
data_sorted = sorted(data)
n = len(data_sorted)
results = []
for percentile in percentiles:
k = (n - 1) * percentile / 100
f = int(k)
c = k - f
if f + 1 < n:
result = data_sorted[f] + c * (data_sorted[f + 1] - data_sorted[f])
else:
result = data_sorted[f]
results.append(result)
return results


def main(mid_dir, output_dir, score_file, max_score, dry):
if os.path.isdir(output_dir) is False:
os.makedirs(output_dir)

scores = {}
with open(score_file, "r") as f:
reader = csv.DictReader(f)
failures = 0
for row in reader:
try:
if 0.0 < float(row["avg_score"]) < 1.0:
scores[row["mid_path"]] = float(row["avg_score"])
except Exception as e:
pass

print(f"{failures} failures")
print(f"found {len(scores.items())} mid-score pairs")

print("top 50 by score:")
for k, v in sorted(scores.items(), key=lambda item: item[1], reverse=True)[
:50
]:
print(f"{v}: {k}")
print("bottom 50 by score:")
for k, v in sorted(scores.items(), key=lambda item: item[1])[:50]:
print(f"{v}: {k}")

# Define the percentiles to calculate
percentiles = [10, 20, 30, 40, 50, 60, 70, 80, 90]
floats = [v for k, v in scores.items()]

# Calculate the percentiles
print(f"percentiles: {calculate_percentiles(floats, percentiles)}")

cnt = 0
for mid_path, score in scores.items():
mid_rel_path = os.path.relpath(mid_path, mid_dir)
output_path = os.path.join(output_dir, mid_rel_path)
if not os.path.exists(os.path.dirname(output_path)):
os.makedirs(os.path.dirname(output_path))

if score < max_score:
if args.dry is not True:
shutil.copyfile(mid_path, output_path)
else:
cnt += 1

print(f"excluded {cnt}/{len(scores.items())} files")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-mid_dir", help="dir containing .mid files", default=None
)
parser.add_argument(
"-output_dir", help="dir containing .mid files", default=None
)
parser.add_argument("-score_file", help="path to output file", default=None)
parser.add_argument(
"-max_score", type=float, help="path to output file", default=None
)
parser.add_argument("-dry", action="store_true", help="path to output file")
args = parser.parse_args()

main(
args.mid_dir, args.output_dir, args.score_file, args.max_score, args.dry
)
6 changes: 6 additions & 0 deletions scripts/eval/prune.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
python /home/loubb/work/aria-amt/scripts/eval/prune.py \
-mid_dir /mnt/ssd1/amt/transcribed_data/0/pijama-mid \
-output_dir /mnt/ssd1/amt/transcribed_data/0/pijama-mid-pruned \
-score_file /mnt/ssd1/amt/transcribed_data/0/pijama-mid.csv \
-max_score 0.42 \
# -dry
Loading

0 comments on commit afe5ad2

Please sign in to comment.