diff --git a/docs/tokenizations.rst b/docs/tokenizations.rst index 077afbd1..5a800314 100644 --- a/docs/tokenizations.rst +++ b/docs/tokenizations.rst @@ -91,6 +91,12 @@ MMM .. autoclass:: miditok.MMM :show-inheritance: +PerTok +------------------------ + +.. autoclass:: miditok.PerTok + :show-inheritance: + Create yours ------------------------ diff --git a/miditok/__init__.py b/miditok/__init__.py index 3de0180f..05013d2e 100644 --- a/miditok/__init__.py +++ b/miditok/__init__.py @@ -16,6 +16,7 @@ MIDILike, MuMIDI, Octuple, + PerTok, Structured, ) from .tokenizer_training_iterator import TokTrainingIterator @@ -34,6 +35,7 @@ "CPWord", "MuMIDI", "MMM", + "PerTok", "utils", "data_augmentation", ] diff --git a/miditok/midi_tokenizer.py b/miditok/midi_tokenizer.py index 82c96255..34c6a2ee 100644 --- a/miditok/midi_tokenizer.py +++ b/miditok/midi_tokenizer.py @@ -491,23 +491,7 @@ def preprocess_score(self, score: Score) -> Score: new_tpq = self.config.max_num_pos_per_beat # Resample time if needed (not inplace) and attribute preprocessed time sig. - if score.ticks_per_quarter != new_tpq: - # Times of time signatures copy need to be resampled too - time_signatures_soa = time_signatures_copy.numpy() - time_signatures_soa["time"] = ( - time_signatures_soa["time"] * (new_tpq / score.ticks_per_quarter) - ).astype(np.int32) - - score = score.resample(new_tpq, min_dur=1) - score.time_signatures = TimeSignature.from_numpy( - **time_signatures_soa, - ) - # Otherwise we do a copy in order to make sure no inplace operation is performed - # on the provided Score object. - # We make a copy here instead of at beginning as resample also makes a copy. - else: - score = score.copy() - score.time_signatures = time_signatures_copy + score = self._resample_score(score, new_tpq, time_signatures_copy) # Merge instruments of the same program / inst before preprocessing them. # This allows to avoid potential duplicated notes in some multitrack settings @@ -531,7 +515,7 @@ def preprocess_score(self, score: Score) -> Score: if not self._note_on_off or ( self.config.use_sustain_pedals and self.config.sustain_pedal_duration ): - if self.config.use_time_signatures: + if self.config.use_time_signatures and len(score.time_signatures) > 0: ticks_per_beat = get_score_ticks_per_beat(score) else: ticks_per_beat = np.array([[score.end(), score.ticks_per_quarter]]) @@ -595,6 +579,29 @@ def preprocess_score(self, score: Score) -> Score: return score + def _resample_score( + self, score: Score, _new_tpq: int, _time_signatures_copy: TimeSignatureTickList + ) -> Score: + if score.ticks_per_quarter != _new_tpq: + # Times of time signatures copy need to be resampled too + time_signatures_soa = _time_signatures_copy.numpy() + time_signatures_soa["time"] = ( + time_signatures_soa["time"] * (_new_tpq / score.ticks_per_quarter) + ).astype(np.int32) + + score = score.resample(_new_tpq, min_dur=1) + score.time_signatures = TimeSignature.from_numpy( + **time_signatures_soa, + ) + # Otherwise we do a copy in order to make sure no inplace operation is performed + # on the provided Score object. + # We make a copy here instead of at beginning as resample also makes a copy. + else: + score = score.copy() + score.time_signatures = _time_signatures_copy + + return score + def _filter_unsupported_time_signatures( self, time_signatures: TimeSignatureTickList ) -> None: @@ -1158,7 +1165,7 @@ def _score_to_tokens( or self.config.use_chords or self.config.use_pitch_intervals ): - if self.config.use_time_signatures: + if self.config.use_time_signatures and len(score.time_signatures) > 0: ticks_per_beat = get_score_ticks_per_beat(score) else: ticks_per_beat = np.array([[score.end(), self.time_division]]) @@ -1483,24 +1490,31 @@ def _create_track_events( ) ) else: - # `while` as there might not be any note in the next section - while note.time >= ticks_per_beat[tpb_idx, 0]: - tpb_idx += 1 - dur = self._tpb_ticks_to_tokens[ticks_per_beat[tpb_idx, 1]][ - note.duration - ] events.append( - Event( - type_="Duration", - value=dur, - time=note.start, - program=program, - desc=f"{note.duration} ticks", + self._create_duration_event( + note=note, + _program=program, + _ticks_per_beat=ticks_per_beat, + _tpb_idx=tpb_idx, ) ) return events + def _create_duration_event( + self, note: Note, _program: int, _ticks_per_beat: np.ndarray, _tpb_idx: int + ) -> Event: + while note.time >= _ticks_per_beat[_tpb_idx, 0]: + _tpb_idx += 1 + dur = self._tpb_ticks_to_tokens[_ticks_per_beat[_tpb_idx, 1]][note.duration] + return Event( + type_="Duration", + value=dur, + time=note.start, + program=_program, + desc=f"{note.duration} ticks", + ) + @staticmethod def _insert_program_change_events(events: list[Event]) -> None: """ diff --git a/miditok/tokenizations/__init__.py b/miditok/tokenizations/__init__.py index fcb51145..5b998d38 100644 --- a/miditok/tokenizations/__init__.py +++ b/miditok/tokenizations/__init__.py @@ -11,6 +11,7 @@ from .mmm import MMM from .mumidi import MuMIDI from .octuple import Octuple +from .pertok import PerTok from .remi import REMI from .structured import Structured from .tsd import TSD @@ -24,4 +25,5 @@ "CPWord", "MuMIDI", "MMM", + "PerTok", ] diff --git a/miditok/tokenizations/pertok.py b/miditok/tokenizations/pertok.py new file mode 100644 index 00000000..bb3f9e93 --- /dev/null +++ b/miditok/tokenizations/pertok.py @@ -0,0 +1,574 @@ +"""PerTok tokenizer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +from symusic import Note, Pedal, PitchBend, Score, Tempo, TimeSignature, Track + +from miditok.classes import Event, TokenizerConfig, TokSequence +from miditok.constants import DEFAULT_VELOCITY, MIDI_INSTRUMENTS, TIME_SIGNATURE +from miditok.midi_tokenizer import MusicTokenizer + +if TYPE_CHECKING: + from pathlib import Path + + from numpy.typing import NDArray + from symusic.core import TimeSignatureTickList + + +class PerTok(MusicTokenizer): + r""" + PerTok: Performance Tokenizer. + + Created by Lemonaide + https://www.lemonaide.ai/ + + Designed to capture the full spectrum of rhythmic values + (16ths, 32nds, various denominations of triplets/etc.) + in addition to velocity and microtiming performance characteristics. + It aims to achieve this while minimizing both vocabulary size and sequence length. + + Notes are encoded by 2-5 tokens: + + * TimeShift; + * Pitch; + * Velocity (optional); + * MicroTiming (optional); + * Duration (optional). + + *Timeshift* tokens are expressed as the nearest quantized value + based upon *beat_res* parameters. + The microtiming shift is then characterized as the remainder from + this quantized value. Timeshift and MicroTiming are represented + in the full ticks-per-quarter (tpq) resolution, e.g. 480 tpq. + + Additionally, *Bar* tokens are inserted at the start of each new measure. + This helps further reduce seq. length and potentially reduces the timing drift + models can develop at longer seq. lengths. + + New TokenizerConfig Options: + + * beat_res: now allows multiple, overlapping values; + * ticks_per_quarter: resolution of the MIDI timing data; + * use_microtiming: inclusion of MicroTiming tokens; + * max_microtiming_shift: float value of the farthest distance of MicroTiming shifts; + * num_microtiming_bins: total number of MicroTiming tokens. + + Example Tokenizer Config: + + .. code-block:: python + + TOKENIZER_PARAMS = { + "pitch_range": (21, 109), + "beat_res": {(0, 4): 4, (0, 4): 3}, + "special_tokens": ["PAD", "BOS", "EOS", "MASK"], + "use_chords": False, + "use_rests": False, + "use_tempos": False, + "use_time_signatures": True, + "use_programs": False, + "use_microtiming": True, + "ticks_per_quarter": 320, + "max_microtiming_shift": 0.125, + "num_microtiming_bins": 30, + } + config = TokenizerConfig(**TOKENIZER_PARAMS) + """ + + def __init__( + self, + tokenizer_config: TokenizerConfig = None, + params: str or Path or None = None, + ) -> None: + super().__init__(tokenizer_config, params) + if "ticks_per_quarter" not in self.config.additional_params: + msg = "Tokenizer config must have a value for ticks_per_quarter" + raise ValueError(msg) + + # Events which will use a "MicroTiming" token + self.microtime_events = [ + "Pitch", + "Pedal", + "PedalOff", + "PitchIntervalChord", + "PitchBend", + "Chord", + "PitchDrum", + "Program", + ] + + def _tweak_config_before_creating_voc(self) -> None: + self.tpq = self.config.additional_params["ticks_per_quarter"] + self.use_microtiming = self.config.additional_params["use_microtiming"] + if self.use_microtiming: + mt_keys = ["max_microtiming_shift", "num_microtiming_bins"] + if missing := set(mt_keys) - set(self.config.additional_params.keys()): + msg = f"TokenizerConfig is missing required keys: {', '.join(missing)}" + raise ValueError(msg) + + self.max_mt_shift = ( + self.config.additional_params["max_microtiming_shift"] * self.tpq + ) + + def _create_base_vocabulary(self) -> list[str]: + vocab = ["Bar_None"] + + # NoteOn/NoteOff/Velocity + self.timeshift_tick_values = self.create_timeshift_tick_values() + self._add_note_tokens_to_vocab_list(vocab) + + # TimeShift + vocab += [ + f"TimeShift_{self._duration_tuple_to_str(duration)}" + for duration in self.durations + ] + + # Duration + if any(self.config.use_note_duration_programs): + vocab += [ + f"Duration_{self._duration_tuple_to_str(duration)}" + for duration in self.durations + ] + + # Microtiming + if self.config.additional_params["use_microtiming"]: + mt_bins = self.config.additional_params["num_microtiming_bins"] + self.microtiming_tick_values = np.linspace( + -self.max_mt_shift, self.max_mt_shift, mt_bins + 1, dtype=np.intc + ) + + vocab += [ + f"MicroTiming_{microtiming!s}" + for microtiming in self.microtiming_tick_values + ] + + # Add additional tokens + self._add_additional_tokens_to_vocab_list(vocab) + + return list(dict.fromkeys(vocab)) + + # Methods to override base MusicTokenizer versions + # To handle MicroTiming and multiple beat_res resolutions + # This is accomplished by removing the downsampling methods + # As a result, many time-based methods need to be redesigned + + def _resample_score( + self, score: Score, _new_tpq: int, _time_signatures_copy: TimeSignatureTickList + ) -> Score: + if score.ticks_per_quarter != self.tpq: + score = score.resample(self.tpq, min_dur=1) + + return score + + def _adjust_durations( + self, notes_pedals_soa: dict[str, np.ndarray], ticks_per_beat: np.ndarray + ) -> None: + pass + + def _create_duration_event( + self, note: Note, _program: int, _ticks_per_beat: np.ndarray, _tpb_idx: int + ) -> Event: + duration_tuple = self._get_closest_duration_tuple(note.duration) + duration = ".".join(str(x) for x in duration_tuple) + + return Event( + type_="Duration", + value=duration, + time=note.start, + program=_program, + desc=f"duration {note.duration}", + ) + + def create_timeshift_tick_values(self) -> NDArray: + """ + Generate tick-based timeshift tokens. + + Returns + ------- + NDArray: Array of available timeshift values + + """ + tick_values = [0] + + for value in self.durations: + beat, subdiv, resolution = value + tick_value = int((beat + (subdiv / resolution)) * self.tpq) + tick_values.append(tick_value) + + return np.array(sorted(set(tick_values))) + + def _create_durations_tuples(self) -> list[tuple[int, int, int]]: + durations = [] + + for beat_range, resolution in self.config.beat_res.items(): + start, end = beat_range + for beat in range(start, end): + for subdiv in range(resolution): + if not (beat == 0 and subdiv == 0): + subres = (self.tpq // resolution * subdiv) if subdiv != 0 else 0 + durations.append((beat, subres, self.tpq)) + + self.min_timeshift = int( + min([(beat * res + subres) for beat, subres, res in durations]) * 0.5 + ) + + return durations + + # Utility Methods + def _get_closest_array_value( + self, value: int | float, array: NDArray + ) -> int | float: + return array[np.abs(array - value).argmin()] + + def _get_closest_duration_tuple(self, target: int) -> tuple[int, int, int]: + return min(self.durations, key=lambda x: abs((x[0] * x[-1] + x[1]) - target)) + + def _convert_durations_to_ticks(self, duration: str) -> int: + beats, subdiv, tpq = map(int, duration.split(".")) + return beats * tpq + subdiv + + def _duration_tuple_to_str(self, duration_tuple: tuple[int, int, int]) -> str: + return ".".join(str(x) for x in duration_tuple) + + def _add_time_events(self, events: list[Event], _time_division: int) -> list[Event]: + # Add time events + all_events = [] + previous_tick = 0 + ticks_per_bar = self.tpq * TIME_SIGNATURE[0] + curr_bar = 0 + + for event in events: + # Bar + bar_time = previous_tick + while event.time > ((curr_bar + 1) * ticks_per_bar - self.min_timeshift): + bar_time += ticks_per_bar - ( + bar_time % ticks_per_bar + ) # tpq=220, time=20, so add 200 to get to next bar + + all_events.append( + Event( + type_="Bar", value=None, time=bar_time, desc=f"Bar {bar_time}" + ) + ) + + curr_bar += 1 + previous_tick = curr_bar * ticks_per_bar + + # Time Signature + if event.type_ == "TimeSig": + num, den = self._parse_token_time_signature(event.value) + ticks_per_bar = den / 4 * num * self.tpq + + time_delta = event.time - previous_tick + timeshift = 0 + + # Time Shift + # Only should be placed before 'Pitch' events + if ( + time_delta >= self.min_timeshift + and event.type_ in self.microtime_events + ): + ts_tuple = self._get_closest_duration_tuple(time_delta) + ts = ".".join(str(x) for x in ts_tuple) + + all_events.append( + Event( + type_="TimeShift", + value=ts, + time=event.time, + desc=f"timeshift {ts}", + ) + ) + timeshift = ts_tuple[0] * ts_tuple[-1] + ts_tuple[1] + previous_tick += timeshift + + all_events.append(event) + + # Microtiming + # Right now hard-coded to come only after 'Pitch' tokens + # TODO: Check with Nathan on the PitchInterval logic below + if self.use_microtiming and event.type_ in self.microtime_events: + microtiming = time_delta - timeshift + closest_microtiming = int( + self._get_closest_array_value( + value=microtiming, array=self.microtiming_tick_values + ) + ) + all_events.append( + Event( + type_="MicroTiming", + value=closest_microtiming, + time=event.time, + desc=f"{closest_microtiming} microtiming", + ) + ) + return all_events + + def _tokens_to_score( + self, + tokens: TokSequence | list[TokSequence], + programs: list[tuple[int, bool]] | None = None, + ) -> Score: + r""" + Convert tokens (:class:`miditok.TokSequence`) into a ``symusic.Score``. + + This is an internal method called by ``self.decode``, intended to be + implemented by classes inheriting :class:`miditok.MusicTokenizer`. + + :param tokens: tokens to convert. Can be either a list of + :class:`miditok.TokSequence` or a list of :class:`miditok.TokSequence`s. + :param programs: programs of the tracks. If none is given, will default to + piano, program 0. (default: ``None``) + :return: the ``symusic.Score`` object. + """ + # Unsqueeze tokens in case of one_token_stream + if self.config.one_token_stream_for_programs: # ie single token seq + tokens = [tokens] + for i in range(len(tokens)): + tokens[i] = tokens[i].tokens + score = Score(self.tpq) + + mt_offset = 1 if self.use_microtiming else 0 + vel_offset = (mt_offset + 1) if self.config.use_velocities else mt_offset + dur_offset = vel_offset + 1 + + # RESULTS + tracks: dict[int, Track] = {} + tempo_changes, time_signature_changes = [], [] + + def check_inst(prog: int) -> None: + if prog not in tracks: + tracks[prog] = Track( + program=0 if prog == -1 else prog, + is_drum=prog == -1, + name="Drums" if prog == -1 else MIDI_INSTRUMENTS[prog]["name"], + ) + + def is_track_empty(track: Track) -> bool: + return ( + len(track.notes) == len(track.controls) == len(track.pitch_bends) == 0 + ) + + current_track = None # used only when one_token_stream is False + ticks_per_beat = score.ticks_per_quarter + for si, seq in enumerate(tokens): + # Set tracking variables + current_tick = 0 + curr_bar = 0 + current_program = 0 + previous_note_end = 0 + previous_pitch_onset = {prog: -128 for prog in self.config.programs} + previous_pitch_chord = {prog: -128 for prog in self.config.programs} + active_pedals = {} + ticks_per_bar = ticks_per_beat * TIME_SIGNATURE[0] + + # Set track / sequence program if needed + if not self.config.one_token_stream_for_programs: + is_drum = False + if programs is not None: + current_program, is_drum = programs[si] + elif self.config.use_programs: + for token in seq: + tok_type, tok_val = token.split("_") + if tok_type.startswith("Program"): + current_program = int(tok_val) + if current_program == -1: + is_drum, current_program = True, 0 + break + current_track = Track( + program=current_program, + is_drum=is_drum, + name="Drums" + if current_program == -1 + else MIDI_INSTRUMENTS[current_program]["name"], + ) + current_track_use_duration = ( + current_program in self.config.use_note_duration_programs + ) + + # Decode tokens + for ti, token in enumerate(seq): + tok_type, tok_val = token.split("_") + + if tok_type == "Bar": + curr_bar += 1 + current_tick += (ticks_per_bar * curr_bar) - current_tick + elif tok_type == "TimeShift": + current_tick += self._convert_durations_to_ticks(tok_val) + elif tok_type in [ + "Pitch", + "PitchDrum", + "PitchIntervalTime", + "PitchIntervalChord", + ]: + if tok_type in {"Pitch", "PitchDrum"}: + pitch = int(tok_val) + elif tok_type == "PitchIntervalTime": + pitch = previous_pitch_onset[current_program] + int(tok_val) + else: # PitchIntervalChord + pitch = previous_pitch_chord[current_program] + int(tok_val) + if ( + not self.config.pitch_range[0] + <= pitch + <= self.config.pitch_range[1] + ): + continue + + # We update previous_pitch_onset and previous_pitch_chord even if + # the try fails. + if tok_type != "PitchIntervalChord": + previous_pitch_onset[current_program] = pitch + previous_pitch_chord[current_program] = pitch + + try: + if self.use_microtiming: + mt_type, mt = seq[ti + mt_offset].split("_") + mt = int(mt) + else: + mt_type, mt = "MicroTiming", 0 + if self.config.use_velocities: + vel_type, vel = seq[ti + vel_offset].split("_") + else: + vel_type, vel = "Velocity", DEFAULT_VELOCITY + if current_track_use_duration: + dur_type, dur = seq[ti + dur_offset].split("_") + else: + dur_type = "Duration" + dur = int( + self.config.default_note_duration * ticks_per_beat + ) + if ( + mt_type == "MicroTiming" + and vel_type == "Velocity" + and dur_type == "Duration" + ): + if isinstance(dur, str): + dur = self._convert_durations_to_ticks(dur) + # dur = self._tpb_tokens_to_ticks[ticks_per_beat][dur] + mt += current_tick + new_note = Note(int(mt), dur, pitch, int(vel)) + if self.config.one_token_stream_for_programs: + check_inst(current_program) + tracks[current_program].notes.append(new_note) + else: + current_track.notes.append(new_note) + previous_note_end = max(previous_note_end, mt + dur) + except IndexError: + # A well constituted sequence should not raise an exception + # However with generated sequences this can happen, or if the + # sequence isn't finished + pass + elif tok_type == "Program": + current_program = int(tok_val) + current_track_use_duration = ( + current_program in self.config.use_note_duration_programs + ) + if ( + not self.config.one_token_stream_for_programs + and self.config.program_changes + ): + if current_program != -1: + current_track.program = current_program + else: + current_track.program = 0 + current_track.is_drum = True + elif tok_type == "Tempo" and si == 0: + tempo_changes.append(Tempo(current_tick, float(tok_val))) + elif tok_type == "TimeSig": + num, den = self._parse_token_time_signature(tok_val) + ticks_per_bar = den / 4 * num * ticks_per_beat + if si == 0: + time_signature_changes.append( + TimeSignature(int(current_tick), num, den) + ) + + elif tok_type == "Pedal": + pedal_prog = ( + int(tok_val) if self.config.use_programs else current_program + ) + if self.config.sustain_pedal_duration and ti + 1 < len(seq): + if seq[ti + 1].split("_")[0] == "Duration": + duration = self._tpb_tokens_to_ticks[ticks_per_beat][ + seq[ti + 1].split("_")[1] + ] + # Add instrument if it doesn't exist, can happen for the + # first tokens + new_pedal = Pedal(current_tick, duration) + if self.config.one_token_stream_for_programs: + check_inst(pedal_prog) + tracks[pedal_prog].pedals.append(new_pedal) + else: + current_track.pedals.append(new_pedal) + elif pedal_prog not in active_pedals: + active_pedals[pedal_prog] = current_tick + elif tok_type == "PedalOff": + pedal_prog = ( + int(tok_val) if self.config.use_programs else current_program + ) + if pedal_prog in active_pedals: + new_pedal = Pedal( + active_pedals[pedal_prog], + current_tick - active_pedals[pedal_prog], + ) + if self.config.one_token_stream_for_programs: + check_inst(pedal_prog) + tracks[pedal_prog].pedals.append( + Pedal( + active_pedals[pedal_prog], + current_tick - active_pedals[pedal_prog], + ) + ) + else: + current_track.pedals.append(new_pedal) + del active_pedals[pedal_prog] + elif tok_type == "PitchBend": + new_pitch_bend = PitchBend(current_tick, int(tok_val)) + if self.config.one_token_stream_for_programs: + check_inst(current_program) + tracks[current_program].pitch_bends.append(new_pitch_bend) + else: + current_track.pitch_bends.append(new_pitch_bend) + + if tok_type in [ + "Program", + "Tempo", + "TimeSig", + "Pedal", + "PedalOff", + "PitchBend", + "Chord", + ]: + previous_note_end = max(previous_note_end, current_tick) + + # Add current_inst to the score and handle notes still active + if not self.config.one_token_stream_for_programs and not is_track_empty( + current_track + ): + score.tracks.append(current_track) + + # Add global events to the score + if self.config.one_token_stream_for_programs: + score.tracks = list(tracks.values()) + score.tempos = tempo_changes + if time_signature_changes is None: + num, den = TIME_SIGNATURE + time_signature_changes.append(TimeSignature(0, num, den)) + score.time_signatures = time_signature_changes + + return score + + def _tokens_errors(self, _tokens: list[str | list[str]]) -> int: + return 0 + + def _create_token_types_graph(self) -> dict[str, set[str]]: + r""" + Return a graph/dictionary of the possible token types successions. + + :return: the token types transitions dictionary. + """ + # Bar, TimeSig, TimeShift, Pitch, MT, Velocity, Duration + + dic: dict[str, set[str]] = {} + return dic diff --git a/tests/test_saving_loading_config.py b/tests/test_saving_loading_config.py index 3d069d4b..8be74119 100644 --- a/tests/test_saving_loading_config.py +++ b/tests/test_saving_loading_config.py @@ -24,6 +24,10 @@ "num_tempos": 32, "tempo_range": (40, 250), "base_tokenizer": "TSD", + "use_microtiming": True, + "ticks_per_quarter": 480, + "max_microtiming_shift": 0.25, + "num_microtiming_bins": 110, } TOK_PARAMS_MULTITRACK = [] @@ -41,6 +45,11 @@ params_["base_tokenizer"] = "TSD" elif tokenization_ in ["Octuple", "MuMIDI"]: params_["max_bar_embedding"] = MAX_BAR_EMBEDDING + elif tokenization_ in ["PerTok"]: + params_["use_microtiming"] = True + params_["ticks_per_quarter"] = 220 + params_["max_microtiming_shift"] = 0.25 + params_["num_microtiming_bins"] = 110 TOK_PARAMS_MULTITRACK.append((tokenization_, params_)) if tokenization_ in tokenizations_non_one_stream: diff --git a/tests/utils_tests.py b/tests/utils_tests.py index 9a4e8f87..3ea840d9 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: from collections.abc import Mapping, Sequence - from symusic.core import TempoTickList + from symusic.core import NoteTickList, TempoTickList from miditok import MusicTokenizer, TokSequence @@ -125,6 +125,26 @@ def adjust_tok_params_for_tests(tokenization: str, params: dict[str, Any]) -> No and params.get("use_rests", False) ): params["use_rests"] = False + # PerTok needs a number of separate args + elif tokenization == "PerTok": + params["beat_res"] = {(0, 128): 4, (0, 32): 3} + params["use_microtiming"] = True + params["ticks_per_quarter"] = 220 + + # params["beat_res"] = {(0, 16): 4} + # params["use_microtiming"] = False + # params["ticks_per_quarter"] = 16 + + params["max_microtiming_shift"] = 0.25 + params["num_microtiming_bins"] = 110 + params["use_rests"] = False + params["use_sustain_pedals"] = False + params["use_bar_end_tokens"] = False + params["use_tempos"] = False + params["use_time_signatures"] = True + params["use_pitch_bends"] = False + params["use_pitchdrum_tokens"] = False + params["use_pitch_intervals"] = False def sort_score(score: Score, sort_tracks: bool = True) -> None: @@ -271,6 +291,7 @@ def scores_notes_equals( score2: Score, check_velocities: bool, use_note_duration_programs: Sequence[int], + use_time_range: bool = False, ) -> list[tuple[int, str, list[tuple[str, Note | int, int]]]]: """ Check that the notes from two Scores are all equal. @@ -289,7 +310,7 @@ def scores_notes_equals( errors = [] for track1, track2 in zip(score1.tracks, score2.tracks): if track1.program != track2.program or track1.is_drum != track2.is_drum: - errors.append((0, "program", [])) + errors.append((0, "program", [track1.program, track2.program])) continue if len(track1.notes) != len(track2.notes): errors.append((0, "num notes", [])) @@ -305,7 +326,12 @@ def scores_notes_equals( track1.notes = Note.from_numpy(**notes2) track1.notes.sort(key=lambda n: (n.time, n.pitch, n.duration, n.velocity)) track_errors = tracks_notes_equals( - track1, track2, check_velocities, using_note_durations + track1, + track2, + check_velocities, + using_note_durations, + use_time_range, + max_time_range=int(score1.ticks_per_quarter * 0.5), ) if len(track_errors) > 0: errors.append((track1.program, track1.name, track_errors)) @@ -317,17 +343,84 @@ def tracks_notes_equals( track2: Track, check_velocities: bool = True, check_durations: bool = True, + use_time_range: bool = False, + max_time_range: int = 220, +) -> list[tuple[str, Note | int, int]]: + if not use_time_range: + errors = [] + for note1, note2 in zip(track1.notes, track2.notes): + err = notes_equals( + note1, + note2, + check_velocities, + check_durations, + ) + if err != "": + errors.append((err, note2, getattr(note1, err))) + return errors + # Sliding window search of nearby notes for hires tokenizers + return notes_in_sliding_window_equals( + track1.notes, + track2.notes, + check_velocities=check_velocities, + check_durations=check_durations, + max_time_range=max_time_range, + ) + + +def notes_in_sliding_window_equals( + notes_1: NoteTickList, + notes_2: NoteTickList, + check_velocities: bool = True, + check_durations: bool = True, + max_time_range: int = 120, ) -> list[tuple[str, Note | int, int]]: errors = [] - for note1, note2 in zip(track1.notes, track2.notes): - err = notes_equals(note1, note2, check_velocities, check_durations) - if err != "": - errors.append((err, note2, getattr(note1, err))) + for idx, note_1 in enumerate(notes_1): + potential_notes = get_notes_in_range(idx=idx, note_list=notes_2, window_size=25) + potential_notes = [ + note for note in potential_notes if note.pitch == note_1.pitch + ] + if potential_notes is None: + errors.append(("pitch", notes_2[idx], note_1.pitch)) + continue + + if not any( + (abs(note_1.start - note_2.start) < max_time_range) + for note_2 in potential_notes + ): + errors.append(("start", notes_2[idx], note_1.start)) + continue + + if check_durations and not any( + (abs(note_1.end - note_2.end) < max_time_range) + for note_2 in potential_notes + ): + errors.append(("end", notes_2[idx], note_1.end)) + continue + + if check_velocities and not any( + (note_1.velocity == note_2.velocity) for note_2 in potential_notes + ): + errors.append(("velocity", notes_2[idx], note_1.velocity)) + continue + return errors +def get_notes_in_range( + idx: int, note_list: NoteTickList, window_size: int = 5 +) -> NoteTickList: + start = max(0, idx - window_size) + end = min(len(note_list) - 1, idx + window_size) + return note_list[start : end + 1] + + def notes_equals( - note1: Note, note2: Note, check_velocity: bool = True, check_duration: bool = True + note1: Note, + note2: Note, + check_velocity: bool = True, + check_duration: bool = True, ) -> str: if note1.start != note2.start: return "start" @@ -363,14 +456,21 @@ def check_scores_equals( check_pedals: bool = True, check_pitch_bends: bool = True, log_prefix: str = "", + use_time_ranges: bool = False, + max_time_range: int = 120, ) -> bool: has_errors = False types_of_errors = [] # Checks notes and add markers if errors errors = scores_notes_equals( - score1, score2, check_velocities, use_note_duration_programs + score1, + score2, + check_velocities, + use_note_duration_programs, + use_time_ranges, ) + if len(errors) > 0: has_errors = True for e, track_err in enumerate(errors): @@ -389,9 +489,16 @@ def check_scores_equals( # Check pedals if check_pedals: for inst1, inst2 in zip(score1.tracks, score2.tracks): - if inst1.pedals != inst2.pedals: + if not use_time_ranges and inst1.pedals != inst2.pedals: types_of_errors.append("PEDALS") break + inst1_pedals, inst2_pedals = inst1.pedals, inst2.pedals + for pedal_0, pedal_1 in zip(inst1_pedals, inst2_pedals): + if (pedal_0.time - pedal_1.time) > max_time_range or ( + pedal_0.duration - pedal_1.duration + ) > max_time_range: + types_of_errors.append("PEDALS") + break # Check pitch bends if check_pitch_bends: @@ -412,8 +519,15 @@ def check_scores_equals( types_of_errors.append("TEMPOS") # Checks time signatures - if check_time_signatures and score1.time_signatures != score2.time_signatures: - types_of_errors.append("TIME SIGNATURES") + if check_time_signatures: + if not use_time_ranges and score1.time_signatures != score2.time_signatures: + types_of_errors.append("TIME SIGNATURES") + elif use_time_ranges: + time_sigs1, time_sigs2 = score1.time_signatures, score2.time_signatures + for time_sig1, time_sig2 in zip(time_sigs1, time_sigs2): + if abs(time_sig1.time - time_sig2.time) > max_time_range: + types_of_errors.append("TIME SIGNATURES") + break # Prints types of errors has_errors = has_errors or len(types_of_errors) > 0 @@ -430,6 +544,7 @@ def tokenize_and_check_equals( ) -> tuple[Score, Score, bool]: tokenization = type(tokenizer).__name__ log_prefix = f"{file_name} / {tokenization}" + use_time_ranges = bool(tokenization in ["PerTok"]) # Tokenize and detokenize adapt_ref_score_before_tokenize(score, tokenizer) @@ -443,6 +558,9 @@ def tokenize_and_check_equals( score = adapt_ref_score_for_tests_assertion(score, tokenizer) if score.ticks_per_quarter != score_decoded.ticks_per_quarter: score = score.resample(tpq=score_decoded.ticks_per_quarter) + # if not use_time_ranges: + # sort_score(score) + # sort_score(score_decoded) sort_score(score) sort_score(score_decoded) @@ -457,6 +575,7 @@ def tokenize_and_check_equals( check_pedals=tokenizer.config.use_sustain_pedals, check_pitch_bends=tokenizer.config.use_pitch_bends, log_prefix=log_prefix, + use_time_ranges=use_time_ranges, ) # Checks types and values conformity following the rules