From d6fea7f6457c5604864cb870142b5c2703f82997 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 12 Mar 2024 17:24:53 +0000 Subject: [PATCH] Fix inference, pedal, and add EQ aug (#20) * fix inference and add prev pedal token * add bandpass eq --- amt/audio.py | 20 ++++++++-- amt/infer.py | 91 +++++++++++++++++++++++++--------------------- amt/tokenizer.py | 49 ++++++++++++++++--------- tests/test_data.py | 12 ++++++ 4 files changed, 109 insertions(+), 63 deletions(-) diff --git a/amt/audio.py b/amt/audio.py index 8b3a08b..18224f5 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -194,6 +194,7 @@ def __init__( noise_ratio: float = 0.95, reverb_ratio: float = 0.95, applause_ratio: float = 0.01, + bandpass_ratio: float = 0.1, distort_ratio: float = 0.15, reduce_ratio: float = 0.01, spec_aug_ratio: float = 0.5, @@ -214,6 +215,7 @@ def __init__( self.noise_ratio = noise_ratio self.reverb_ratio = reverb_ratio self.applause_ratio = applause_ratio + self.bandpass_ratio = bandpass_ratio self.distort_ratio = distort_ratio self.reduce_ratio = reduce_ratio self.spec_aug_ratio = spec_aug_ratio @@ -350,6 +352,14 @@ def apply_applause(self, wav: torch.tensor): return AF.add_noise(waveform=wav, noise=applause, snr=snr_dbs) + def apply_bandpass(self, wav: torch.tensor): + central_freq = random.randint(1000, 3500) + Q = random.uniform(0.707, 1.41) + + return torchaudio.functional.bandpass_biquad( + wav, self.sample_rate, central_freq, Q + ) + def apply_reduction(self, wav: torch.tensor): """ Limit the high-band pass filter, the low-band pass filter and the sample rate @@ -424,9 +434,13 @@ def aug_wav(self, wav: torch.Tensor): # Reverb if random.random() < self.reverb_ratio: - return self.apply_reverb(wav) - else: - return wav + wav = self.apply_reverb(wav) + + # EQ + if random.random() < self.bandpass_ratio: + wav = self.apply_bandpass(wav) + + return wav def norm_mel(self, mel_spec: torch.Tensor): log_spec = torch.clamp(mel_spec, min=1e-10).log10() diff --git a/amt/infer.py b/amt/infer.py index 382fb65..289d499 100644 --- a/amt/infer.py +++ b/amt/infer.py @@ -283,10 +283,13 @@ def _truncate_seq( _mid_dict = tokenizer._detokenize_midi_dict(seq, LEN_MS) try: res = tokenizer._tokenize_midi_dict(_mid_dict, start_ms, end_ms - 1) - except: + except Exception: + print("Truncate failed") return [""] else: - return res[: res.index(tokenizer.eos_tok)] # Needs to change + if res[-1] == tokenizer.eos_tok: + res.pop() + return res def process_file( @@ -306,14 +309,9 @@ def process_file( ) ] - # Add addtional (padded) final audio segment - _last_seg = audio_segments[-1] - audio_segments.append( - pad_or_trim(_last_seg[len(_last_seg) // STRIDE_FACTOR :]) - ) - + res = [] seq = [tokenizer.bos_tok] - res = [tokenizer.bos_tok] + concat_seq = [tokenizer.bos_tok] for idx, audio_seg in enumerate(audio_segments): init_idx = len(seq) @@ -327,21 +325,25 @@ def process_file( else: result_queue.put(gpu_result) - res += _shift_onset( + concat_seq += _shift_onset( seq[init_idx:], idx * CHUNK_LEN_MS, ) if idx == len(audio_segments) - 1: - break - elif res[-1] == tokenizer.eos_tok: - logger.info(f"Exiting early") - break + res.append(concat_seq) + elif concat_seq[-1] == tokenizer.eos_tok: + res.append(concat_seq) + seq = [tokenizer.bos_tok] + concat_seq = [tokenizer.bos_tok] + logger.info(f"Finished segment - eos_tok seen") else: seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS - CHUNK_LEN_MS) - if len(seq) <= 2: - logger.info(f"Exiting early") - return res + if len(seq) == 1: + res.append(concat_seq) + seq = [tokenizer.bos_tok] + concat_seq = [tokenizer.bos_tok] + logger.info(f"Exiting early - silence") return res @@ -353,16 +355,35 @@ def worker( save_dir: str, input_dir: str | None = None, ): - def _get_save_path(_file_path: str): + def _save_seq(_seq: list, _save_path: str): + if os.path.exists(_save_path): + logger.info(f"Already exists {_save_path} - overwriting") + + for tok in _seq[::-1]: + if type(tok) is tuple and tok[0] == "onset": + last_onset = tok[1] + break + + try: + mid_dict = tokenizer._detokenize_midi_dict( + tokenized_seq=_seq, len_ms=last_onset + ) + mid = mid_dict.to_midi() + mid.save(_save_path) + except Exception as e: + logger.error(f"Failed to save {_save_path}") + + def _get_save_path(_file_path: str, _idx: int | str = ""): if input_dir is None: save_path = os.path.join( save_dir, - os.path.splitext(os.path.basename(file_path))[0] + ".mid", + os.path.splitext(os.path.basename(file_path))[0] + + f"{_idx}.mid", ) else: input_rel_path = os.path.relpath(_file_path, input_dir) save_path = os.path.join( - save_dir, os.path.splitext(input_rel_path)[0] + ".mid" + save_dir, os.path.splitext(input_rel_path)[0] + f"{_idx}.mid" ) if not os.path.isdir(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path), exist_ok=True) @@ -374,34 +395,20 @@ def _get_save_path(_file_path: str): files_processed = 0 while not file_queue.empty(): file_path = file_queue.get() - save_path = _get_save_path(file_path) - if os.path.exists(save_path): - logger.info(f"{save_path} already exists, overwriting") try: - res = process_file(file_path, gpu_task_queue, result_queue) + seqs = process_file(file_path, gpu_task_queue, result_queue) except Exception as e: - logger.error(f"Failed to transcribe {file_path}") + logger.error(f"Failed to process {file_path}") continue - files_processed += 1 - - for tok in res[::-1]: - if type(tok) is tuple and tok[0] == "onset": - last_onset = tok[1] - break + logger.info(f"Transcribed into {len(seqs)} segment(s)") + for _idx, seq in enumerate(seqs): + _save_seq(seq, _get_save_path(file_path, _idx)) - try: - mid_dict = tokenizer._detokenize_midi_dict( - tokenized_seq=res, len_ms=last_onset - ) - mid = mid_dict.to_midi() - mid.save(save_path) - except Exception as e: - logger.error(f"Failed to detokenize with error {e}") - else: - logger.info(f"Finished file {files_processed} - {file_path}") - logger.info(f"{file_queue.qsize()} file(s) remaining in queue") + files_processed += 1 + logger.info(f"Finished file {files_processed} - {file_path}") + logger.info(f"{file_queue.qsize()} file(s) remaining in queue") def batch_transcribe( diff --git a/amt/tokenizer.py b/amt/tokenizer.py index 762a10a..d5416a7 100644 --- a/amt/tokenizer.py +++ b/amt/tokenizer.py @@ -46,7 +46,7 @@ def __init__(self, return_tensors: bool = False): self.prev_tokens = [("prev", i) for i in range(128)] self.note_on_tokens = [("on", i) for i in range(128)] self.note_off_tokens = [("off", i) for i in range(128)] - self.pedal_tokens = [("pedal", 0), (("pedal", 1))] + self.pedal_tokens = [("pedal", 0), ("pedal", 1), ("prev", "pedal")] self.velocity_tokens = [("vel", i) for i in self.velocity_quantizations] self.onset_tokens = [ ("onset", i) for i in self.onset_time_quantizations @@ -81,7 +81,6 @@ def _quantize_velocity(self, velocity: int): # TODO: # - I need to make this method more robust, as it will have to handle # an arbitrary MIDI file - # - Decide whether to put pedal messages as prev tokens def _tokenize_midi_dict( self, midi_dict: MidiDict, @@ -96,11 +95,13 @@ def _tokenize_midi_dict( pedal_intervals = midi_dict._build_pedal_intervals() if len(pedal_intervals.keys()) > 1: print("Warning: midi_dict has more than one pedal channel") + if len(midi_dict.instrument_msgs) > 1: + print("Warning: midi_dict has more than one instrument msg") pedal_intervals = pedal_intervals[0] last_msg_ms = -1 on_off_notes = [] - prev_notes = [] + prev_toks = [] for msg in midi_dict.note_msgs: _pitch = msg["data"]["pitch"] _velocity = msg["data"]["velocity"] @@ -137,9 +138,9 @@ def _tokenize_midi_dict( if note_end_ms <= start_ms or note_start_ms >= end_ms: # Skip continue elif ( - note_start_ms < start_ms and _pitch not in prev_notes + note_start_ms < start_ms and _pitch not in prev_toks ): # Add to prev notes - prev_notes.append(_pitch) + prev_toks.append(_pitch) if note_end_ms < end_ms: on_off_notes.append( ("off", _pitch, rel_note_end_ms_q, None) @@ -182,8 +183,10 @@ def _tokenize_midi_dict( rel_off_ms_q = self._quantize_onset(pedal_off_ms - start_ms) # On message - if pedal_on_ms <= start_ms or pedal_on_ms >= end_ms: + if pedal_off_ms <= start_ms or pedal_on_ms >= end_ms: continue + elif pedal_on_ms < start_ms and pedal_off_ms >= start_ms: + prev_toks.append("pedal") else: on_off_pedal.append(("pedal", 1, rel_on_ms_q, None)) @@ -200,7 +203,7 @@ def _tokenize_midi_dict( (0 if x[0] == "pedal" else 1 if x[0] == "off" else 2), ) ) - random.shuffle(prev_notes) + random.shuffle(prev_toks) tokenized_seq = [] for tok in on_off_combined: @@ -220,7 +223,7 @@ def _tokenize_midi_dict( tokenized_seq.append(("pedal", _val)) tokenized_seq.append(("onset", _onset)) - prefix = [("prev", p) for p in prev_notes] + prefix = [("prev", p) for p in prev_toks] # Add eos_tok only if segment includes end of midi_dict if last_msg_ms < end_ms: @@ -271,7 +274,21 @@ def _detokenize_midi_dict( if DEBUG: raise Exception - notes_to_close[tok[1]] = (0, self.default_velocity) + if tok[1] == "pedal": + pedal_msgs.append( + { + "type": "pedal", + "data": 1, + "tick": 0, + "channel": 0, + } + ) + elif isinstance(tok[1], int): + notes_to_close[tok[1]] = (0, self.default_velocity) + else: + print(f"Invalid 'prev' token: {tok}") + if DEBUG: + raise Exception else: raise Exception( f"Invalid note sequence at position {idx}: {tok, tokenized_seq[:idx]}" @@ -293,11 +310,9 @@ def _detokenize_midi_dict( if DEBUG: raise Exception elif tok_1_type == "pedal": - # Pedal information contained in note-off messages, so we don't - # need to manually processes them _pedal_data = tok_1_data _tick = tok_2_data - note_msgs.append( + pedal_msgs.append( { "type": "pedal", "data": _pedal_data, @@ -454,13 +469,11 @@ def msg_mixup(src: list): # Shuffle order and re-append to result for k, v in sorted(buffer.items()): + off_pedal_combined = v["off"] + v["pedal"] + random.shuffle(off_pedal_combined) random.shuffle(v["on"]) - random.shuffle(v["off"]) - for item in v["pedal"]: - res.append(item[0]) # Pedal - res.append(item[1]) # Onset - for item in v["off"]: - res.append(item[0]) # Pitch + for item in off_pedal_combined: + res.append(item[0]) # Off or pedal res.append(item[1]) # Onset for item in v["on"]: res.append(item[0]) # Pitch diff --git a/tests/test_data.py b/tests/test_data.py index e1770e4..1437472 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -177,6 +177,18 @@ def test_distortion(self): res = audio_transform.apply_distortion(wav) torchaudio.save("tests/test_results/dist.wav", res, SAMPLE_RATE) + def test_bandpass(self): + SAMPLE_RATE, CHUNK_LEN = 16000, 30 + audio_transform = AudioTransform() + wav, sr = torchaudio.load("tests/test_data/147.wav") + wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE).mean( + 0, keepdim=True + )[:, : SAMPLE_RATE * CHUNK_LEN] + + torchaudio.save("tests/test_results/orig.wav", wav, SAMPLE_RATE) + res = audio_transform.apply_bandpass(wav) + torchaudio.save("tests/test_results/bandpass.wav", res, SAMPLE_RATE) + def test_applause(self): SAMPLE_RATE, CHUNK_LEN = 16000, 30 audio_transform = AudioTransform()