Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tokenizer note overlaps #7

Merged
merged 25 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions amt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ def get_features(audio_path: str, mid_path: str | None = None):

if not os.path.isfile(audio_path):
return None
if (mid_path is not None) and (not os.path.isfile(mid_path)):

if mid_path is not None:
pass
elif not os.path.isfile(mid_path):
return None

try:
Expand All @@ -40,7 +43,7 @@ def get_features(audio_path: str, mid_path: str | None = None):
midi_dict = None
except Exception as e:
print("Failed to convert files into features")
return None
raise e

_, total_frames = log_spec.shape
res = []
Expand All @@ -62,10 +65,7 @@ def get_features(audio_path: str, mid_path: str | None = None):

def get_features_mp(args):
"""Multiprocessing wrapper for get_features"""
try:
res = get_features(*args)
except Exception as e:
res = None
res = get_features(*args)

if res is None:
return False, None
Expand Down
60 changes: 47 additions & 13 deletions amt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def sinusoids(length, channels, max_timescale=10000):
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
assert n_state % n_head == 0, "n_head does not evenly devide n_state"

self.n_head = n_head
self.d_head = n_state // n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
Expand All @@ -93,20 +96,52 @@ def forward(
k = kv_cache[self.key]
v = kv_cache[self.value]

# Use flash attention here !!
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
# debug = True
# if debug is True:
# print(f"q shape: {q.shape}")
# print(f"k shape: {k.shape}")
# print(f"v shape: {v.shape}")
# print(f"mask shape: {mask.shape}")
# Old ---
# wv, qk = self.qkv_attention(q, k, v, mask)
# End ---

# New code ------
debug = False
# Reshape and transpose for attention calculation
batch_size, target_seq_len, _ = q.shape
batch_size, source_seq_len, _ = k.shape
q = q.view(batch_size, target_seq_len, self.n_head, self.d_head)
k = k.view(batch_size, source_seq_len, self.n_head, self.d_head)
v = v.view(batch_size, source_seq_len, self.n_head, self.d_head)

# (bz, L, nh, dh) -> (bz, nh, L, dh)
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))

if debug is True:
print(f"q shape: {q.shape}")
print(f"k shape: {k.shape}")
print(f"v shape: {v.shape}")
if mask is not None:
print(f"mask shape: {mask.shape}")

if mask == None:
_is_causal = False
else:
_is_causal = True

qk = None # Only used during kv-caching?
wv = F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
is_causal=_is_causal,
)

wv, qk = self.qkv_attention(q, k, v, mask)
# (bz, nh, L, dh) -> (bz, L, nh, dh) -> (bz, L, d)
wv = wv.transpose(1, 2)
wv = wv.view(batch_size, target_seq_len, self.n_head * self.d_head)

# if debug is True:
# print(f"att_out shape: {wv.shape}")
# print(f"att_weights shape: {qk.shape}")
if debug is True:
print(f"att_out shape: {wv.shape}")
if qk is not None:
print(f"att_weights shape: {qk.shape}")

# End new code ------

return self.out(wv), qk

Expand Down Expand Up @@ -209,7 +244,6 @@ def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()

self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))

Expand Down
149 changes: 88 additions & 61 deletions amt/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import random
import os
import copy

from collections import defaultdict

Expand All @@ -7,16 +9,11 @@
from amt.config import load_config


# 3am idea:
# Randomly mixup the order of on/off msgs in small windows (100ms?)
# This way the model can bo back and add notes in the past at inference time
# Make this easy codewise by sorting the msgs by onset when converting from
# tokenized_seq to midi_dict

# Instead of doing this, we could calculate beams at inference time, selecting
# the note with the first onset so that we don't miss notes.

# TODO: Move start token to after prev tokens? Instead of at the very start

DEBUG = os.getenv("DEBUG")


class AmtTokenizer(Tokenizer):
Expand All @@ -32,6 +29,7 @@ def __init__(self, return_tensors: bool = False):
self.onset_time_quantizations = [
i * self.time_step for i in range(self.num_steps + 1)
]
self.max_onset = self.onset_time_quantizations[-1]

# Calculate velocity quantizations
self.default_velocity = self.config["velocity_quantization"]["default"]
Expand Down Expand Up @@ -76,18 +74,22 @@ def _quantize_velocity(self, velocity: int):
else:
return velocity_quantized

