Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tokenizer and update tests #8

Merged
merged 1 commit into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions amt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,8 @@ def _format(tok):

self.file_mmap.seek(self.index[idx])

# This isn't going to load properly
spec, _seq = json.loads(
self.file_mmap.readline()
) # Load data from line
# Load data from line
spec, _seq = json.loads(self.file_mmap.readline())

spec = torch.tensor(spec) # Format spectrogram into tensor
_seq = [_format(tok) for tok in _seq] # Format seq
Expand Down
108 changes: 44 additions & 64 deletions amt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,32 +81,19 @@ def _tokenize_midi_dict(
start_ms: int,
end_ms: int,
):

assert (
end_ms - start_ms <= self.max_onset
), "Invalid values for start_ms, end_ms"

channel_to_pedal_intervals = self._build_pedal_intervals(midi_dict)
on_off_notes = defaultdict(
list
) # pitch: [(onset, offset, velocity), ...]
midi_dict.resolve_pedal() # Important !!
on_off_notes = []
prev_notes = []
for msg in midi_dict.note_msgs:
_channel = msg["channel"]
_pitch = msg["data"]["pitch"]
_velocity = msg["data"]["velocity"]
_start_tick = msg["data"]["start"]
_end_tick = msg["data"]["end"]

# Update end tick if affected by pedal
for pedal_interval in channel_to_pedal_intervals[_channel]:
pedal_start, pedal_end = (
pedal_interval[0],
pedal_interval[1],
)
if pedal_start < _end_tick < pedal_end:
_end_tick = pedal_end
break

