Skip to content

Commit

Permalink
PerTok Tokenizer (#191)
Browse files Browse the repository at this point in the history
* init pertok

* linting

* Added microtiming to TokenizerConfig

* pertok updates

* duration + timeshift values updated

* first working encode/decode

* remove prints and commented out functions

* moved most microtiming configs to additional_params

* Removed res_microtiming and added a _resample_score fn

* linting

* linting

* ruff format

* Documentation + comments

* remove duplicate lines of code

* All tests passing

* Update tests/utils_tests.py

* Update miditok/tokenizations/pertok.py

* docs and tests fixes

---------

Co-authored-by: Nathan Fradet <[email protected]>
  • Loading branch information
JLenzy and Natooz authored Sep 10, 2024
1 parent 2281c4f commit 967209a
Show file tree
Hide file tree
Showing 7 changed files with 769 additions and 43 deletions.
6 changes: 6 additions & 0 deletions docs/tokenizations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ MMM
.. autoclass:: miditok.MMM
:show-inheritance:

PerTok
------------------------

.. autoclass:: miditok.PerTok
:show-inheritance:


Create yours
------------------------
Expand Down
2 changes: 2 additions & 0 deletions miditok/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
MIDILike,
MuMIDI,
Octuple,
PerTok,
Structured,
)
from .tokenizer_training_iterator import TokTrainingIterator
Expand All @@ -34,6 +35,7 @@
"CPWord",
"MuMIDI",
"MMM",
"PerTok",
"utils",
"data_augmentation",
]
Expand Down
76 changes: 45 additions & 31 deletions miditok/midi_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,23 +491,7 @@ def preprocess_score(self, score: Score) -> Score:
new_tpq = self.config.max_num_pos_per_beat

# Resample time if needed (not inplace) and attribute preprocessed time sig.
if score.ticks_per_quarter != new_tpq:
# Times of time signatures copy need to be resampled too
time_signatures_soa = time_signatures_copy.numpy()
time_signatures_soa["time"] = (
time_signatures_soa["time"] * (new_tpq / score.ticks_per_quarter)
).astype(np.int32)

score = score.resample(new_tpq, min_dur=1)
score.time_signatures = TimeSignature.from_numpy(
**time_signatures_soa,
)
# Otherwise we do a copy in order to make sure no inplace operation is performed
# on the provided Score object.
# We make a copy here instead of at beginning as resample also makes a copy.
else:
score = score.copy()
score.time_signatures = time_signatures_copy
score = self._resample_score(score, new_tpq, time_signatures_copy)

# Merge instruments of the same program / inst before preprocessing them.
# This allows to avoid potential duplicated notes in some multitrack settings
Expand All @@ -531,7 +515,7 @@ def preprocess_score(self, score: Score) -> Score:
if not self._note_on_off or (
self.config.use_sustain_pedals and self.config.sustain_pedal_duration
):
if self.config.use_time_signatures:
if self.config.use_time_signatures and len(score.time_signatures) > 0:
ticks_per_beat = get_score_ticks_per_beat(score)
else:
ticks_per_beat = np.array([[score.end(), score.ticks_per_quarter]])
Expand Down Expand Up @@ -595,6 +579,29 @@ def preprocess_score(self, score: Score) -> Score:

return score

def _resample_score(
self, score: Score, _new_tpq: int, _time_signatures_copy: TimeSignatureTickList
) -> Score:
if score.ticks_per_quarter != _new_tpq:
# Times of time signatures copy need to be resampled too
time_signatures_soa = _time_signatures_copy.numpy()
time_signatures_soa["time"] = (
time_signatures_soa["time"] * (_new_tpq / score.ticks_per_quarter)
).astype(np.int32)

score = score.resample(_new_tpq, min_dur=1)
score.time_signatures = TimeSignature.from_numpy(
**time_signatures_soa,
)
# Otherwise we do a copy in order to make sure no inplace operation is performed
# on the provided Score object.
# We make a copy here instead of at beginning as resample also makes a copy.
else:
score = score.copy()
score.time_signatures = _time_signatures_copy

return score

def _filter_unsupported_time_signatures(
self, time_signatures: TimeSignatureTickList
) -> None:
Expand Down Expand Up @@ -1158,7 +1165,7 @@ def _score_to_tokens(
or self.config.use_chords
or self.config.use_pitch_intervals
):
if self.config.use_time_signatures:
if self.config.use_time_signatures and len(score.time_signatures) > 0:
ticks_per_beat = get_score_ticks_per_beat(score)
else:
ticks_per_beat = np.array([[score.end(), self.time_division]])
Expand Down Expand Up @@ -1483,24 +1490,31 @@ def _create_track_events(
)
)
else:
# `while` as there might not be any note in the next section
while note.time >= ticks_per_beat[tpb_idx, 0]:
tpb_idx += 1
dur = self._tpb_ticks_to_tokens[ticks_per_beat[tpb_idx, 1]][
note.duration
]
events.append(
Event(
type_="Duration",
value=dur,
time=note.start,
program=program,
desc=f"{note.duration} ticks",
self._create_duration_event(
note=note,
_program=program,
_ticks_per_beat=ticks_per_beat,
_tpb_idx=tpb_idx,
)
)

return events

def _create_duration_event(
self, note: Note, _program: int, _ticks_per_beat: np.ndarray, _tpb_idx: int
) -> Event:
while note.time >= _ticks_per_beat[_tpb_idx, 0]:
_tpb_idx += 1
dur = self._tpb_ticks_to_tokens[_ticks_per_beat[_tpb_idx, 1]][note.duration]
return Event(
type_="Duration",
value=dur,
time=note.start,
program=_program,
desc=f"{note.duration} ticks",
)

@staticmethod
def _insert_program_change_events(events: list[Event]) -> None:
"""
Expand Down
2 changes: 2 additions & 0 deletions miditok/tokenizations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .mmm import MMM
from .mumidi import MuMIDI
from .octuple import Octuple
from .pertok import PerTok
from .remi import REMI
from .structured import Structured
from .tsd import TSD
Expand All @@ -24,4 +25,5 @@
"CPWord",
"MuMIDI",
"MMM",
"PerTok",
]
Loading

0 comments on commit 967209a

Please sign in to comment.