Skip to content

Commit

Permalink
Add pedal msgs to tokenizer (#18)
Browse files Browse the repository at this point in the history
* add pedal msgs to tokenizer

* fix eos token

* format

* improve inference
  • Loading branch information
loubbrad authored Mar 11, 2024
1 parent 12d249b commit d56e8e5
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 52 deletions.
77 changes: 51 additions & 26 deletions amt/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,22 @@

from torch.multiprocessing import Queue
from tqdm import tqdm
from functools import wraps
from torch.cuda import is_bf16_supported

from amt.model import AmtEncoderDecoder
from amt.tokenizer import AmtTokenizer
from amt.audio import AudioTransform
from amt.audio import AudioTransform, pad_or_trim
from amt.data import get_wav_mid_segments


MAX_SEQ_LEN = 4096
LEN_MS = 30000
STRIDE_FACTOR = 3
CHUNK_LEN_MS = LEN_MS // STRIDE_FACTOR
BEAM = 3
ONSET_TOLERANCE = 50
VEL_TOLERANCE = 50
BEAM = 5
ONSET_TOLERANCE = 61
VEL_TOLERANCE = 100


def _setup_logger():
Expand Down Expand Up @@ -105,10 +108,6 @@ def calculate_onset(
return tokenizer.tok_to_id[("onset", new_onset)]


from functools import wraps
from torch.cuda import is_bf16_supported


def optional_bf16_autocast(func):
@wraps(func)
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -145,7 +144,7 @@ def process_segments(
tokenizer.trunc_seq(prefix, MAX_SEQ_LEN) for prefix in raw_prefixes
]
seq = torch.stack([tokenizer.encode(prefix) for prefix in prefixes]).cuda()
eos_seen = [False for _ in prefixes]
end_idxs = [MAX_SEQ_LEN for _ in prefixes]

kv_cache = model.get_empty_cache()

Expand Down Expand Up @@ -173,7 +172,7 @@ def process_segments(
next_tok_ids = torch.argmax(logits[:, -1], dim=-1)

for batch_idx in range(logits.shape[0]):
if eos_seen[batch_idx] is not False:
if idx > end_idxs[batch_idx]:
# End already seen, add pad token
tok_id = tokenizer.pad_id
elif idx >= prefix_lens[batch_idx]:
Expand All @@ -192,20 +191,24 @@ def process_segments(
tok_id = tokenizer.tok_to_id[prefixes[batch_idx][idx]]

seq[batch_idx, idx] = tok_id
if tokenizer.id_to_tok[tok_id] == tokenizer.eos_tok:
eos_seen[batch_idx] = idx

if all(eos_seen):
tok = tokenizer.id_to_tok[tok_id]
if tok == tokenizer.eos_tok:
end_idxs[batch_idx] = idx
elif (
type(tok) is tuple
and tok[0] == "onset"
and tok[1] >= LEN_MS - CHUNK_LEN_MS
):
end_idxs[batch_idx] = idx - 2

if all(_idx <= idx for _idx in end_idxs):
break

if not all(eos_seen):
if not all(_idx <= idx for _idx in end_idxs):
logger.warning("Context length overflow when transcribing segment")
for _idx in range(seq.shape[0]):
if eos_seen[_idx] == False:
eos_seen[_idx] = MAX_SEQ_LEN

results = [
tokenizer.decode(seq[_idx, : eos_seen[_idx] + 1])
tokenizer.decode(seq[_idx, : end_idxs[_idx] + 1])
for _idx in range(seq.shape[0])
]

Expand All @@ -218,7 +221,7 @@ def gpu_manager(
model: AmtEncoderDecoder,
batch_size: int,
):
# model.compile()
model.compile()
logger = _setup_logger()
audio_transform = AudioTransform().cuda()
tokenizer = AmtTokenizer(return_tensors=True)
Expand Down Expand Up @@ -283,7 +286,7 @@ def _truncate_seq(
except:
return ["<S>"]
else:
return res[: res.index(tokenizer.eos_tok)]
return res[: res.index(tokenizer.eos_tok)] # Needs to change


def process_file(
Expand All @@ -302,8 +305,15 @@ def process_file(
audio_path=file_path, stride_factor=STRIDE_FACTOR
)
]
seq = ["<S>"]
res = ["<S>"]

# Add addtional (padded) final audio segment
_last_seg = audio_segments[-1]
audio_segments.append(
pad_or_trim(_last_seg[len(_last_seg) // STRIDE_FACTOR :])
)

seq = [tokenizer.bos_tok]
res = [tokenizer.bos_tok]
for idx, audio_seg in enumerate(audio_segments):
init_idx = len(seq)

Expand All @@ -318,15 +328,18 @@ def process_file(
result_queue.put(gpu_result)

res += _shift_onset(
seq[init_idx : seq.index(tokenizer.eos_tok)],
seq[init_idx:],
idx * CHUNK_LEN_MS,
)

if idx == len(audio_segments) - 1:
break
elif res[-1] == tokenizer.eos_tok:
logger.info(f"Exiting early")
break
else:
seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS)
if len(seq) == 1:
seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS - CHUNK_LEN_MS)
if len(seq) <= 2:
logger.info(f"Exiting early")
return res

Expand Down Expand Up @@ -441,3 +454,15 @@ def batch_transcribe(
p.join()

gpu_manager_process.join()


def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)

return next_token
115 changes: 90 additions & 25 deletions amt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, return_tensors: bool = False):
self.prev_tokens = [("prev", i) for i in range(128)]
self.note_on_tokens = [("on", i) for i in range(128)]
self.note_off_tokens = [("off", i) for i in range(128)]
self.pedal_tokens = [("pedal", 0), (("pedal", 1))]
self.velocity_tokens = [("vel", i) for i in self.velocity_quantizations]
self.onset_tokens = [
("onset", i) for i in self.onset_time_quantizations
Expand All @@ -56,6 +57,7 @@ def __init__(self, return_tensors: bool = False):
+ self.prev_tokens
+ self.note_on_tokens
+ self.note_off_tokens
+ self.pedal_tokens
+ self.velocity_tokens
+ self.onset_tokens
)
Expand All @@ -76,7 +78,10 @@ def _quantize_velocity(self, velocity: int):
else:
return velocity_quantized

# This method needs to be cleaned up completely, variables renamed
# TODO:
# - I need to make this method more robust, as it will have to handle
# an arbitrary MIDI file
# - Decide whether to put pedal messages as prev tokens
def _tokenize_midi_dict(
self,
midi_dict: MidiDict,
Expand All @@ -88,6 +93,12 @@ def _tokenize_midi_dict(
), "Invalid values for start_ms, end_ms"

midi_dict.resolve_pedal() # Important !!
pedal_intervals = midi_dict._build_pedal_intervals()
if len(pedal_intervals.keys()) > 1:
print("Warning: midi_dict has more than one pedal channel")
pedal_intervals = pedal_intervals[0]

last_msg_ms = -1
on_off_notes = []
prev_notes = []
for msg in midi_dict.note_msgs:
Expand All @@ -109,6 +120,9 @@ def _tokenize_midi_dict(
ticks_per_beat=midi_dict.ticks_per_beat,
)

if note_end_ms > last_msg_ms:
last_msg_ms = note_end_ms

rel_note_start_ms_q = self._quantize_onset(note_start_ms - start_ms)
rel_note_end_ms_q = self._quantize_onset(note_end_ms - start_ms)
velocity_q = self._quantize_velocity(_velocity)
Expand Down Expand Up @@ -149,35 +163,70 @@ def _tokenize_midi_dict(
("off", _pitch, rel_note_end_ms_q, None)
)

on_off_notes.sort(key=lambda x: (x[2], x[0] == "on"))
on_off_pedal = []
for pedal_on_tick, pedal_off_tick in pedal_intervals:
pedal_on_ms = get_duration_ms(
start_tick=0,
end_tick=pedal_on_tick,
tempo_msgs=midi_dict.tempo_msgs,
ticks_per_beat=midi_dict.ticks_per_beat,
)
pedal_off_ms = get_duration_ms(
start_tick=0,
end_tick=pedal_off_tick,
tempo_msgs=midi_dict.tempo_msgs,
ticks_per_beat=midi_dict.ticks_per_beat,
)

rel_on_ms_q = self._quantize_onset(pedal_on_ms - start_ms)
rel_off_ms_q = self._quantize_onset(pedal_off_ms - start_ms)

# On message
if pedal_on_ms <= start_ms or pedal_on_ms >= end_ms:
continue
else:
on_off_pedal.append(("pedal", 1, rel_on_ms_q, None))

# Off message
if pedal_off_ms <= start_ms or pedal_off_ms >= end_ms:
continue
else:
on_off_pedal.append(("pedal", 0, rel_off_ms_q, None))

on_off_combined = on_off_notes + on_off_pedal
on_off_combined.sort(
key=lambda x: (
x[2],
(0 if x[0] == "pedal" else 1 if x[0] == "off" else 2),
)
)
random.shuffle(prev_notes)

tokenized_seq = []
note_status = {}
for pitch in prev_notes:
note_status[pitch] = True
for note in on_off_notes:
_type, _pitch, _onset, _velocity = note
for tok in on_off_combined:
_type, _val, _onset, _velocity = tok
if _type == "on":
if note_status.get(_pitch) == True:
# Place holder - we can remove note_status logic now
raise Exception

tokenized_seq.append(("on", _pitch))
tokenized_seq.append(("on", _val))
tokenized_seq.append(("onset", _onset))
tokenized_seq.append(("vel", _velocity))
note_status[_pitch] = True
elif _type == "off":
if note_status.get(_pitch) == False:
# Place holder - we can remove note_status logic now
raise Exception
else:
tokenized_seq.append(("off", _pitch))
tokenized_seq.append(("off", _val))
tokenized_seq.append(("onset", _onset))
elif _type == "pedal":
if _val == 0:
tokenized_seq.append(("pedal", _val))
tokenized_seq.append(("onset", _onset))
elif _val:
tokenized_seq.append(("pedal", _val))
tokenized_seq.append(("onset", _onset))
note_status[_pitch] = False

prefix = [("prev", p) for p in prev_notes]
return prefix + [self.bos_tok] + tokenized_seq + [self.eos_tok]

# Add eos_tok only if segment includes end of midi_dict
if last_msg_ms < end_ms:
return prefix + [self.bos_tok] + tokenized_seq + [self.eos_tok]
else:
return prefix + [self.bos_tok] + tokenized_seq

def _detokenize_midi_dict(
self,
Expand Down Expand Up @@ -243,16 +292,29 @@ def _detokenize_midi_dict(
print("Unexpected token order: 'prev' seen after '<S>'")
if DEBUG:
raise Exception
elif tok_1_type == "pedal":
# Pedal information contained in note-off messages, so we don't
# need to manually processes them
_pedal_data = tok_1_data
_tick = tok_2_data
note_msgs.append(
{
"type": "pedal",
"data": _pedal_data,
"tick": _tick,
"channel": 0,
}
)
elif tok_1_type == "on":
if (tok_2_type, tok_3_type) != ("onset", "vel"):
print("Unexpected token order")
print("Unexpected token order:", tok_1, tok_2, tok_3)
if DEBUG:
raise Exception
else:
notes_to_close[tok_1_data] = (tok_2_data, tok_3_data)
elif tok_1_type == "off":
if tok_2_type != "onset":
print("Unexpected token order")
print("Unexpected token order:", tok_1, tok_2, tok_3)
if DEBUG:
raise Exception
else:
Expand Down Expand Up @@ -336,9 +398,6 @@ def export_data_aug(self):

def export_msg_mixup(self):
def msg_mixup(src: list):
def round_to_base(n, base=150):
return base * round(n / base)

# Process bos, eos, and pad tokens
orig_len = len(src)
seen_pad_tok = False
Expand Down Expand Up @@ -387,13 +446,19 @@ def round_to_base(n, base=150):
elif tok_1_type == "off":
_onset = tok_2_data
buffer[_onset]["off"].append((tok_1, tok_2))
elif tok_1_type == "pedal":
_onset = tok_2_data
buffer[_onset]["pedal"].append((tok_1, tok_2))
else:
pass

# Shuffle order and re-append to result
for k, v in sorted(buffer.items()):
random.shuffle(v["on"])
random.shuffle(v["off"])
for item in v["pedal"]:
res.append(item[0]) # Pedal
res.append(item[1]) # Onset
for item in v["off"]:
res.append(item[0]) # Pitch
res.append(item[1]) # Onset
Expand Down
Loading

0 comments on commit d56e8e5

Please sign in to comment.