Skip to content

Commit

Permalink
fix tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad committed Feb 22, 2024
1 parent f5c8d81 commit 97286a2
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 96 deletions.
6 changes: 2 additions & 4 deletions amt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,8 @@ def _format(tok):

self.file_mmap.seek(self.index[idx])

# This isn't going to load properly
spec, _seq = json.loads(
self.file_mmap.readline()
) # Load data from line
# Load data from line
spec, _seq = json.loads(self.file_mmap.readline())

spec = torch.tensor(spec) # Format spectrogram into tensor
_seq = [_format(tok) for tok in _seq] # Format seq
Expand Down
108 changes: 44 additions & 64 deletions amt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,32 +81,19 @@ def _tokenize_midi_dict(
start_ms: int,
end_ms: int,
):

assert (
end_ms - start_ms <= self.max_onset
), "Invalid values for start_ms, end_ms"

channel_to_pedal_intervals = self._build_pedal_intervals(midi_dict)
on_off_notes = defaultdict(
list
) # pitch: [(onset, offset, velocity), ...]
midi_dict.resolve_pedal() # Important !!
on_off_notes = []
prev_notes = []
for msg in midi_dict.note_msgs:
_channel = msg["channel"]
_pitch = msg["data"]["pitch"]
_velocity = msg["data"]["velocity"]
_start_tick = msg["data"]["start"]
_end_tick = msg["data"]["end"]

# Update end tick if affected by pedal
for pedal_interval in channel_to_pedal_intervals[_channel]:
pedal_start, pedal_end = (
pedal_interval[0],
pedal_interval[1],
)
if pedal_start < _end_tick < pedal_end:
_end_tick = pedal_end
break