# Go back and make sure that the use of < or <= are as we want
# There might also be issues with very rapid notes, make sure this is
# working as intended
# This method needs to be cleaned up completely, variables renamed
def _tokenize_midi_dict(
self,
midi_dict: MidiDict,
start_ms: int,
end_ms: int,
):

assert (
end_ms - start_ms <= self.max_onset
), "Invalid values for start_ms, end_ms"

channel_to_pedal_intervals = self._build_pedal_intervals(midi_dict)
prev_notes = []
on_off_notes = []
on_off_notes = defaultdict(
list
) # pitch: [(onset, offset, velocity), ...]
for msg in midi_dict.note_msgs:
_channel = msg["channel"]
_pitch = msg["data"]["pitch"]
Expand All @@ -101,11 +103,9 @@ def _tokenize_midi_dict(
pedal_interval[0],
pedal_interval[1],
)
if (
pedal_start <= _start_tick < pedal_end
and _end_tick < pedal_end
):
if pedal_start < _end_tick < pedal_end:
_end_tick = pedal_end
break

note_start_ms = get_duration_ms(
start_tick=0,
Expand All @@ -124,65 +124,72 @@ def _tokenize_midi_dict(
rel_note_end_ms_q = self._quantize_onset(note_end_ms - start_ms)
velocity_q = self._quantize_velocity(_velocity)

assert note_start_ms < note_end_ms, "Error"
if note_end_ms <= start_ms or note_start_ms >= end_ms: # Skip
continue
elif (
note_start_ms < start_ms and _pitch not in prev_notes
note_start_ms < start_ms # and _pitch not in _prev_notes
): # Add to prev notes
prev_notes.append(_pitch)
if note_end_ms < end_ms:
on_off_notes.append(
("off", _pitch, rel_note_end_ms_q, None)
)
if note_end_ms >= end_ms:
# We do this so we can detect it later (don't add off tok)
rel_note_end_ms_q += 1
on_off_notes[_pitch].append(
(-1, rel_note_end_ms_q, self.default_velocity)
)
else: # Add to on_off_msgs
# Skip notes with no duration or duplicate notes
if rel_note_start_ms_q == rel_note_end_ms_q:
continue
elif (
"on",
_pitch,
rel_note_start_ms_q,
velocity_q,
) in on_off_notes:
elif rel_note_start_ms_q in [
t[0] for t in on_off_notes[_pitch]
]:
continue

on_off_notes.append(
("on", _pitch, rel_note_start_ms_q, velocity_q)
)
if note_end_ms < end_ms:
on_off_notes.append(
("off", _pitch, rel_note_end_ms_q, None)
)
if note_end_ms >= end_ms:
# Same idea as before
rel_note_end_ms_q += 1

on_off_notes.sort(key=lambda x: (x[2], x[0] == "on")) # Check
random.shuffle(prev_notes)
on_off_notes[_pitch].append(
(rel_note_start_ms_q, rel_note_end_ms_q, velocity_q)
)

# Resolve note overlaps
for k, v in on_off_notes.items():
if k == 64:
pass
v.sort(key=lambda x: x[0])
on_ms_buff, off_ms_buff, vel_buff = -3, -2, 0
for idx, (on_ms, off_ms, vel) in enumerate(v):
if off_ms_buff > on_ms:
# Adjust previous off so that it doesn't interupt
v[idx - 1] = (on_ms_buff, on_ms, vel_buff)
on_ms_buff, off_ms_buff, vel_buff = on_ms, off_ms, vel

_note_msgs = []
for k, v in on_off_notes.items():
for on_ms, off_ms, vel in v:
_note_msgs.append(("on", k, on_ms, vel))

if off_ms <= self.max_onset:
_note_msgs.append(("off", k, off_ms, None))

_note_msgs.sort(key=lambda x: (x[2], x[0] == "on")) # Check
tokenized_seq = []
note_status = {}
for pitch in prev_notes:
note_status[pitch] = True
for note in on_off_notes:
prefix = []
for note in _note_msgs:
_type, _pitch, _onset, _velocity = note
if _type == "on":
if note_status.get(_pitch) == True:
# If note already on, turn it off first
tokenized_seq.append(("off", _pitch))
tokenized_seq.append(("onset", _onset))

tokenized_seq.append(("on", _pitch))
tokenized_seq.append(("onset", _onset))
tokenized_seq.append(("vel", _velocity))
note_status[_pitch] = True
elif _type == "off":
if note_status.get(_pitch) == False:
# If note not on, skip
continue
if _onset < 0:
prefix.append(("prev", _pitch))
else:
tokenized_seq.append(("off", _pitch))
tokenized_seq.append(("on", _pitch))
tokenized_seq.append(("onset", _onset))
note_status[_pitch] = False
tokenized_seq.append(("vel", _velocity))
elif _type == "off":
tokenized_seq.append(("off", _pitch))
tokenized_seq.append(("onset", _onset))

prefix = [("prev", p) for p in prev_notes]
random.shuffle(prefix)
return prefix + [self.bos_tok] + tokenized_seq + [self.eos_tok]

def _detokenize_midi_dict(
Expand All @@ -196,8 +203,8 @@ def _detokenize_midi_dict(
TICKS_PER_BEAT = 500
TEMPO = 500000

# if tokenized_seq[0] == self.bos_tok:
# tokenized_seq = tokenized_seq[1:]
tokenized_seq = copy.deepcopy(tokenized_seq)

if self.eos_tok in tokenized_seq:
tokenized_seq = tokenized_seq[: tokenized_seq.index(self.eos_tok)]
if self.pad_tok in tokenized_seq:
Expand All @@ -222,9 +229,16 @@ def _detokenize_midi_dict(
if tok == self.bos_tok:
break
elif type(tok) == tuple and tok[0] == "prev":
if tok[1] in notes_to_close.keys():
print(f"Duplicate 'prev' token: {tok[1]}")
if DEBUG:
raise Exception

notes_to_close[tok[1]] = (0, self.default_velocity)
else:
raise Exception(f"Invalid note sequence: {tokenized_seq[:idx]}")
raise Exception(
f"Invalid note sequence at position {idx}: {tok, tokenized_seq[:idx]}"
)

# Process notes
for tok_1, tok_2, tok_3 in zip(
Expand All @@ -239,19 +253,30 @@ def _detokenize_midi_dict(
if tok_1_type == "prev":
notes_to_close[tok_1_data] = (0, self.default_velocity)
print("Unexpected token order: 'prev' seen after '<S>'")
if DEBUG:
raise Exception
elif tok_1_type == "on":
if (tok_2_type, tok_3_type) != ("onset", "vel"):
print("Unexpected token order")
if DEBUG:
raise Exception
else:

notes_to_close[tok_1_data] = (tok_2_data, tok_3_data)
elif tok_1_type == "off":
if tok_2_type != "onset":
print("Unexpected token order")
if DEBUG:
raise Exception
else:
# Process note and add to note msgs
note_to_close = notes_to_close.pop(tok_1_data, None)
if note_to_close is None:
print(f"No 'on' token corresponding to 'off' token")
print(
f"No 'on' token corresponding to 'off' token: {tok_1, tok_2}"
)
if DEBUG:
raise Exception
continue
else:
_pitch = tok_1_data
Expand Down Expand Up @@ -340,6 +365,8 @@ def msg_mixup(src: list):
res.append(tok)
else:
print(f"Unexpected token when processing prefix: {tok}")
if DEBUG:
raise Exception

random.shuffle(res) # Only includes prev toks
res.append(self.bos_tok) # Beggining of sequence
Expand Down
4 changes: 3 additions & 1 deletion amt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ def train_loop(
# if (overfit is True) and (of_batch_exists is True):
# pass
# else:
# mel, src, tgt = batch # (b_sz, s_len), (b_sz, s_len, v_sz)
# of_batch_exists = True
# mel, src, tgt = batch # (b_sz, s_len), (b_sz, s_len, v_sz)

mel, src, tgt = batch # (b_sz, s_len), (b_sz, s_len, v_sz)
logits = model(mel, src) # (b_sz, s_len, v_sz)
Expand Down Expand Up @@ -651,6 +651,8 @@ def train(
model_config = ModelConfig(**load_model_config(model_name))
model_config.set_vocab_size(tokenizer.vocab_size)
model = AmtEncoderDecoder(model_config)
# logger.info("Compiling model...")
# model = torch.compile(model)
logger.info(f"Loaded model with config: {load_model_config(model_name)}")
if mode == "finetune":
try:
Expand Down
Loading
Loading