diff --git a/miditok/constants.py b/miditok/constants.py index 317f4e51..8ef86cb3 100644 --- a/miditok/constants.py +++ b/miditok/constants.py @@ -2,7 +2,7 @@ """ -CURRENT_VERSION_PACKAGE = "2.1.1" # used when saving the config of a tokenizer +CURRENT_VERSION_PACKAGE = "2.1.2" # used when saving the config of a tokenizer MIDI_FILES_EXTENSIONS = [".mid", ".midi", ".MID", ".MIDI"] diff --git a/miditok/midi_tokenizer.py b/miditok/midi_tokenizer.py index 15525950..e21e4590 100644 --- a/miditok/midi_tokenizer.py +++ b/miditok/midi_tokenizer.py @@ -33,8 +33,8 @@ def _in_as_seq(complete: bool = True, decode_bpe: bool = True): - """Decorator creating if necessary and completing a TokSequence object before that the function is called. - This decorator is used by the :py:meth:`miditok.MIDITokenizer.track_to_tokens` method. + r"""Decorator creating if necessary and completing a TokSequence object before that the function is called. + This decorator is made to be used by the :py:meth:`miditok.MIDITokenizer.tokens_to_midi` method. :param complete: will complete the sequence, i.e. complete its ``ids`` , ``tokens`` and ``events`` . :param decode_bpe: will decode BPE, if applicable. This step is performed before completing the sequence. @@ -42,7 +42,7 @@ def _in_as_seq(complete: bool = True, decode_bpe: bool = True): def decorator(function: Callable = None): def wrapper(*args, **kwargs): - self = args[0] + tokenizer = args[0] seq = args[1] if not isinstance(seq, TokSequence) and not all( isinstance(seq_, TokSequence) for seq_ in seq @@ -57,25 +57,38 @@ def wrapper(*args, **kwargs): else: # list of Event, very unlikely arg = ("events", seq) + # Deduce nb of subscript, if tokenizer is multi-voc or unique_track + nb_subscripts = nb_real_subscripts = 1 + if not tokenizer.unique_track: + nb_subscripts += 1 + if tokenizer.is_multi_voc: + nb_subscripts += 1 if isinstance(arg[1][0], list): + nb_real_subscripts += 1 + if isinstance(arg[1][0][0], list): + nb_real_subscripts += 1 + + if not tokenizer.unique_track and nb_subscripts == nb_real_subscripts: seq = [] for obj in arg[1]: kwarg = {arg[0]: obj} seq.append(TokSequence(**kwarg)) - seq[-1].ids_bpe_encoded = self._are_ids_bpe_encoded(seq[-1].ids) - else: + if not tokenizer.is_multi_voc: + seq[-1].ids_bpe_encoded = tokenizer._are_ids_bpe_encoded(seq[-1].ids) + else: # 1 subscript, unique_track and no multi-voc kwarg = {arg[0]: arg[1]} seq = TokSequence(**kwarg) - seq.ids_bpe_encoded = self._are_ids_bpe_encoded(seq.ids) + if not tokenizer.is_multi_voc: + seq.ids_bpe_encoded = tokenizer._are_ids_bpe_encoded(seq.ids) - if self.has_bpe and decode_bpe: - self.decode_bpe(seq) + if tokenizer.has_bpe and decode_bpe: + tokenizer.decode_bpe(seq) if complete: if isinstance(seq, TokSequence): - self.complete_sequence(seq) + tokenizer.complete_sequence(seq) else: for seq_ in seq: - self.complete_sequence(seq_) + tokenizer.complete_sequence(seq_) args = list(args) args[1] = seq @@ -589,6 +602,7 @@ def _bytes_to_tokens( tokens = [tok for toks in tokens for tok in toks] # flatten return tokens + @_in_as_seq() def tokens_to_midi( self, tokens: Union[TokSequence, List, np.ndarray, Any], @@ -639,7 +653,7 @@ def tokens_to_midi( @abstractmethod def tokens_to_track( self, - tokens: Union[TokSequence, List, np.ndarray, Any], + tokens: TokSequence, time_division: Optional[int] = TIME_DIVISION, program: Optional[Tuple[int, bool]] = (0, False), ) -> Tuple[Instrument, List[TempoChange]]: @@ -1291,7 +1305,7 @@ def tokenize_midi_dataset( @_in_as_seq(complete=False, decode_bpe=False) def tokens_errors( self, tokens: Union[TokSequence, List[Union[int, List[int]]]] - ) -> float: + ) -> Union[float, List[float]]: r"""Checks if a sequence of tokens is made of good token types successions and returns the error ratio (lower is better). The common implementation in MIDITokenizer class will check token types, @@ -1303,6 +1317,10 @@ def tokens_errors( :param tokens: sequence of tokens to check. :return: the error ratio (lower is better). """ + # If list of TokSequence -> recursive + if isinstance(tokens, list): + return [self.tokens_errors(tok_seq) for tok_seq in tokens] + nb_tok_predicted = len(tokens) # used to norm the score if self.has_bpe: self.decode_bpe(tokens) diff --git a/miditok/tokenizations/cp_word.py b/miditok/tokenizations/cp_word.py index 55e1eed4..58f98f17 100644 --- a/miditok/tokenizations/cp_word.py +++ b/miditok/tokenizations/cp_word.py @@ -287,7 +287,7 @@ def __create_cp_token( def tokens_to_track( self, - tokens: Union[TokSequence, List, np.ndarray, Any], + tokens: TokSequence, time_division: Optional[int] = TIME_DIVISION, program: Optional[Tuple[int, bool]] = (0, False), ) -> Tuple[Instrument, List[TempoChange]]: @@ -464,7 +464,7 @@ def _create_token_types_graph(self) -> Dict[str, List[str]]: return dic @_in_as_seq() - def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> float: + def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> Union[float, List[float]]: r"""Checks if a sequence of tokens is made of good token types successions and returns the error ratio (lower is better). The Pitch and Position values are also analyzed: @@ -474,6 +474,9 @@ def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> fl :param tokens: sequence of tokens to check :return: the error ratio (lower is better) """ + # If list of TokSequence -> recursive + if isinstance(tokens, list): + return [self.tokens_errors(tok_seq) for tok_seq in tokens] def cp_token_type(tok: List[int]) -> List[str]: family = self[0, tok[0]].split("_")[1] diff --git a/miditok/tokenizations/midi_like.py b/miditok/tokenizations/midi_like.py index 79457e18..dd36220e 100644 --- a/miditok/tokenizations/midi_like.py +++ b/miditok/tokenizations/midi_like.py @@ -180,7 +180,7 @@ def track_to_tokens(self, track: Instrument) -> TokSequence: def tokens_to_track( self, - tokens: Union[TokSequence, List, np.ndarray, Any], + tokens: TokSequence, time_division: Optional[int] = TIME_DIVISION, program: Optional[Tuple[int, bool]] = (0, False), default_duration: int = None, @@ -350,7 +350,7 @@ def _create_token_types_graph(self) -> Dict[str, List[str]]: return dic @_in_as_seq(complete=False, decode_bpe=False) - def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> float: + def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> Union[float, List[float]]: r"""Checks if a sequence of tokens is made of good token types successions and returns the error ratio (lower is better). The Pitch and Position values are also analyzed: @@ -360,6 +360,10 @@ def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> fl :param tokens: sequence of tokens to check :return: the error ratio (lower is better) """ + # If list of TokSequence -> recursive + if isinstance(tokens, list): + return [self.tokens_errors(tok_seq) for tok_seq in tokens] + nb_tok_predicted = len(tokens) # used to norm the score if self.has_bpe: self.decode_bpe(tokens) diff --git a/miditok/tokenizations/mmm.py b/miditok/tokenizations/mmm.py index 64b9edb4..d6d218d2 100644 --- a/miditok/tokenizations/mmm.py +++ b/miditok/tokenizations/mmm.py @@ -216,7 +216,7 @@ def track_to_tokens(self, track: Instrument) -> List[Event]: def tokens_to_track( self, - tokens: Union[TokSequence, List, np.ndarray, Any], + tokens: TokSequence, time_division: Optional[int] = TIME_DIVISION, program: Optional[Tuple[int, bool]] = (0, False), ) -> None: diff --git a/miditok/tokenizations/mumidi.py b/miditok/tokenizations/mumidi.py index 48f3df21..c3c7f7f1 100644 --- a/miditok/tokenizations/mumidi.py +++ b/miditok/tokenizations/mumidi.py @@ -368,7 +368,7 @@ def tokens_to_midi( def tokens_to_track( self, - tokens: Union[TokSequence, List, np.ndarray, Any], + tokens: TokSequence, time_division: Optional[int] = TIME_DIVISION, program: Optional[Tuple[int, bool]] = (0, False), ): diff --git a/miditok/tokenizations/octuple.py b/miditok/tokenizations/octuple.py index 48197015..5a40036e 100644 --- a/miditok/tokenizations/octuple.py +++ b/miditok/tokenizations/octuple.py @@ -359,7 +359,7 @@ def tokens_to_midi( def tokens_to_track( self, - tokens: Union[TokSequence, List, np.ndarray, Any], + tokens: TokSequence, time_division: Optional[int] = TIME_DIVISION, program: Optional[Tuple[int, bool]] = (0, False), ) -> Tuple[Instrument, List[TempoChange]]: diff --git a/miditok/tokenizations/octuple_mono.py b/miditok/tokenizations/octuple_mono.py index d898ec7e..e74ffef8 100644 --- a/miditok/tokenizations/octuple_mono.py +++ b/miditok/tokenizations/octuple_mono.py @@ -132,7 +132,7 @@ def track_to_tokens(self, track: Instrument) -> TokSequence: def tokens_to_track( self, - tokens: Union[TokSequence, List, np.ndarray, Any], + tokens: TokSequence, time_division: Optional[int] = TIME_DIVISION, program: Optional[Tuple[int, bool]] = (0, False), ) -> Tuple[Instrument, List[TempoChange]]: @@ -250,7 +250,7 @@ def _create_token_types_graph(self) -> Dict[str, List[str]]: return {} # not relevant for this encoding @_in_as_seq() - def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> float: + def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> Union[float, List[float]]: r"""Checks if a sequence of tokens is made of good token values and returns the error ratio (lower is better). The token types are always the same in Octuple so this method only checks @@ -262,6 +262,10 @@ def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> fl :param tokens: sequence of tokens to check :return: the error ratio (lower is better) """ + # If list of TokSequence -> recursive + if isinstance(tokens, list): + return [self.tokens_errors(tok_seq) for tok_seq in tokens] + err = 0 current_bar = current_pos = -1 current_pitches = [] diff --git a/miditok/tokenizations/remi.py b/miditok/tokenizations/remi.py index c95d929c..98bdc7ad 100644 --- a/miditok/tokenizations/remi.py +++ b/miditok/tokenizations/remi.py @@ -1,9 +1,9 @@ -from typing import List, Tuple, Dict, Optional, Union, Any +from typing import List, Tuple, Dict, Optional import numpy as np from miditoolkit import Instrument, Note, TempoChange -from ..midi_tokenizer import MIDITokenizer, _in_as_seq, _out_as_complete_seq +from ..midi_tokenizer import MIDITokenizer, _out_as_complete_seq from ..classes import TokSequence, Event from ..utils import detect_chords from ..constants import ( @@ -204,7 +204,7 @@ def track_to_tokens(self, track: Instrument) -> TokSequence: def tokens_to_track( self, - tokens: Union[TokSequence, List, np.ndarray, Any], + tokens: TokSequence, time_division: Optional[int] = TIME_DIVISION, program: Optional[Tuple[int, bool]] = (0, False), ) -> Tuple[Instrument, List[TempoChange]]: diff --git a/miditok/tokenizations/remi_plus.py b/miditok/tokenizations/remi_plus.py index e7c4102b..6777e0ab 100644 --- a/miditok/tokenizations/remi_plus.py +++ b/miditok/tokenizations/remi_plus.py @@ -292,7 +292,7 @@ def track_to_tokens(self, track: Instrument) -> TokSequence: def tokens_to_track( self, - tokens: Union[TokSequence, List, np.ndarray, Any], + tokens: TokSequence, time_division: Optional[int] = TIME_DIVISION, program: Optional[Tuple[int, bool]] = (0, False), ) -> None: diff --git a/miditok/tokenizations/structured.py b/miditok/tokenizations/structured.py index 1c50ce3f..0f2cf09b 100644 --- a/miditok/tokenizations/structured.py +++ b/miditok/tokenizations/structured.py @@ -138,7 +138,7 @@ def track_to_tokens(self, track: Instrument) -> TokSequence: def tokens_to_track( self, - tokens: Union[TokSequence, List, np.ndarray, Any], + tokens: TokSequence, time_division: Optional[int] = TIME_DIVISION, program: Optional[Tuple[int, bool]] = (0, False), ) -> Tuple[Instrument, List[TempoChange]]: diff --git a/miditok/tokenizations/tsd.py b/miditok/tokenizations/tsd.py index 8ced93f7..1724cb42 100644 --- a/miditok/tokenizations/tsd.py +++ b/miditok/tokenizations/tsd.py @@ -250,7 +250,7 @@ def _midi_to_tokens( def tokens_to_track( self, - tokens: Union[TokSequence, List, np.ndarray, Any], + tokens: TokSequence, time_division: Optional[int] = TIME_DIVISION, program: Optional[Tuple[int, bool]] = (0, False), ) -> Tuple[Instrument, List[TempoChange]]: diff --git a/setup.py b/setup.py index 28d32a3a..29f5fa45 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ author="Nathan Fradet", url="https://github.com/Natooz/MidiTok", packages=find_packages(exclude=("tests",)), - version="2.1.1", + version="2.1.2", license="MIT", description="A convenient MIDI tokenizer for Deep Learning networks, with multiple encoding strategies", long_description=long_description,