-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
455 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import os | ||
import hashlib | ||
import argparse | ||
import multiprocessing | ||
|
||
from pydub import AudioSegment | ||
|
||
|
||
def hash_audio_file(file_path): | ||
"""Hash the audio content of an MP3 file.""" | ||
try: | ||
audio = AudioSegment.from_mp3(file_path) | ||
raw_data = audio.raw_data | ||
except Exception as e: | ||
print(e) | ||
return file_path, -1 | ||
else: | ||
return file_path, hashlib.sha256(raw_data).hexdigest() | ||
|
||
|
||
def find_duplicates(root_dir): | ||
"""Find and remove duplicate MP3 files in the directory and its subdirectories.""" | ||
duplicates = [] | ||
mp3_paths = [] | ||
for root, _, files in os.walk(root_dir): | ||
for file in files: | ||
if file.endswith(".mp3"): | ||
mp3_paths.append(os.path.join(root, file)) | ||
|
||
with multiprocessing.Pool() as pool: | ||
hashes = pool.map(hash_audio_file, mp3_paths) | ||
|
||
seen_hash = {} | ||
for p, h in hashes: | ||
if seen_hash.get(h, False) is True: | ||
print("Seen dupe") | ||
duplicates.append(p) | ||
else: | ||
print("Seen orig") | ||
seen_hash[h] = True | ||
|
||
return duplicates | ||
|
||
|
||
def remove_duplicates(duplicate_files): | ||
"""Remove the duplicate files.""" | ||
for file in duplicate_files: | ||
os.remove(file) | ||
print(f"Removed duplicate file: {file}") | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser( | ||
description="Remove duplicate MP3 files based on audio content." | ||
) | ||
parser.add_argument( | ||
"dir", type=str, help="Directory to scan for duplicate MP3 files." | ||
) | ||
args = parser.parse_args() | ||
|
||
root_directory = args.dir | ||
duplicates = find_duplicates(root_directory) | ||
|
||
if duplicates: | ||
print(f"Found {len(duplicates)} duplicates. Removing...") | ||
remove_duplicates(duplicates) | ||
else: | ||
print("No duplicates found.") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
# pip install git+https://github.com/alex2awesome/djitw.git | ||
|
||
import argparse | ||
import csv | ||
import librosa | ||
import djitw | ||
import pretty_midi | ||
import scipy | ||
import random | ||
import multiprocessing | ||
import os | ||
import warnings | ||
import functools | ||
import glob | ||
import numpy as np | ||
|
||
from multiprocessing.dummy import Pool as ThreadPool | ||
|
||
# Audio/CQT parameters | ||
FS = 22050.0 | ||
NOTE_START = 36 | ||
N_NOTES = 48 | ||
HOP_LENGTH = 1024 | ||
|
||
# DTW parameters | ||
GULLY = 0.96 | ||
|
||
|
||
def compute_cqt(audio_data): | ||
"""Compute the CQT and frame times for some audio data""" | ||
# Compute CQT | ||
cqt = librosa.cqt( | ||
audio_data, | ||
sr=FS, | ||
fmin=librosa.midi_to_hz(NOTE_START), | ||
n_bins=N_NOTES, | ||
hop_length=HOP_LENGTH, | ||
tuning=0.0, | ||
) | ||
# Compute the time of each frame | ||
times = librosa.frames_to_time( | ||
np.arange(cqt.shape[1]), sr=FS, hop_length=HOP_LENGTH | ||
) | ||
# Compute log-amplitude | ||
cqt = librosa.amplitude_to_db(cqt, ref=cqt.max()) | ||
# Normalize and return | ||
return librosa.util.normalize(cqt, norm=2).T, times | ||
|
||
|
||
# Had to change this to average chunks for large audio files for cpu reasons | ||
def load_and_run_dtw(args): | ||
def calc_score(_midi_cqt, _audio_cqt): | ||
# Nearly all high-performing systems used cosine distance | ||
distance_matrix = scipy.spatial.distance.cdist( | ||
_midi_cqt, _audio_cqt, "cosine" | ||
) | ||
|
||
# Get lowest cost path | ||
p, q, score = djitw.dtw( | ||
distance_matrix, | ||
GULLY, # The gully for all high-performing systems was near 1 | ||
np.median( | ||
distance_matrix | ||
), # The penalty was also near 1.0*median(distance_matrix) | ||
inplace=False, | ||
) | ||
# Normalize by path length, normalize by distance matrix submatrix within path | ||
score = score / len(p) | ||
score = ( | ||
score / distance_matrix[p.min() : p.max(), q.min() : q.max()].mean() | ||
) | ||
|
||
return score | ||
|
||
audio_file, midi_file = args | ||
# Load in the audio data | ||
audio_data, _ = librosa.load(audio_file, sr=FS) | ||
audio_cqt, audio_times = compute_cqt(audio_data) | ||
|
||
midi_object = pretty_midi.PrettyMIDI(midi_file) | ||
midi_audio = midi_object.fluidsynth(fs=FS) | ||
midi_cqt, midi_times = compute_cqt(midi_audio) | ||
|
||
# Truncate to save on compute time for long tracks | ||
MAX_LEN = 10000 | ||
total_len = midi_cqt.shape[0] | ||
if total_len > MAX_LEN: | ||
idx = 0 | ||
scores = [] | ||
while idx < total_len: | ||
scores.append( | ||
calc_score( | ||
_midi_cqt=midi_cqt[idx : idx + MAX_LEN, :], | ||
_audio_cqt=audio_cqt[idx : idx + MAX_LEN, :], | ||
) | ||
) | ||
idx += MAX_LEN | ||
|
||
max_score = max(scores) | ||
avg_score = sum(scores) / len(scores) if scores else 1.0 | ||
|
||
else: | ||
avg_score = calc_score(_midi_cqt=midi_cqt, _audio_cqt=audio_cqt) | ||
max_score = avg_score | ||
|
||
return midi_file, avg_score, max_score | ||
|
||
|
||
# I changed wav with mp3 in here :/ | ||
def get_matched_files(audio_dir: str, mid_dir: str): | ||
# We assume that the files have the same path relative to their directory | ||
res = [] | ||
wav_paths = glob.glob(os.path.join(audio_dir, "**/*.mp3"), recursive=True) | ||
print(f"found {len(wav_paths)} mp3 files") | ||
|
||
for wav_path in wav_paths: | ||
input_rel_path = os.path.relpath(wav_path, audio_dir) | ||
mid_path = os.path.join( | ||
mid_dir, os.path.splitext(input_rel_path)[0] + ".mid" | ||
) | ||
if os.path.isfile(mid_path): | ||
res.append((wav_path, mid_path)) | ||
|
||
print(f"found {len(res)} matched mp3-midi pairs") | ||
|
||
return res | ||
|
||
|
||
def abortable_worker(func, *args, **kwargs): | ||
timeout = kwargs.get("timeout", None) | ||
p = ThreadPool(1) | ||
res = p.apply_async(func, args=args) | ||
try: | ||
out = res.get(timeout) | ||
return out | ||
except multiprocessing.TimeoutError: | ||
return None, None, None | ||
except Exception as e: | ||
print(e) | ||
return None, None, None | ||
finally: | ||
p.close() | ||
p.join() | ||
|
||
|
||
if __name__ == "__main__": | ||
multiprocessing.set_start_method("fork") | ||
warnings.filterwarnings( | ||
"ignore", | ||
category=UserWarning, | ||
message="amplitude_to_db was called on complex input", | ||
) | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-audio_dir", help="dir containing .wav files") | ||
parser.add_argument( | ||
"-mid_dir", help="dir containing .mid files", default=None | ||
) | ||
parser.add_argument( | ||
"-output_file", help="path to output file", default=None | ||
) | ||
args = parser.parse_args() | ||
|
||
matched_files = get_matched_files( | ||
audio_dir=args.audio_dir, mid_dir=args.mid_dir | ||
) | ||
|
||
results = {} | ||
if os.path.exists(args.output_file): | ||
with open(args.output_file, "r") as f: | ||
reader = csv.DictReader(f) | ||
for row in reader: | ||
results[row["mid_path"]] = { | ||
"avg_score": row["avg_score"], | ||
"max_score": row["max_score"], | ||
} | ||
|
||
matched_files = [ | ||
(audio_path, mid_path) | ||
for audio_path, mid_path in matched_files | ||
if mid_path not in results.keys() | ||
] | ||
random.shuffle(matched_files) | ||
print(f"loaded {len(results)} results") | ||
print(f"calculating scores for {len(matched_files)}") | ||
|
||
score_csv = open(args.output_file, "a") | ||
csv_writer = csv.writer(score_csv) | ||
csv_writer.writerow(["mid_path", "avg_score", "max_score"]) | ||
|
||
with multiprocessing.Pool() as pool: | ||
abortable_func = functools.partial( | ||
abortable_worker, load_and_run_dtw, timeout=15000 | ||
) | ||
scores = pool.imap_unordered(abortable_func, matched_files) | ||
|
||
skipped = 0 | ||
processed = 0 | ||
for mid_path, avg_score, max_score in scores: | ||
if avg_score is not None and max_score is not None: | ||
csv_writer.writerow([mid_path, avg_score, max_score]) | ||
score_csv.flush() | ||
else: | ||
print(f"timeout") | ||
skipped += 1 | ||
|
||
processed += 1 | ||
if processed % 10 == 0: | ||
print(f"PROCESSED: {processed}/{len(matched_files)}") | ||
print(f"***") | ||
|
||
print(f"skipped: {skipped}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
python /home/loubb/work/aria-amt/scripts/eval/dtw.py \ | ||
-audio_dir /mnt/ssd1/data/mp3/raw/aria-mp3 \ | ||
-mid_dir /mnt/ssd1/amt/transcribed_data/0/aria-mid \ | ||
-output_file /mnt/ssd1/amt/transcribed_data/0/aria-mid.csv | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
python /home/loubb/work/aria-amt/amt/evaluate.py \ | ||
--est-dir /home/loubb/work/aria-amt/maestro-ft \ | ||
--ref-dir /mnt/ssd1/data/mp3/raw/maestro-mp3 \ | ||
--output-stats-file out.json | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import argparse | ||
import csv | ||
import os | ||
import shutil | ||
|
||
|
||
# Calculate percentiles without using numpy | ||
def calculate_percentiles(data, percentiles): | ||
data_sorted = sorted(data) | ||
n = len(data_sorted) | ||
results = [] | ||
for percentile in percentiles: | ||
k = (n - 1) * percentile / 100 | ||
f = int(k) | ||
c = k - f | ||
if f + 1 < n: | ||
result = data_sorted[f] + c * (data_sorted[f + 1] - data_sorted[f]) | ||
else: | ||
result = data_sorted[f] | ||
results.append(result) | ||
return results | ||
|
||
|
||
def main(mid_dir, output_dir, score_file, max_score, dry): | ||
if os.path.isdir(output_dir) is False: | ||
os.makedirs(output_dir) | ||
|
||
scores = {} | ||
with open(score_file, "r") as f: | ||
reader = csv.DictReader(f) | ||
failures = 0 | ||
for row in reader: | ||
try: | ||
if 0.0 < float(row["avg_score"]) < 1.0: | ||
scores[row["mid_path"]] = float(row["avg_score"]) | ||
except Exception as e: | ||
pass | ||
|
||
print(f"{failures} failures") | ||
print(f"found {len(scores.items())} mid-score pairs") | ||
|
||
print("top 50 by score:") | ||
for k, v in sorted(scores.items(), key=lambda item: item[1], reverse=True)[ | ||
:50 | ||
]: | ||
print(f"{v}: {k}") | ||
print("bottom 50 by score:") | ||
for k, v in sorted(scores.items(), key=lambda item: item[1])[:50]: | ||
print(f"{v}: {k}") | ||
|
||
# Define the percentiles to calculate | ||
percentiles = [10, 20, 30, 40, 50, 60, 70, 80, 90] | ||
floats = [v for k, v in scores.items()] | ||
|
||
# Calculate the percentiles | ||
print(f"percentiles: {calculate_percentiles(floats, percentiles)}") | ||
|
||
cnt = 0 | ||
for mid_path, score in scores.items(): | ||
mid_rel_path = os.path.relpath(mid_path, mid_dir) | ||
output_path = os.path.join(output_dir, mid_rel_path) | ||
if not os.path.exists(os.path.dirname(output_path)): | ||
os.makedirs(os.path.dirname(output_path)) | ||
|
||
if score < max_score: | ||
if args.dry is not True: | ||
shutil.copyfile(mid_path, output_path) | ||
else: | ||
cnt += 1 | ||
|
||
print(f"excluded {cnt}/{len(scores.items())} files") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"-mid_dir", help="dir containing .mid files", default=None | ||
) | ||
parser.add_argument( | ||
"-output_dir", help="dir containing .mid files", default=None | ||
) | ||
parser.add_argument("-score_file", help="path to output file", default=None) | ||
parser.add_argument( | ||
"-max_score", type=float, help="path to output file", default=None | ||
) | ||
parser.add_argument("-dry", action="store_true", help="path to output file") | ||
args = parser.parse_args() | ||
|
||
main( | ||
args.mid_dir, args.output_dir, args.score_file, args.max_score, args.dry | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
python /home/loubb/work/aria-amt/scripts/eval/prune.py \ | ||
-mid_dir /mnt/ssd1/amt/transcribed_data/0/pijama-mid \ | ||
-output_dir /mnt/ssd1/amt/transcribed_data/0/pijama-mid-pruned \ | ||
-score_file /mnt/ssd1/amt/transcribed_data/0/pijama-mid.csv \ | ||
-max_score 0.42 \ | ||
# -dry |
Oops, something went wrong.