note_start_ms = get_duration_ms(
start_tick=0,
end_tick=_start_tick,
Expand All @@ -128,68 +115,60 @@ def _tokenize_midi_dict(
if note_end_ms <= start_ms or note_start_ms >= end_ms: # Skip
continue
elif (
note_start_ms < start_ms # and _pitch not in _prev_notes
note_start_ms < start_ms and _pitch not in prev_notes
): # Add to prev notes
if note_end_ms >= end_ms:
# We do this so we can detect it later (don't add off tok)
rel_note_end_ms_q += 1
on_off_notes[_pitch].append(
(-1, rel_note_end_ms_q, self.default_velocity)
)
prev_notes.append(_pitch)
if note_end_ms < end_ms:
on_off_notes.append(
("off", _pitch, rel_note_end_ms_q, None)
)
else: # Add to on_off_msgs
# Skip notes with no duration or duplicate notes
if rel_note_start_ms_q == rel_note_end_ms_q:
continue
elif rel_note_start_ms_q in [
t[0] for t in on_off_notes[_pitch]
]:
if (
"on",
_pitch,
rel_note_start_ms_q,
velocity_q,
) in on_off_notes:
continue

if note_end_ms >= end_ms:
# Same idea as before
rel_note_end_ms_q += 1

on_off_notes[_pitch].append(
(rel_note_start_ms_q, rel_note_end_ms_q, velocity_q)
on_off_notes.append(
("on", _pitch, rel_note_start_ms_q, velocity_q)
)
if note_end_ms < end_ms:
on_off_notes.append(
("off", _pitch, rel_note_end_ms_q, None)
)

on_off_notes.sort(key=lambda x: (x[2], x[0] == "on"))
random.shuffle(prev_notes)

# Resolve note overlaps
for k, v in on_off_notes.items():
if k == 64:
pass
v.sort(key=lambda x: x[0])
on_ms_buff, off_ms_buff, vel_buff = -3, -2, 0
for idx, (on_ms, off_ms, vel) in enumerate(v):
if off_ms_buff > on_ms:
# Adjust previous off so that it doesn't interupt
v[idx - 1] = (on_ms_buff, on_ms, vel_buff)
on_ms_buff, off_ms_buff, vel_buff = on_ms, off_ms, vel

_note_msgs = []
for k, v in on_off_notes.items():
for on_ms, off_ms, vel in v:
_note_msgs.append(("on", k, on_ms, vel))

if off_ms <= self.max_onset:
_note_msgs.append(("off", k, off_ms, None))

_note_msgs.sort(key=lambda x: (x[2], x[0] == "on")) # Check
tokenized_seq = []
prefix = []
for note in _note_msgs:
note_status = {}
for pitch in prev_notes:
note_status[pitch] = True
for note in on_off_notes:
_type, _pitch, _onset, _velocity = note
if _type == "on":
if _onset < 0:
prefix.append(("prev", _pitch))
if note_status.get(_pitch) == True:
# Place holder - we can remove note_status logic now
raise Exception

tokenized_seq.append(("on", _pitch))
tokenized_seq.append(("onset", _onset))
tokenized_seq.append(("vel", _velocity))
note_status[_pitch] = True
elif _type == "off":
if note_status.get(_pitch) == False:
# Place holder - we can remove note_status logic now
raise Exception
else:
tokenized_seq.append(("on", _pitch))
tokenized_seq.append(("off", _pitch))
tokenized_seq.append(("onset", _onset))
tokenized_seq.append(("vel", _velocity))
elif _type == "off":
tokenized_seq.append(("off", _pitch))
tokenized_seq.append(("onset", _onset))
note_status[_pitch] = False

random.shuffle(prefix)
prefix = [("prev", p) for p in prev_notes]
return prefix + [self.bos_tok] + tokenized_seq + [self.eos_tok]

def _detokenize_midi_dict(
Expand All @@ -200,6 +179,7 @@ def _detokenize_midi_dict(
):
# NOTE: These values chosen so that 1000 ticks = 1000ms, allowing us to
# skip converting between ticks and ms
assert len_ms > 0, "len_ms must be positive"
TICKS_PER_BEAT = 500
TEMPO = 500000

Expand Down
10 changes: 7 additions & 3 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,19 @@ def test_maestro(self):
src_dec, tgt_dec = tokenizer.decode(src), tokenizer.decode(tgt)
if (idx + 1) % 200 == 0:
break
if idx % 25 == 0:
if idx % 7 == 0:
src_mid_dict = tokenizer._detokenize_midi_dict(
src_dec, len_ms=30000
src_dec,
len_ms=30000,
)

src_mid = src_mid_dict.to_midi()
if idx % 10 == 0:
src_mid.save(f"tests/test_results/dataset_{idx}.mid")

for src_tok, tgt_tok in enumerate(zip(src_dec[1:], tgt_dec)):
self.assertTrue(tokenizer.unk_tok not in src_dec)
self.assertTrue(tokenizer.unk_tok not in tgt_dec)
for src_tok, tgt_tok in zip(src_dec[1:], tgt_dec):
self.assertEqual(src_tok, tgt_tok)


Expand Down
47 changes: 22 additions & 25 deletions tests/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,34 @@
os.mkdir("tests/test_results")


# Add test for unk tok


class TestAmtTokenizer(unittest.TestCase):
def test_tokenize(self):
def _tokenize_detokenize(mid_name: str):
START = 0
END = 30000

def _tokenize_detokenize(mid_name: str, start: int, end: int):
length = end - start
tokenizer = AmtTokenizer()
midi_dict = MidiDict.from_midi(f"tests/test_data/{mid_name}")
tokenized_seq = tokenizer._tokenize_midi_dict(
midi_dict=midi_dict,
start_ms=START,
end_ms=END,
)
logging.info(f"{mid_name} tokenized:")
logging.info(tokenized_seq)

_midi_dict = tokenizer._detokenize_midi_dict(
tokenized_seq, END - START
)
logging.info(f"tokenizing {mid_name} in range ({start}, {end})...")
tokenized_seq = tokenizer._tokenize_midi_dict(midi_dict, start, end)
tokenized_seq = tokenizer.decode(tokenizer.encode(tokenized_seq))
self.assertTrue(tokenizer.unk_tok not in tokenized_seq)
_midi_dict = tokenizer._detokenize_midi_dict(tokenized_seq, length)
_mid = _midi_dict.to_midi()
_mid.save(f"tests/test_results/{mid_name}")
logging.info(f"{mid_name} note_msgs:")
for msg in _midi_dict.note_msgs:
logging.info(msg)

_tokenize_detokenize(mid_name="basic.mid")
_tokenize_detokenize(mid_name="147.mid")
_tokenize_detokenize(mid_name="beethoven_moonlight.mid")
_tokenize_detokenize(mid_name="maestro1.mid")
_tokenize_detokenize(mid_name="maestro2.mid")
_tokenize_detokenize(mid_name="maestro3.mid")
_mid.save(f"tests/test_results/{start}_{end}_{mid_name}")

_tokenize_detokenize("basic.mid", start=0, end=30000)
_tokenize_detokenize("147.mid", start=0, end=30000)
_tokenize_detokenize("beethoven_moonlight.mid", start=0, end=30000)

for _idx in range(5):
START = _idx * 25000
END = (_idx + 1) * 25000
_tokenize_detokenize("maestro1.mid", start=START, end=END)
_tokenize_detokenize("maestro2.mid", start=START, end=END)
_tokenize_detokenize("maestro3.mid", start=START, end=END)

def test_aug(self):
def aug(_midi_dict: MidiDict, _start_ms: int, _end_ms: int):
Expand Down

0 comments on commit 97286a2

Please sign in to comment.