Skip to content

Commit

Permalink
Fix inference, pedal, and add EQ aug (#20)
Browse files Browse the repository at this point in the history
* fix inference and add prev pedal token

* add bandpass eq
  • Loading branch information
loubbrad authored Mar 12, 2024
1 parent d56e8e5 commit d6fea7f
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 63 deletions.
20 changes: 17 additions & 3 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __init__(
noise_ratio: float = 0.95,
reverb_ratio: float = 0.95,
applause_ratio: float = 0.01,
bandpass_ratio: float = 0.1,
distort_ratio: float = 0.15,
reduce_ratio: float = 0.01,
spec_aug_ratio: float = 0.5,
Expand All @@ -214,6 +215,7 @@ def __init__(
self.noise_ratio = noise_ratio
self.reverb_ratio = reverb_ratio
self.applause_ratio = applause_ratio
self.bandpass_ratio = bandpass_ratio
self.distort_ratio = distort_ratio
self.reduce_ratio = reduce_ratio
self.spec_aug_ratio = spec_aug_ratio
Expand Down Expand Up @@ -350,6 +352,14 @@ def apply_applause(self, wav: torch.tensor):

return AF.add_noise(waveform=wav, noise=applause, snr=snr_dbs)

def apply_bandpass(self, wav: torch.tensor):
central_freq = random.randint(1000, 3500)
Q = random.uniform(0.707, 1.41)

return torchaudio.functional.bandpass_biquad(
wav, self.sample_rate, central_freq, Q
)

def apply_reduction(self, wav: torch.tensor):
"""
Limit the high-band pass filter, the low-band pass filter and the sample rate
Expand Down Expand Up @@ -424,9 +434,13 @@ def aug_wav(self, wav: torch.Tensor):

# Reverb
if random.random() < self.reverb_ratio:
return self.apply_reverb(wav)
else:
return wav
wav = self.apply_reverb(wav)

# EQ
if random.random() < self.bandpass_ratio:
wav = self.apply_bandpass(wav)

return wav

def norm_mel(self, mel_spec: torch.Tensor):
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
Expand Down
91 changes: 49 additions & 42 deletions amt/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,13 @@ def _truncate_seq(
_mid_dict = tokenizer._detokenize_midi_dict(seq, LEN_MS)
try:
res = tokenizer._tokenize_midi_dict(_mid_dict, start_ms, end_ms - 1)
except:
except Exception:
print("Truncate failed")
return ["<S>"]
else:
return res[: res.index(tokenizer.eos_tok)] # Needs to change
if res[-1] == tokenizer.eos_tok:
res.pop()
return res


def process_file(
Expand All @@ -306,14 +309,9 @@ def process_file(
)
]

# Add addtional (padded) final audio segment
_last_seg = audio_segments[-1]
audio_segments.append(
pad_or_trim(_last_seg[len(_last_seg) // STRIDE_FACTOR :])
)

res = []
seq = [tokenizer.bos_tok]
res = [tokenizer.bos_tok]
concat_seq = [tokenizer.bos_tok]
for idx, audio_seg in enumerate(audio_segments):
init_idx = len(seq)

Expand All @@ -327,21 +325,25 @@ def process_file(
else:
result_queue.put(gpu_result)

res += _shift_onset(
concat_seq += _shift_onset(
seq[init_idx:],
idx * CHUNK_LEN_MS,
)

if idx == len(audio_segments) - 1:
break
elif res[-1] == tokenizer.eos_tok:
logger.info(f"Exiting early")
break
res.append(concat_seq)
elif concat_seq[-1] == tokenizer.eos_tok:
res.append(concat_seq)
seq = [tokenizer.bos_tok]
concat_seq = [tokenizer.bos_tok]
logger.info(f"Finished segment - eos_tok seen")
else:
seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS - CHUNK_LEN_MS)
if len(seq) <= 2:
logger.info(f"Exiting early")
return res
if len(seq) == 1:
res.append(concat_seq)
seq = [tokenizer.bos_tok]
concat_seq = [tokenizer.bos_tok]
logger.info(f"Exiting early - silence")

return res

Expand All @@ -353,16 +355,35 @@ def worker(
save_dir: str,
input_dir: str | None = None,
):
def _get_save_path(_file_path: str):
def _save_seq(_seq: list, _save_path: str):
if os.path.exists(_save_path):
logger.info(f"Already exists {_save_path} - overwriting")

for tok in _seq[::-1]:
if type(tok) is tuple and tok[0] == "onset":
last_onset = tok[1]
break

try:
mid_dict = tokenizer._detokenize_midi_dict(
tokenized_seq=_seq, len_ms=last_onset
)
mid = mid_dict.to_midi()
mid.save(_save_path)
except Exception as e:
logger.error(f"Failed to save {_save_path}")

def _get_save_path(_file_path: str, _idx: int | str = ""):
if input_dir is None:
save_path = os.path.join(
save_dir,
os.path.splitext(os.path.basename(file_path))[0] + ".mid",
os.path.splitext(os.path.basename(file_path))[0]
+ f"{_idx}.mid",
)
else:
input_rel_path = os.path.relpath(_file_path, input_dir)
save_path = os.path.join(
save_dir, os.path.splitext(input_rel_path)[0] + ".mid"
save_dir, os.path.splitext(input_rel_path)[0] + f"{_idx}.mid"
)
if not os.path.isdir(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
Expand All @@ -374,34 +395,20 @@ def _get_save_path(_file_path: str):
files_processed = 0
while not file_queue.empty():
file_path = file_queue.get()
save_path = _get_save_path(file_path)
if os.path.exists(save_path):
logger.info(f"{save_path} already exists, overwriting")

try:
res = process_file(file_path, gpu_task_queue, result_queue)
seqs = process_file(file_path, gpu_task_queue, result_queue)
except Exception as e:
logger.error(f"Failed to transcribe {file_path}")
logger.error(f"Failed to process {file_path}")
continue

files_processed += 1

for tok in res[::-1]:
if type(tok) is tuple and tok[0] == "onset":
last_onset = tok[1]
break
logger.info(f"Transcribed into {len(seqs)} segment(s)")
for _idx, seq in enumerate(seqs):
_save_seq(seq, _get_save_path(file_path, _idx))

try:
mid_dict = tokenizer._detokenize_midi_dict(
tokenized_seq=res, len_ms=last_onset
)
mid = mid_dict.to_midi()
mid.save(save_path)
except Exception as e:
logger.error(f"Failed to detokenize with error {e}")
else:
logger.info(f"Finished file {files_processed} - {file_path}")
logger.info(f"{file_queue.qsize()} file(s) remaining in queue")
files_processed += 1
logger.info(f"Finished file {files_processed} - {file_path}")
logger.info(f"{file_queue.qsize()} file(s) remaining in queue")


def batch_transcribe(
Expand Down
49 changes: 31 additions & 18 deletions amt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, return_tensors: bool = False):
self.prev_tokens = [("prev", i) for i in range(128)]
self.note_on_tokens = [("on", i) for i in range(128)]
self.note_off_tokens = [("off", i) for i in range(128)]
self.pedal_tokens = [("pedal", 0), (("pedal", 1))]
self.pedal_tokens = [("pedal", 0), ("pedal", 1), ("prev", "pedal")]
self.velocity_tokens = [("vel", i) for i in self.velocity_quantizations]
self.onset_tokens = [
("onset", i) for i in self.onset_time_quantizations
Expand Down Expand Up @@ -81,7 +81,6 @@ def _quantize_velocity(self, velocity: int):
# TODO:
# - I need to make this method more robust, as it will have to handle
# an arbitrary MIDI file
# - Decide whether to put pedal messages as prev tokens
def _tokenize_midi_dict(
self,
midi_dict: MidiDict,
Expand All @@ -96,11 +95,13 @@ def _tokenize_midi_dict(
pedal_intervals = midi_dict._build_pedal_intervals()
if len(pedal_intervals.keys()) > 1:
print("Warning: midi_dict has more than one pedal channel")
if len(midi_dict.instrument_msgs) > 1:
print("Warning: midi_dict has more than one instrument msg")
pedal_intervals = pedal_intervals[0]

last_msg_ms = -1
on_off_notes = []
prev_notes = []
prev_toks = []
for msg in midi_dict.note_msgs:
_pitch = msg["data"]["pitch"]
_velocity = msg["data"]["velocity"]
Expand Down Expand Up @@ -137,9 +138,9 @@ def _tokenize_midi_dict(
if note_end_ms <= start_ms or note_start_ms >= end_ms: # Skip
continue
elif (
note_start_ms < start_ms and _pitch not in prev_notes
note_start_ms < start_ms and _pitch not in prev_toks
): # Add to prev notes
prev_notes.append(_pitch)
prev_toks.append(_pitch)
if note_end_ms < end_ms:
on_off_notes.append(
("off", _pitch, rel_note_end_ms_q, None)
Expand Down Expand Up @@ -182,8 +183,10 @@ def _tokenize_midi_dict(
rel_off_ms_q = self._quantize_onset(pedal_off_ms - start_ms)

# On message
if pedal_on_ms <= start_ms or pedal_on_ms >= end_ms:
if pedal_off_ms <= start_ms or pedal_on_ms >= end_ms:
continue
elif pedal_on_ms < start_ms and pedal_off_ms >= start_ms:
prev_toks.append("pedal")
else:
on_off_pedal.append(("pedal", 1, rel_on_ms_q, None))

Expand All @@ -200,7 +203,7 @@ def _tokenize_midi_dict(
(0 if x[0] == "pedal" else 1 if x[0] == "off" else 2),
)
)
random.shuffle(prev_notes)
random.shuffle(prev_toks)

tokenized_seq = []
for tok in on_off_combined:
Expand All @@ -220,7 +223,7 @@ def _tokenize_midi_dict(
tokenized_seq.append(("pedal", _val))
tokenized_seq.append(("onset", _onset))

prefix = [("prev", p) for p in prev_notes]
prefix = [("prev", p) for p in prev_toks]

# Add eos_tok only if segment includes end of midi_dict
if last_msg_ms < end_ms:
Expand Down Expand Up @@ -271,7 +274,21 @@ def _detokenize_midi_dict(
if DEBUG:
raise Exception

notes_to_close[tok[1]] = (0, self.default_velocity)
if tok[1] == "pedal":
pedal_msgs.append(
{
"type": "pedal",
"data": 1,
"tick": 0,
"channel": 0,
}
)
elif isinstance(tok[1], int):
notes_to_close[tok[1]] = (0, self.default_velocity)
else:
print(f"Invalid 'prev' token: {tok}")
if DEBUG:
raise Exception
else:
raise Exception(
f"Invalid note sequence at position {idx}: {tok, tokenized_seq[:idx]}"
Expand All @@ -293,11 +310,9 @@ def _detokenize_midi_dict(
if DEBUG:
raise Exception
elif tok_1_type == "pedal":
# Pedal information contained in note-off messages, so we don't
# need to manually processes them
_pedal_data = tok_1_data
_tick = tok_2_data
note_msgs.append(
pedal_msgs.append(
{
"type": "pedal",
"data": _pedal_data,
Expand Down Expand Up @@ -454,13 +469,11 @@ def msg_mixup(src: list):

# Shuffle order and re-append to result
for k, v in sorted(buffer.items()):
off_pedal_combined = v["off"] + v["pedal"]
random.shuffle(off_pedal_combined)
random.shuffle(v["on"])
random.shuffle(v["off"])
for item in v["pedal"]:
res.append(item[0]) # Pedal
res.append(item[1]) # Onset
for item in v["off"]:
res.append(item[0]) # Pitch
for item in off_pedal_combined:
res.append(item[0]) # Off or pedal
res.append(item[1]) # Onset
for item in v["on"]:
res.append(item[0]) # Pitch
Expand Down
12 changes: 12 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,18 @@ def test_distortion(self):
res = audio_transform.apply_distortion(wav)
torchaudio.save("tests/test_results/dist.wav", res, SAMPLE_RATE)

def test_bandpass(self):
SAMPLE_RATE, CHUNK_LEN = 16000, 30
audio_transform = AudioTransform()
wav, sr = torchaudio.load("tests/test_data/147.wav")
wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE).mean(
0, keepdim=True
)[:, : SAMPLE_RATE * CHUNK_LEN]

torchaudio.save("tests/test_results/orig.wav", wav, SAMPLE_RATE)
res = audio_transform.apply_bandpass(wav)
torchaudio.save("tests/test_results/bandpass.wav", res, SAMPLE_RATE)

def test_applause(self):
SAMPLE_RATE, CHUNK_LEN = 16000, 30
audio_transform = AudioTransform()
Expand Down

0 comments on commit d6fea7f

Please sign in to comment.