Skip to content

Commit

Permalink
Evaluation scripts (#14)
Browse files Browse the repository at this point in the history
* updated gitignore

* added evaluation

* updated

* added signal reduction that makes it sound like a phone

* fixed bug in spec_aug that was making it happen in 80% of cases

* updated

* added notebooks

---------

Co-authored-by: Alex Spangher <[email protected]>
  • Loading branch information
alex2awesome and Alex Spangher authored Mar 8, 2024
1 parent 5d6f630 commit e00b374
Show file tree
Hide file tree
Showing 8 changed files with 1,502 additions and 90 deletions.
25 changes: 25 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,28 @@
# data files
*.csv
*.json
*.xls
*.xlsx
*.pkl
*.h5
*.sqlite
*.db
*.dbf
*.shp
*.shx
*.prj
*.cpg
*.xml
*.html
*.htm
*.mid
*.midi
*.wav
*.mp3

.idea/


# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
45 changes: 35 additions & 10 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ def __init__(
max_snr: int = 50,
max_dist_gain: int = 25,
min_dist_gain: int = 0,
# ratios for the reduction of the audio quality
distort_ratio: float = 0.2,
reduce_ratio: float = 0.2,
spec_aug_ratio: float = 0.2,
):
super().__init__()
self.tokenizer = AmtTokenizer()
Expand All @@ -204,6 +208,10 @@ def __init__(
self.chunk_len = self.config["chunk_len"]
self.num_samples = self.sample_rate * self.chunk_len

self.dist_ratio = distort_ratio
self.reduce_ratio = reduce_ratio
self.spec_aug_ratio = spec_aug_ratio

# Audio aug
impulse_paths = self._get_paths(
os.path.join(os.path.dirname(__file__), "assets", "impulse")
Expand Down Expand Up @@ -313,6 +321,16 @@ def apply_noise(self, wav: torch.tensor):

return AF.add_noise(waveform=wav, noise=noise, snr=snr_dbs)

def apply_reduction(self, wav: torch.tensor):
"""
Limit the high-band pass filter, the low-band pass filter and the sample rate
Designed to mimic the effect of recording on a low-quality microphone or phone.
"""
wav = AF.highpass_biquad(wav, self.sample_rate, cutoff_freq=1200)
wav = AF.lowpass_biquad(wav, self.sample_rate, cutoff_freq=1400)
resample_rate = 6000
return AF.resample(wav, orig_freq=self.sample_rate, new_freq=resample_rate, lowpass_filter_width=3)

def apply_distortion(self, wav: torch.tensor):
gain = random.randint(self.min_dist_gain, self.max_dist_gain)
colour = random.randint(5, 95)
Expand Down Expand Up @@ -345,13 +363,20 @@ def shift_spec(self, specs: torch.Tensor, shift: int):
return shifted_specs

def aug_wav(self, wav: torch.Tensor):
# Only apply distortion in 20% of cases
if random.random() > 0.20:
return self.apply_reverb(self.apply_noise(wav))
else:
return self.apply_reverb(
self.apply_distortion(self.apply_noise(wav))
)
"""
pipeline for audio augmentation:
1. apply noise
2. apply distortion (x% of the time)
3. apply reduction (x% of the time)
4. apply reverb
"""

wav = self.apply_noise(wav)
if random.random() < self.dist_ratio:
wav = self.apply_distortion(wav)
if random.random() < self.reduce_ratio:
wav = self.apply_reduction(wav)
return self.apply_reverb(wav)

def norm_mel(self, mel_spec: torch.Tensor):
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
Expand All @@ -374,14 +399,14 @@ def log_mel(self, wav: torch.Tensor, shift: int | None = None):
return log_spec

def forward(self, wav: torch.Tensor, shift: int = 0):
# Reverb & noise
# noise, distortion, reduction and reverb
wav = self.aug_wav(wav)

# Spec & pitch shift
log_mel = self.log_mel(wav, shift)

# Spec aug in 20% of cases
if random.random() > 0.20:
# Spec aug in 20% of the cases
if random.random() < self.spec_aug_ratio:
log_mel = self.spec_aug(log_mel)

return log_mel
117 changes: 117 additions & 0 deletions amt/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import glob
from tqdm.auto import tqdm
import pretty_midi
import numpy as np
import mir_eval
import json
import os

def midi_to_intervals_and_pitches(midi_file_path):
"""
This function reads a MIDI file and extracts note intervals and pitches
suitable for use with mir_eval's transcription evaluation functions.
"""
# Load the MIDI file
midi_data = pretty_midi.PrettyMIDI(midi_file_path)

# Prepare lists to collect note intervals and pitches
notes = []
for instrument in midi_data.instruments:
# Skip drum instruments
if not instrument.is_drum:
for note in instrument.notes:
notes.append([note.start, note.end, note.pitch])
notes = sorted(notes, key=lambda x: x[0])
notes = np.array(notes)
intervals, pitches = notes[:, :2], notes[:, 2]
intervals -= intervals[0][0]
return intervals, pitches


def midi_to_hz(note, shift=0):
"""
Convert MIDI to HZ.
Shift, if != 0, is subtracted from the MIDI note.
Use "2" for the hFT augmented model transcriptions, else pitches won't match.
"""
# the one used in hFT transformer
return 440.0 * (2.0 ** (note.astype(int) - shift - 69) / 12)
# a = 440 # frequency of A (common value is 440Hz)
# return (a / 32) * (2 ** ((note - 9) / 12))


def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0):
"""
Evaluate the estimated pitches against the reference pitches using mir_eval.
"""
# Evaluate the estimated pitches against the reference pitches
ref_midi_files = glob.glob(f"{ref_dir}/*.mid*")
est_midi_files = glob.glob(f"{est_dir}/*.mid*")

est_ref_pairs = []
for est_fpath in est_midi_files:
ref_fpath = os.path.join(ref_dir, os.path.basename(est_fpath))
if ref_fpath in ref_midi_files:
est_ref_pairs.append((est_fpath, ref_fpath))
if ref_fpath.replace(".mid", ".midi") in ref_midi_files:
est_ref_pairs.append((est_fpath, ref_fpath.replace(".mid", ".midi")))
else:
print(f"Reference file not found for {est_fpath} (ref file: {ref_fpath})")

output_fhandle = open(output_stats_file, "w") if output_stats_file is not None else None

for est_file, ref_file in tqdm(est_ref_pairs):
ref_intervals, ref_pitches = midi_to_intervals_and_pitches(ref_file)
est_intervals, est_pitches = midi_to_intervals_and_pitches(est_file)
ref_pitches_hz = midi_to_hz(ref_pitches)
est_pitches_hz = midi_to_hz(est_pitches, est_shift)
scores = mir_eval.transcription.evaluate(ref_intervals, ref_pitches_hz, est_intervals, est_pitches_hz)
if output_fhandle is not None:
output_fhandle.write(json.dumps(scores))
output_fhandle.write("\n")
else:
print(json.dumps(scores, indent=4))


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(usage="evaluate <command> [<args>]")
parser.add_argument(
"--est-dir",
type=str,
help="Path to the directory containing either the transcribed MIDI files or WAV files to be transcribed."
)
parser.add_argument(
"--ref-dir",
type=str,
help="Path to the directory containing the reference files (we'll use gold MIDI for mir_eval, WAV for dtw)."
)
parser.add_argument(
'--output-stats-file',
default=None,
type=str, help="Path to the file to save the evaluation stats"
)

# add mir_eval and dtw subparsers
subparsers = parser.add_subparsers(help="sub-command help")
mir_eval_parse = subparsers.add_parser("run_mir_eval", help="Run standard mir_eval evaluation on MAESTRO test set.")
mir_eval_parse.add_argument('--shift', type=int, default=0, help="Shift to apply to the estimated pitches.")

# to come
dtw_eval_parse = subparsers.add_parser("run_dtw", help="Run dynamic time warping evaluation on a specified dataset.")

args = parser.parse_args()
if not hasattr(args, "command"):
parser.print_help()
print("Unrecognized command")
exit(1)

# todo: should we add an option to run transcription again every time we wish to evaluate?
# that way, we can run both tests with a range of different audio augmentations right here.
# -> We expect that baseline methods will fall flat on these, while aria-amt will be OK.

if args.command == "run_mir_eval":
evaluate_mir_eval(args.est_dir, args.ref_dir, args.output_stats_file, args.shift)
elif args.command == "run_dtw":
pass
19 changes: 17 additions & 2 deletions amt/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

# TODO: Profile and fix gpu util


def calculate_vel(
logits: torch.Tensor,
init_vel: int,
Expand Down Expand Up @@ -88,7 +87,23 @@ def calculate_onset(
return tokenizer.tok_to_id[("onset", new_onset)]


@torch.autocast("cuda", dtype=torch.bfloat16)
from functools import wraps
from torch.cuda import is_bf16_supported
def optional_bf16_autocast(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Assuming 'check_bfloat16_support()' returns True if bfloat16 is supported
if is_bf16_supported():
with torch.autocast("cuda", dtype=torch.bfloat16):
return func(*args, **kwargs)
else:
# Call the function with float16 if bfloat16 is not supported
with torch.autocast("cuda", dtype=torch.float16):
return func(*args, **kwargs)
return wrapper


@optional_bf16_autocast
def process_segments(
tasks: list,
model: AmtEncoderDecoder,
Expand Down
Loading

0 comments on commit e00b374

Please sign in to comment.