note_start_ms = get_duration_ms(
start_tick=0,
end_tick=_start_tick,
Expand All @@ -128,68 +115,60 @@ def _tokenize_midi_dict(
if note_end_ms <= start_ms or note_start_ms >= end_ms: # Skip
continue
elif (
note_start_ms < start_ms # and _pitch not in _prev_notes
note_start_ms < start_ms and _pitch not in prev_notes
): # Add to prev notes
if note_end_ms >= end_ms:
# We do this so we can detect it later (don't add off tok)
rel_note_end_ms_q += 1
on_off_notes[_pitch].append(
(-1, rel_note_end_ms_q, self.default_velocity)
)
prev_notes.append(_pitch)
if note_end_ms < end_ms:
on_off_notes.append(
("off", _pitch, rel_note_end_ms_q, None)
)
else: # Add to on_off_msgs
# Skip notes with no duration or duplicate notes
if rel_note_start_ms_q == rel_note_end_ms_q:
continue
elif rel_note_start_ms_q in [
t[0] for t in on_off_notes[_pitch]
]:
if (
"on",
_pitch,
rel_note_start_ms_q,
velocity_q,
) in on_off_notes:
continue

if note_end_ms >= end_ms:
# Same idea as before
rel_note_end_ms_q += 1

on_off_notes[_pitch].append(
(rel_note_start_ms_q, rel_note_end_ms_q, velocity_q)
on_off_notes.append(
("on", _pitch, rel_note_start_ms_q, velocity_q)
)
if note_end_ms < end_ms:
on_off_notes.append(
("off", _pitch, rel_note_end_ms_q, None)
)

on_off_notes.sort(key=lambda x: (x[2], x[0] == "on"))
random.shuffle(prev_notes)

# Resolve note overlaps
for k, v in on_off_notes.items():
if k == 64:
pass
v.sort(key=lambda x: x[0])
on_ms_buff, off_ms_buff, vel_buff = -3, -2, 0
for idx, (on_ms, off_ms, vel) in enumerate(v):
if off_ms_buff > on_ms:
# Adjust previous off so that it doesn't interupt
v[idx - 1] = (on_ms_buff, on_ms, vel_buff)
on_ms_buff, off_ms_buff, vel_buff = on_ms, off_ms, vel

_note_msgs = []
for k, v in on_off_notes.items():
for on_ms, off_ms, vel in v:
_note_msgs.append(("on", k, on_ms, vel))

if off_ms <= self.max_onset:
_note_msgs.append(("off", k, off_ms, None))

_note_msgs.sort(key=lambda x: (x[2], x[0] == "on")) # Check
tokenized_seq = []
prefix = []
for note in _note_msgs:
note_status = {}
for pitch in prev_notes:
note_status[pitch] = True
for note in on_off_notes:
_type, _pitch, _onset, _velocity = note
if _type == "on":
if _onset < 0:
prefix.append(("prev", _pitch))
if note_status.get(_pitch) == True:
# Place holder - we can remove note_status logic now
raise Exception

tokenized_seq.append(("on", _pitch))
tokenized_seq.append(("onset", _onset))
tokenized_seq.append(("vel", _velocity))
note_status[_pitch] = True
elif _type == "off":
if note_status.get(_pitch) == False:
# Place holder - we can remove note_status logic now
raise Exception
else:
tokenized_seq.append(("on", _pitch))
tokenized_seq.append(("off", _pitch))
tokenized_seq.append(("onset", _onset))
tokenized_seq.append(("vel", _velocity))
elif _type == "off":
tokenized_seq.append(("off", _pitch))
tokenized_seq.append(("onset", _onset))
note_status[_pitch] = False

random.shuffle(prefix)
prefix = [("prev", p) for p in prev_notes]
return prefix + [self.bos_tok] + tokenized_seq + [self.eos_tok]

def _detokenize_midi_dict(
Expand All @@ -200,6 +179,7 @@ def _detokenize_midi_dict(
):
# NOTE: These values chosen so that 1000 ticks = 1000ms, allowing us to
# skip converting between ticks and ms
assert len_ms > 0, "len_ms must be positive"
TICKS_PER_BEAT = 500
TEMPO = 500000

Expand Down
10 changes: 7 additions & 3 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,19 @@ def test_maestro(self):
src_dec, tgt_dec = tokenizer.decode(src), tokenizer.decode(tgt)
if (idx + 1) % 200 == 0:
break
if idx % 25 == 0:
if idx % 7 == 0:
src_mid_dict = tokenizer._detokenize_midi_dict(
src_dec, len_ms=30000
src_dec,
len_ms=30000,
)

src_mid = src_mid_dict.to_midi()
if idx % 10 == 0:
src_mid.save(f"tests/test_results/dataset_{idx}.mid")

for src_tok, tgt_tok in enumerate(zip(src_dec[1:], tgt_dec)):
self.assertTrue(tokenizer.unk_tok not in src_dec)
self.assertTrue(tokenizer.unk_tok not in tgt_dec)
for src_tok, tgt_tok in zip(src_dec[1:], tgt_dec):
self.assertEqual(src_tok, tgt_tok)


Expand Down
47 changes: 22 additions & 25 deletions tests/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,34 @@
os.mkdir("tests/test_results")


# Add test for unk tok


class TestAmtTokenizer(unittest.TestCase):
def test_tokenize(self):
def _tokenize_detokenize(mid_name: str):
START = 0
END = 30000

def _tokenize_detokenize(mid_name: str, start: int, end: int):
length = end - start
tokenizer = AmtTokenizer()
midi_dict = MidiDict.from_midi(f"tests/test_data/{mid_name}")
tokenized_seq = tokenizer._tokenize_midi_dict(
midi_dict=midi_dict,
start_ms=START,
end_ms=END,
)
logging.info(f"{mid_name} tokenized:")
logging.info(tokenized_seq)

_midi_dict = tokenizer._detokenize_midi_dict(
tokenized_seq, END - START
)
logging.info(f"tokenizing {mid_name} in range ({start}, {end})...")
tokenized_seq = tokenizer._tokenize_midi_dict(midi_dict, start, end)
tokenized_seq = tokenizer.decode(tokenizer.encode(tokenized_seq))
self.assertTrue(tokenizer.unk_tok not in tokenized_seq)
_midi_dict = tokenizer._detokenize_midi_dict(tokenized_seq, length)
_mid = _midi_dict.to_midi()
_mid.save(f"tests/test_results/{mid_name}")
logging.info(f"{mid_name} note_msgs:")
for msg in _midi_dict.note_msgs:
logging.info(msg)

_tokenize_detokenize(mid_name="basic.mid")
_tokenize_detokenize(mid_name="147.mid")
_tokenize_detokenize(mid_name="beethoven_moonlight.mid")
_tokenize_detokenize(mid_name="maestro1.mid")
_tokenize_detokenize(mid_name="maestro2.mid")
_tokenize_detokenize(mid_name="maestro3.mid")
_mid.save(f"tests/test_results/{start}_{end}_{mid_name}")

_tokenize_detokenize("basic.mid", start=0, end=30000)
_tokenize_detokenize("147.mid", start=0, end=30000)
_tokenize_detokenize("beethoven_moonlight.mid", start=0, end=30000)

for _idx in range(5):
START = _idx * 25000
END = (_idx + 1) * 25000
_tokenize_detokenize("maestro1.mid", start=START, end=END)
_tokenize_detokenize("maestro2.mid", start=START, end=END)
_tokenize_detokenize("maestro3.mid", start=START, end=END)

def test_aug(self):
def aug(_midi_dict: MidiDict, _start_ms: int, _end_ms: int):
Expand Down
Loading