diff --git a/amt/infer.py b/amt/infer.py index 8437de8..382fb65 100644 --- a/amt/infer.py +++ b/amt/infer.py @@ -7,19 +7,22 @@ from torch.multiprocessing import Queue from tqdm import tqdm +from functools import wraps +from torch.cuda import is_bf16_supported from amt.model import AmtEncoderDecoder from amt.tokenizer import AmtTokenizer -from amt.audio import AudioTransform +from amt.audio import AudioTransform, pad_or_trim from amt.data import get_wav_mid_segments + MAX_SEQ_LEN = 4096 LEN_MS = 30000 STRIDE_FACTOR = 3 CHUNK_LEN_MS = LEN_MS // STRIDE_FACTOR -BEAM = 3 -ONSET_TOLERANCE = 50 -VEL_TOLERANCE = 50 +BEAM = 5 +ONSET_TOLERANCE = 61 +VEL_TOLERANCE = 100 def _setup_logger(): @@ -105,10 +108,6 @@ def calculate_onset( return tokenizer.tok_to_id[("onset", new_onset)] -from functools import wraps -from torch.cuda import is_bf16_supported - - def optional_bf16_autocast(func): @wraps(func) def wrapper(*args, **kwargs): @@ -145,7 +144,7 @@ def process_segments( tokenizer.trunc_seq(prefix, MAX_SEQ_LEN) for prefix in raw_prefixes ] seq = torch.stack([tokenizer.encode(prefix) for prefix in prefixes]).cuda() - eos_seen = [False for _ in prefixes] + end_idxs = [MAX_SEQ_LEN for _ in prefixes] kv_cache = model.get_empty_cache() @@ -173,7 +172,7 @@ def process_segments( next_tok_ids = torch.argmax(logits[:, -1], dim=-1) for batch_idx in range(logits.shape[0]): - if eos_seen[batch_idx] is not False: + if idx > end_idxs[batch_idx]: # End already seen, add pad token tok_id = tokenizer.pad_id elif idx >= prefix_lens[batch_idx]: @@ -192,20 +191,24 @@ def process_segments( tok_id = tokenizer.tok_to_id[prefixes[batch_idx][idx]] seq[batch_idx, idx] = tok_id - if tokenizer.id_to_tok[tok_id] == tokenizer.eos_tok: - eos_seen[batch_idx] = idx - - if all(eos_seen): + tok = tokenizer.id_to_tok[tok_id] + if tok == tokenizer.eos_tok: + end_idxs[batch_idx] = idx + elif ( + type(tok) is tuple + and tok[0] == "onset" + and tok[1] >= LEN_MS - CHUNK_LEN_MS + ): + end_idxs[batch_idx] = idx - 2 + + if all(_idx <= idx for _idx in end_idxs): break - if not all(eos_seen): + if not all(_idx <= idx for _idx in end_idxs): logger.warning("Context length overflow when transcribing segment") - for _idx in range(seq.shape[0]): - if eos_seen[_idx] == False: - eos_seen[_idx] = MAX_SEQ_LEN results = [ - tokenizer.decode(seq[_idx, : eos_seen[_idx] + 1]) + tokenizer.decode(seq[_idx, : end_idxs[_idx] + 1]) for _idx in range(seq.shape[0]) ] @@ -218,7 +221,7 @@ def gpu_manager( model: AmtEncoderDecoder, batch_size: int, ): - # model.compile() + model.compile() logger = _setup_logger() audio_transform = AudioTransform().cuda() tokenizer = AmtTokenizer(return_tensors=True) @@ -283,7 +286,7 @@ def _truncate_seq( except: return [""] else: - return res[: res.index(tokenizer.eos_tok)] + return res[: res.index(tokenizer.eos_tok)] # Needs to change def process_file( @@ -302,8 +305,15 @@ def process_file( audio_path=file_path, stride_factor=STRIDE_FACTOR ) ] - seq = [""] - res = [""] + + # Add addtional (padded) final audio segment + _last_seg = audio_segments[-1] + audio_segments.append( + pad_or_trim(_last_seg[len(_last_seg) // STRIDE_FACTOR :]) + ) + + seq = [tokenizer.bos_tok] + res = [tokenizer.bos_tok] for idx, audio_seg in enumerate(audio_segments): init_idx = len(seq) @@ -318,15 +328,18 @@ def process_file( result_queue.put(gpu_result) res += _shift_onset( - seq[init_idx : seq.index(tokenizer.eos_tok)], + seq[init_idx:], idx * CHUNK_LEN_MS, ) if idx == len(audio_segments) - 1: break + elif res[-1] == tokenizer.eos_tok: + logger.info(f"Exiting early") + break else: - seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS) - if len(seq) == 1: + seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS - CHUNK_LEN_MS) + if len(seq) <= 2: logger.info(f"Exiting early") return res @@ -441,3 +454,15 @@ def batch_transcribe( p.join() gpu_manager_process.join() + + +def sample_top_p(probs, p): + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + + return next_token diff --git a/amt/tokenizer.py b/amt/tokenizer.py index 67d6072..762a10a 100644 --- a/amt/tokenizer.py +++ b/amt/tokenizer.py @@ -46,6 +46,7 @@ def __init__(self, return_tensors: bool = False): self.prev_tokens = [("prev", i) for i in range(128)] self.note_on_tokens = [("on", i) for i in range(128)] self.note_off_tokens = [("off", i) for i in range(128)] + self.pedal_tokens = [("pedal", 0), (("pedal", 1))] self.velocity_tokens = [("vel", i) for i in self.velocity_quantizations] self.onset_tokens = [ ("onset", i) for i in self.onset_time_quantizations @@ -56,6 +57,7 @@ def __init__(self, return_tensors: bool = False): + self.prev_tokens + self.note_on_tokens + self.note_off_tokens + + self.pedal_tokens + self.velocity_tokens + self.onset_tokens ) @@ -76,7 +78,10 @@ def _quantize_velocity(self, velocity: int): else: return velocity_quantized - # This method needs to be cleaned up completely, variables renamed + # TODO: + # - I need to make this method more robust, as it will have to handle + # an arbitrary MIDI file + # - Decide whether to put pedal messages as prev tokens def _tokenize_midi_dict( self, midi_dict: MidiDict, @@ -88,6 +93,12 @@ def _tokenize_midi_dict( ), "Invalid values for start_ms, end_ms" midi_dict.resolve_pedal() # Important !! + pedal_intervals = midi_dict._build_pedal_intervals() + if len(pedal_intervals.keys()) > 1: + print("Warning: midi_dict has more than one pedal channel") + pedal_intervals = pedal_intervals[0] + + last_msg_ms = -1 on_off_notes = [] prev_notes = [] for msg in midi_dict.note_msgs: @@ -109,6 +120,9 @@ def _tokenize_midi_dict( ticks_per_beat=midi_dict.ticks_per_beat, ) + if note_end_ms > last_msg_ms: + last_msg_ms = note_end_ms + rel_note_start_ms_q = self._quantize_onset(note_start_ms - start_ms) rel_note_end_ms_q = self._quantize_onset(note_end_ms - start_ms) velocity_q = self._quantize_velocity(_velocity) @@ -149,35 +163,70 @@ def _tokenize_midi_dict( ("off", _pitch, rel_note_end_ms_q, None) ) - on_off_notes.sort(key=lambda x: (x[2], x[0] == "on")) + on_off_pedal = [] + for pedal_on_tick, pedal_off_tick in pedal_intervals: + pedal_on_ms = get_duration_ms( + start_tick=0, + end_tick=pedal_on_tick, + tempo_msgs=midi_dict.tempo_msgs, + ticks_per_beat=midi_dict.ticks_per_beat, + ) + pedal_off_ms = get_duration_ms( + start_tick=0, + end_tick=pedal_off_tick, + tempo_msgs=midi_dict.tempo_msgs, + ticks_per_beat=midi_dict.ticks_per_beat, + ) + + rel_on_ms_q = self._quantize_onset(pedal_on_ms - start_ms) + rel_off_ms_q = self._quantize_onset(pedal_off_ms - start_ms) + + # On message + if pedal_on_ms <= start_ms or pedal_on_ms >= end_ms: + continue + else: + on_off_pedal.append(("pedal", 1, rel_on_ms_q, None)) + + # Off message + if pedal_off_ms <= start_ms or pedal_off_ms >= end_ms: + continue + else: + on_off_pedal.append(("pedal", 0, rel_off_ms_q, None)) + + on_off_combined = on_off_notes + on_off_pedal + on_off_combined.sort( + key=lambda x: ( + x[2], + (0 if x[0] == "pedal" else 1 if x[0] == "off" else 2), + ) + ) random.shuffle(prev_notes) tokenized_seq = [] - note_status = {} - for pitch in prev_notes: - note_status[pitch] = True - for note in on_off_notes: - _type, _pitch, _onset, _velocity = note + for tok in on_off_combined: + _type, _val, _onset, _velocity = tok if _type == "on": - if note_status.get(_pitch) == True: - # Place holder - we can remove note_status logic now - raise Exception - - tokenized_seq.append(("on", _pitch)) + tokenized_seq.append(("on", _val)) tokenized_seq.append(("onset", _onset)) tokenized_seq.append(("vel", _velocity)) - note_status[_pitch] = True elif _type == "off": - if note_status.get(_pitch) == False: - # Place holder - we can remove note_status logic now - raise Exception - else: - tokenized_seq.append(("off", _pitch)) + tokenized_seq.append(("off", _val)) + tokenized_seq.append(("onset", _onset)) + elif _type == "pedal": + if _val == 0: + tokenized_seq.append(("pedal", _val)) + tokenized_seq.append(("onset", _onset)) + elif _val: + tokenized_seq.append(("pedal", _val)) tokenized_seq.append(("onset", _onset)) - note_status[_pitch] = False prefix = [("prev", p) for p in prev_notes] - return prefix + [self.bos_tok] + tokenized_seq + [self.eos_tok] + + # Add eos_tok only if segment includes end of midi_dict + if last_msg_ms < end_ms: + return prefix + [self.bos_tok] + tokenized_seq + [self.eos_tok] + else: + return prefix + [self.bos_tok] + tokenized_seq def _detokenize_midi_dict( self, @@ -243,16 +292,29 @@ def _detokenize_midi_dict( print("Unexpected token order: 'prev' seen after ''") if DEBUG: raise Exception + elif tok_1_type == "pedal": + # Pedal information contained in note-off messages, so we don't + # need to manually processes them + _pedal_data = tok_1_data + _tick = tok_2_data + note_msgs.append( + { + "type": "pedal", + "data": _pedal_data, + "tick": _tick, + "channel": 0, + } + ) elif tok_1_type == "on": if (tok_2_type, tok_3_type) != ("onset", "vel"): - print("Unexpected token order") + print("Unexpected token order:", tok_1, tok_2, tok_3) if DEBUG: raise Exception else: notes_to_close[tok_1_data] = (tok_2_data, tok_3_data) elif tok_1_type == "off": if tok_2_type != "onset": - print("Unexpected token order") + print("Unexpected token order:", tok_1, tok_2, tok_3) if DEBUG: raise Exception else: @@ -336,9 +398,6 @@ def export_data_aug(self): def export_msg_mixup(self): def msg_mixup(src: list): - def round_to_base(n, base=150): - return base * round(n / base) - # Process bos, eos, and pad tokens orig_len = len(src) seen_pad_tok = False @@ -387,6 +446,9 @@ def round_to_base(n, base=150): elif tok_1_type == "off": _onset = tok_2_data buffer[_onset]["off"].append((tok_1, tok_2)) + elif tok_1_type == "pedal": + _onset = tok_2_data + buffer[_onset]["pedal"].append((tok_1, tok_2)) else: pass @@ -394,6 +456,9 @@ def round_to_base(n, base=150): for k, v in sorted(buffer.items()): random.shuffle(v["on"]) random.shuffle(v["off"]) + for item in v["pedal"]: + res.append(item[0]) # Pedal + res.append(item[1]) # Onset for item in v["off"]: res.append(item[0]) # Pitch res.append(item[1]) # Onset diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 64c1a36..1148c0c 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -37,8 +37,24 @@ def _tokenize_detokenize(mid_name: str, start: int, end: int): _tokenize_detokenize("maestro2.mid", start=START, end=END) _tokenize_detokenize("maestro3.mid", start=START, end=END) + def test_eos_tok(self): + tokenizer = AmtTokenizer() + midi_dict = MidiDict.from_midi(f"tests/test_data/maestro1.mid") + + cnt = 0 + while True: + seq = tokenizer._tokenize_midi_dict( + midi_dict, start_ms=cnt * 10000, end_ms=(cnt * 10000) + 30000 + ) + if len(seq) <= 2: + self.assertEqual(seq[-1], tokenizer.eos_tok) + break + else: + cnt += 1 + def test_pitch_aug(self): tokenizer = AmtTokenizer(return_tensors=True) + tensor_pitch_aug = tokenizer.export_tensor_pitch_aug() midi_dict_1 = MidiDict.from_midi("tests/test_data/maestro1.mid") midi_dict_2 = MidiDict.from_midi("tests/test_data/maestro2.mid") @@ -61,7 +77,7 @@ def test_pitch_aug(self): tokenizer.encode(seq_3), ) ) - aug_seqs = tokenizer.pitch_aug(seqs, shift=2) + aug_seqs = tensor_pitch_aug(seqs, shift=2) midi_dict_1_aug = tokenizer._detokenize_midi_dict( tokenizer.decode(aug_seqs[0]), 30000