Skip to content

Commit

Permalink
local changes
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad committed Apr 9, 2024
1 parent 1f423d8 commit b68bda3
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 71 deletions.
16 changes: 8 additions & 8 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,15 @@ def __init__(
max_snr: int = 50,
max_dist_gain: int = 25,
min_dist_gain: int = 0,
noise_ratio: float = 0.75,
reverb_ratio: float = 0.75,
noise_ratio: float = 0.5,
reverb_ratio: float = 0.5,
applause_ratio: float = 0.01,
bandpass_ratio: float = 0.15,
distort_ratio: float = 0.15,
reduce_ratio: float = 0.01,
bandpass_ratio: float = 0.05,
distort_ratio: float = 0.05,
reduce_ratio: float = 0.0,
detune_ratio: float = 0.0,
detune_max_shift: float = 0.0,
spec_aug_ratio: float = 0.9,
spec_aug_ratio: float = 0.0,
):
super().__init__()
self.tokenizer = AmtTokenizer()
Expand All @@ -224,8 +224,8 @@ def __init__(
self.detune_max_shift = detune_max_shift
self.spec_aug_ratio = spec_aug_ratio

self.time_mask_param = 2500
self.freq_mask_param = 15
self.time_mask_param = 500
self.freq_mask_param = 10
self.reduction_resample_rate = 6000

# Audio aug
Expand Down
78 changes: 27 additions & 51 deletions amt/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,32 @@ def midi_to_hz(note, shift=0):
# return (a / 32) * (2 ** ((note - 9) / 12))


def get_matched_files(est_dir: str, ref_dir: str):
# We assume that the files have the same path relative to their directory

res = []
est_paths = glob.glob(os.path.join(est_dir, "**/*.mid"), recursive=True)
print(f"found {len(est_paths)} est files")

for est_path in est_paths:
est_rel_path = os.path.relpath(est_path, est_dir)
ref_path = os.path.join(
ref_dir, os.path.splitext(est_rel_path)[0] + ".midi"
)
if os.path.isfile(ref_path):
res.append((est_path, ref_path))

print(f"found {len(res)} matched est-ref pairs")

return res


def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0):
"""
Evaluate the estimated pitches against the reference pitches using mir_eval.
"""
# Evaluate the estimated pitches against the reference pitches
ref_midi_files = glob.glob(f"{ref_dir}/*.mid*")
est_midi_files = glob.glob(f"{est_dir}/*.mid*")

est_ref_pairs = []
for est_fpath in est_midi_files:
ref_fpath = os.path.join(ref_dir, os.path.basename(est_fpath))
if ref_fpath in ref_midi_files:
est_ref_pairs.append((est_fpath, ref_fpath))
if ref_fpath.replace(".mid", ".midi") in ref_midi_files:
est_ref_pairs.append(
(est_fpath, ref_fpath.replace(".mid", ".midi"))
)
else:
print(
f"Reference file not found for {est_fpath} (ref file: {ref_fpath})"
)

est_ref_pairs = get_matched_files(est_dir, ref_dir)

output_fhandle = (
open(output_stats_file, "w") if output_stats_file is not None else None
Expand Down Expand Up @@ -104,38 +109,9 @@ def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0):
help="Path to the file to save the evaluation stats",
)

# add mir_eval and dtw subparsers
subparsers = parser.add_subparsers(help="sub-command help")
mir_eval_parse = subparsers.add_parser(
"run_mir_eval",
help="Run standard mir_eval evaluation on MAESTRO test set.",
)
mir_eval_parse.add_argument(
"--shift",
type=int,
default=0,
help="Shift to apply to the estimated pitches.",
)

# to come
dtw_eval_parse = subparsers.add_parser(
"run_dtw",
help="Run dynamic time warping evaluation on a specified dataset.",
)

args = parser.parse_args()
if not hasattr(args, "command"):
parser.print_help()
print("Unrecognized command")
exit(1)

# todo: should we add an option to run transcription again every time we wish to evaluate?
# that way, we can run both tests with a range of different audio augmentations right here.
# -> We expect that baseline methods will fall flat on these, while aria-amt will be OK.

if args.command == "run_mir_eval":
evaluate_mir_eval(
args.est_dir, args.ref_dir, args.output_stats_file, args.shift
)
elif args.command == "run_dtw":
pass
evaluate_mir_eval(
args.est_dir,
args.ref_dir,
args.output_stats_file,
)
8 changes: 7 additions & 1 deletion amt/inference/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,11 @@ def gpu_manager(
compile: bool = False,
gpu_id: int | None = None,
):
logger = _setup_logger(name="GPU")
if gpu_id:
logger = _setup_logger(name=f"GPU-{gpu_id}")
else:
logger = _setup_logger(name=f"GPU")

logger.info("Started GPU manager")

if gpu_id is not None:
Expand Down Expand Up @@ -682,6 +686,8 @@ def batch_transcribe(
is False
):
file_queue.put(file_path)
elif len(file_paths) == 1:
file_queue.put(file_path)

logger.info(f"Files to process: {file_queue.qsize()}/{len(file_paths)}")

Expand Down
18 changes: 9 additions & 9 deletions amt/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,16 @@ def build_maestro(maestro_dir, train_file, val_file, test_file, num_procs):

print(f"Building {train_file}")
AmtDataset.build(
load_paths=matched_paths_train + matched_paths_val,
load_paths=matched_paths_train,
save_path=train_file,
num_processes=num_procs,
)
# print(f"Building {val_file}")
# AmtDataset.build(
# load_paths=matched_paths_val,
# save_path=val_file,
# num_processes=num_procs,
# )
print(f"Building {val_file}")
AmtDataset.build(
load_paths=matched_paths_val,
save_path=val_file,
num_processes=num_procs,
)
print(f"Building {test_file}")
AmtDataset.build(
load_paths=matched_paths_test,
Expand Down Expand Up @@ -296,8 +296,8 @@ def transcribe(
)
val_mp3_paths = [ap for ap, mp in matched_val_paths]
test_mp3_paths = [ap for ap, mp in matched_test_paths]
file_paths = val_mp3_paths + test_mp3_paths
assert len(file_paths) == 314, "Invalid maestro files"
file_paths = test_mp3_paths # val_mp3_paths + test_mp3_paths
assert len(file_paths) == 177, "Invalid maestro files"
else:
file_paths = [load_path]
batch_size = 1
Expand Down
4 changes: 2 additions & 2 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
if os.path.isdir("tests/test_results") is False:
os.mkdir("tests/test_results")

MAESTRO_PATH = "/home/mchorse/amt/data/maestro_train/train.txt"
MAESTRO_PATH = "/mnt/ssd1/amt/training_data/train.txt"


def plot_spec(mel: torch.Tensor, name: str | int):
Expand Down Expand Up @@ -76,7 +76,7 @@ def test_maestro(self):
audio_transform = AudioTransform()
dataset = AmtDataset(load_path=MAESTRO_PATH)
print(f"Dataset length: {len(dataset)}")
for idx, (wav, src, tgt, idx) in enumerate(dataset):
for idx, (wav, src, tgt, __idx) in enumerate(dataset):
src_dec, tgt_dec = tokenizer.decode(src), tokenizer.decode(tgt)

if idx % 7 == 0 and idx < 100:
Expand Down

0 comments on commit b68bda3

Please sign in to comment.