diff --git a/amt/data.py b/amt/data.py index 2b5c468..ee87c33 100644 --- a/amt/data.py +++ b/amt/data.py @@ -105,10 +105,8 @@ def _format(tok): self.file_mmap.seek(self.index[idx]) - # This isn't going to load properly - spec, _seq = json.loads( - self.file_mmap.readline() - ) # Load data from line + # Load data from line + spec, _seq = json.loads(self.file_mmap.readline()) spec = torch.tensor(spec) # Format spectrogram into tensor _seq = [_format(tok) for tok in _seq] # Format seq diff --git a/amt/tokenizer.py b/amt/tokenizer.py index 1c6cc07..180ab6c 100644 --- a/amt/tokenizer.py +++ b/amt/tokenizer.py @@ -81,32 +81,19 @@ def _tokenize_midi_dict( start_ms: int, end_ms: int, ): - assert ( end_ms - start_ms <= self.max_onset ), "Invalid values for start_ms, end_ms" - channel_to_pedal_intervals = self._build_pedal_intervals(midi_dict) - on_off_notes = defaultdict( - list - ) # pitch: [(onset, offset, velocity), ...] + midi_dict.resolve_pedal() # Important !! + on_off_notes = [] + prev_notes = [] for msg in midi_dict.note_msgs: - _channel = msg["channel"] _pitch = msg["data"]["pitch"] _velocity = msg["data"]["velocity"] _start_tick = msg["data"]["start"] _end_tick = msg["data"]["end"] - # Update end tick if affected by pedal - for pedal_interval in channel_to_pedal_intervals[_channel]: - pedal_start, pedal_end = ( - pedal_interval[0], - pedal_interval[1], - ) - if pedal_start < _end_tick < pedal_end: - _end_tick = pedal_end - break - note_start_ms = get_duration_ms( start_tick=0, end_tick=_start_tick, @@ -128,68 +115,60 @@ def _tokenize_midi_dict( if note_end_ms <= start_ms or note_start_ms >= end_ms: # Skip continue elif ( - note_start_ms < start_ms # and _pitch not in _prev_notes + note_start_ms < start_ms and _pitch not in prev_notes ): # Add to prev notes - if note_end_ms >= end_ms: - # We do this so we can detect it later (don't add off tok) - rel_note_end_ms_q += 1 - on_off_notes[_pitch].append( - (-1, rel_note_end_ms_q, self.default_velocity) - ) + prev_notes.append(_pitch) + if note_end_ms < end_ms: + on_off_notes.append( + ("off", _pitch, rel_note_end_ms_q, None) + ) else: # Add to on_off_msgs # Skip notes with no duration or duplicate notes if rel_note_start_ms_q == rel_note_end_ms_q: continue - elif rel_note_start_ms_q in [ - t[0] for t in on_off_notes[_pitch] - ]: + if ( + "on", + _pitch, + rel_note_start_ms_q, + velocity_q, + ) in on_off_notes: continue - - if note_end_ms >= end_ms: - # Same idea as before - rel_note_end_ms_q += 1 - - on_off_notes[_pitch].append( - (rel_note_start_ms_q, rel_note_end_ms_q, velocity_q) + on_off_notes.append( + ("on", _pitch, rel_note_start_ms_q, velocity_q) ) + if note_end_ms < end_ms: + on_off_notes.append( + ("off", _pitch, rel_note_end_ms_q, None) + ) + + on_off_notes.sort(key=lambda x: (x[2], x[0] == "on")) + random.shuffle(prev_notes) - # Resolve note overlaps - for k, v in on_off_notes.items(): - if k == 64: - pass - v.sort(key=lambda x: x[0]) - on_ms_buff, off_ms_buff, vel_buff = -3, -2, 0 - for idx, (on_ms, off_ms, vel) in enumerate(v): - if off_ms_buff > on_ms: - # Adjust previous off so that it doesn't interupt - v[idx - 1] = (on_ms_buff, on_ms, vel_buff) - on_ms_buff, off_ms_buff, vel_buff = on_ms, off_ms, vel - - _note_msgs = [] - for k, v in on_off_notes.items(): - for on_ms, off_ms, vel in v: - _note_msgs.append(("on", k, on_ms, vel)) - - if off_ms <= self.max_onset: - _note_msgs.append(("off", k, off_ms, None)) - - _note_msgs.sort(key=lambda x: (x[2], x[0] == "on")) # Check tokenized_seq = [] - prefix = [] - for note in _note_msgs: + note_status = {} + for pitch in prev_notes: + note_status[pitch] = True + for note in on_off_notes: _type, _pitch, _onset, _velocity = note if _type == "on": - if _onset < 0: - prefix.append(("prev", _pitch)) + if note_status.get(_pitch) == True: + # Place holder - we can remove note_status logic now + raise Exception + + tokenized_seq.append(("on", _pitch)) + tokenized_seq.append(("onset", _onset)) + tokenized_seq.append(("vel", _velocity)) + note_status[_pitch] = True + elif _type == "off": + if note_status.get(_pitch) == False: + # Place holder - we can remove note_status logic now + raise Exception else: - tokenized_seq.append(("on", _pitch)) + tokenized_seq.append(("off", _pitch)) tokenized_seq.append(("onset", _onset)) - tokenized_seq.append(("vel", _velocity)) - elif _type == "off": - tokenized_seq.append(("off", _pitch)) - tokenized_seq.append(("onset", _onset)) + note_status[_pitch] = False - random.shuffle(prefix) + prefix = [("prev", p) for p in prev_notes] return prefix + [self.bos_tok] + tokenized_seq + [self.eos_tok] def _detokenize_midi_dict( @@ -200,6 +179,7 @@ def _detokenize_midi_dict( ): # NOTE: These values chosen so that 1000 ticks = 1000ms, allowing us to # skip converting between ticks and ms + assert len_ms > 0, "len_ms must be positive" TICKS_PER_BEAT = 500 TEMPO = 500000 diff --git a/tests/test_data.py b/tests/test_data.py index 58f3ada..7e2d85d 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -55,15 +55,19 @@ def test_maestro(self): src_dec, tgt_dec = tokenizer.decode(src), tokenizer.decode(tgt) if (idx + 1) % 200 == 0: break - if idx % 25 == 0: + if idx % 7 == 0: src_mid_dict = tokenizer._detokenize_midi_dict( - src_dec, len_ms=30000 + src_dec, + len_ms=30000, ) + src_mid = src_mid_dict.to_midi() if idx % 10 == 0: src_mid.save(f"tests/test_results/dataset_{idx}.mid") - for src_tok, tgt_tok in enumerate(zip(src_dec[1:], tgt_dec)): + self.assertTrue(tokenizer.unk_tok not in src_dec) + self.assertTrue(tokenizer.unk_tok not in tgt_dec) + for src_tok, tgt_tok in zip(src_dec[1:], tgt_dec): self.assertEqual(src_tok, tgt_tok) diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index c7fb31d..d6e5d5f 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -11,37 +11,34 @@ os.mkdir("tests/test_results") +# Add test for unk tok + + class TestAmtTokenizer(unittest.TestCase): def test_tokenize(self): - def _tokenize_detokenize(mid_name: str): - START = 0 - END = 30000 - + def _tokenize_detokenize(mid_name: str, start: int, end: int): + length = end - start tokenizer = AmtTokenizer() midi_dict = MidiDict.from_midi(f"tests/test_data/{mid_name}") - tokenized_seq = tokenizer._tokenize_midi_dict( - midi_dict=midi_dict, - start_ms=START, - end_ms=END, - ) - logging.info(f"{mid_name} tokenized:") - logging.info(tokenized_seq) - _midi_dict = tokenizer._detokenize_midi_dict( - tokenized_seq, END - START - ) + logging.info(f"tokenizing {mid_name} in range ({start}, {end})...") + tokenized_seq = tokenizer._tokenize_midi_dict(midi_dict, start, end) + tokenized_seq = tokenizer.decode(tokenizer.encode(tokenized_seq)) + self.assertTrue(tokenizer.unk_tok not in tokenized_seq) + _midi_dict = tokenizer._detokenize_midi_dict(tokenized_seq, length) _mid = _midi_dict.to_midi() - _mid.save(f"tests/test_results/{mid_name}") - logging.info(f"{mid_name} note_msgs:") - for msg in _midi_dict.note_msgs: - logging.info(msg) - - _tokenize_detokenize(mid_name="basic.mid") - _tokenize_detokenize(mid_name="147.mid") - _tokenize_detokenize(mid_name="beethoven_moonlight.mid") - _tokenize_detokenize(mid_name="maestro1.mid") - _tokenize_detokenize(mid_name="maestro2.mid") - _tokenize_detokenize(mid_name="maestro3.mid") + _mid.save(f"tests/test_results/{start}_{end}_{mid_name}") + + _tokenize_detokenize("basic.mid", start=0, end=30000) + _tokenize_detokenize("147.mid", start=0, end=30000) + _tokenize_detokenize("beethoven_moonlight.mid", start=0, end=30000) + + for _idx in range(5): + START = _idx * 25000 + END = (_idx + 1) * 25000 + _tokenize_detokenize("maestro1.mid", start=START, end=END) + _tokenize_detokenize("maestro2.mid", start=START, end=END) + _tokenize_detokenize("maestro3.mid", start=START, end=END) def test_aug(self): def aug(_midi_dict: MidiDict, _start_ms: int, _end_ms: int):