Skip to content

Commit

Permalink
fix in in_as_seq, tokens_error handle list of TokSeq, tokens_to_track…
Browse files Browse the repository at this point in the history
… only taking TokSequence input hint
  • Loading branch information
Natooz committed Jul 20, 2023
1 parent 394dc4d commit 089fa74
Show file tree
Hide file tree
Showing 13 changed files with 58 additions and 29 deletions.
2 changes: 1 addition & 1 deletion miditok/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""

CURRENT_VERSION_PACKAGE = "2.1.1" # used when saving the config of a tokenizer
CURRENT_VERSION_PACKAGE = "2.1.2" # used when saving the config of a tokenizer

MIDI_FILES_EXTENSIONS = [".mid", ".midi", ".MID", ".MIDI"]

Expand Down
42 changes: 30 additions & 12 deletions miditok/midi_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@


def _in_as_seq(complete: bool = True, decode_bpe: bool = True):
"""Decorator creating if necessary and completing a TokSequence object before that the function is called.
This decorator is used by the :py:meth:`miditok.MIDITokenizer.track_to_tokens` method.
r"""Decorator creating if necessary and completing a TokSequence object before that the function is called.
This decorator is made to be used by the :py:meth:`miditok.MIDITokenizer.tokens_to_midi` method.
:param complete: will complete the sequence, i.e. complete its ``ids`` , ``tokens`` and ``events`` .
:param decode_bpe: will decode BPE, if applicable. This step is performed before completing the sequence.
"""

def decorator(function: Callable = None):
def wrapper(*args, **kwargs):
self = args[0]
tokenizer = args[0]
seq = args[1]
if not isinstance(seq, TokSequence) and not all(
isinstance(seq_, TokSequence) for seq_ in seq
Expand All @@ -57,25 +57,38 @@ def wrapper(*args, **kwargs):
else: # list of Event, very unlikely
arg = ("events", seq)

# Deduce nb of subscript, if tokenizer is multi-voc or unique_track
nb_subscripts = nb_real_subscripts = 1
if not tokenizer.unique_track:
nb_subscripts += 1
if tokenizer.is_multi_voc:
nb_subscripts += 1
if isinstance(arg[1][0], list):
nb_real_subscripts += 1
if isinstance(arg[1][0][0], list):
nb_real_subscripts += 1

if not tokenizer.unique_track and nb_subscripts == nb_real_subscripts:
seq = []
for obj in arg[1]:
kwarg = {arg[0]: obj}
seq.append(TokSequence(**kwarg))
seq[-1].ids_bpe_encoded = self._are_ids_bpe_encoded(seq[-1].ids)
else:
if not tokenizer.is_multi_voc:
seq[-1].ids_bpe_encoded = tokenizer._are_ids_bpe_encoded(seq[-1].ids)
else: # 1 subscript, unique_track and no multi-voc
kwarg = {arg[0]: arg[1]}
seq = TokSequence(**kwarg)
seq.ids_bpe_encoded = self._are_ids_bpe_encoded(seq.ids)
if not tokenizer.is_multi_voc:
seq.ids_bpe_encoded = tokenizer._are_ids_bpe_encoded(seq.ids)

if self.has_bpe and decode_bpe:
self.decode_bpe(seq)
if tokenizer.has_bpe and decode_bpe:
tokenizer.decode_bpe(seq)
if complete:
if isinstance(seq, TokSequence):
self.complete_sequence(seq)
tokenizer.complete_sequence(seq)
else:
for seq_ in seq:
self.complete_sequence(seq_)
tokenizer.complete_sequence(seq_)

args = list(args)
args[1] = seq
Expand Down Expand Up @@ -589,6 +602,7 @@ def _bytes_to_tokens(
tokens = [tok for toks in tokens for tok in toks] # flatten
return tokens

@_in_as_seq()
def tokens_to_midi(
self,
tokens: Union[TokSequence, List, np.ndarray, Any],
Expand Down Expand Up @@ -639,7 +653,7 @@ def tokens_to_midi(
@abstractmethod
def tokens_to_track(
self,
tokens: Union[TokSequence, List, np.ndarray, Any],
tokens: TokSequence,
time_division: Optional[int] = TIME_DIVISION,
program: Optional[Tuple[int, bool]] = (0, False),
) -> Tuple[Instrument, List[TempoChange]]:
Expand Down Expand Up @@ -1291,7 +1305,7 @@ def tokenize_midi_dataset(
@_in_as_seq(complete=False, decode_bpe=False)
def tokens_errors(
self, tokens: Union[TokSequence, List[Union[int, List[int]]]]
) -> float:
) -> Union[float, List[float]]:
r"""Checks if a sequence of tokens is made of good token types
successions and returns the error ratio (lower is better).
The common implementation in MIDITokenizer class will check token types,
Expand All @@ -1303,6 +1317,10 @@ def tokens_errors(
:param tokens: sequence of tokens to check.
:return: the error ratio (lower is better).
"""
# If list of TokSequence -> recursive
if isinstance(tokens, list):
return [self.tokens_errors(tok_seq) for tok_seq in tokens]

nb_tok_predicted = len(tokens) # used to norm the score
if self.has_bpe:
self.decode_bpe(tokens)
Expand Down
7 changes: 5 additions & 2 deletions miditok/tokenizations/cp_word.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def __create_cp_token(

def tokens_to_track(
self,
tokens: Union[TokSequence, List, np.ndarray, Any],
tokens: TokSequence,
time_division: Optional[int] = TIME_DIVISION,
program: Optional[Tuple[int, bool]] = (0, False),
) -> Tuple[Instrument, List[TempoChange]]:
Expand Down Expand Up @@ -464,7 +464,7 @@ def _create_token_types_graph(self) -> Dict[str, List[str]]:
return dic

@_in_as_seq()
def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> float:
def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> Union[float, List[float]]:
r"""Checks if a sequence of tokens is made of good token types
successions and returns the error ratio (lower is better).
The Pitch and Position values are also analyzed:
Expand All @@ -474,6 +474,9 @@ def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> fl
:param tokens: sequence of tokens to check
:return: the error ratio (lower is better)
"""
# If list of TokSequence -> recursive
if isinstance(tokens, list):
return [self.tokens_errors(tok_seq) for tok_seq in tokens]

def cp_token_type(tok: List[int]) -> List[str]:
family = self[0, tok[0]].split("_")[1]
Expand Down
8 changes: 6 additions & 2 deletions miditok/tokenizations/midi_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def track_to_tokens(self, track: Instrument) -> TokSequence:

def tokens_to_track(
self,
tokens: Union[TokSequence, List, np.ndarray, Any],
tokens: TokSequence,
time_division: Optional[int] = TIME_DIVISION,
program: Optional[Tuple[int, bool]] = (0, False),
default_duration: int = None,
Expand Down Expand Up @@ -350,7 +350,7 @@ def _create_token_types_graph(self) -> Dict[str, List[str]]:
return dic

@_in_as_seq(complete=False, decode_bpe=False)
def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> float:
def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> Union[float, List[float]]:
r"""Checks if a sequence of tokens is made of good token types
successions and returns the error ratio (lower is better).
The Pitch and Position values are also analyzed:
Expand All @@ -360,6 +360,10 @@ def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> fl
:param tokens: sequence of tokens to check
:return: the error ratio (lower is better)
"""
# If list of TokSequence -> recursive
if isinstance(tokens, list):
return [self.tokens_errors(tok_seq) for tok_seq in tokens]

nb_tok_predicted = len(tokens) # used to norm the score
if self.has_bpe:
self.decode_bpe(tokens)
Expand Down
2 changes: 1 addition & 1 deletion miditok/tokenizations/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def track_to_tokens(self, track: Instrument) -> List[Event]:

def tokens_to_track(
self,
tokens: Union[TokSequence, List, np.ndarray, Any],
tokens: TokSequence,
time_division: Optional[int] = TIME_DIVISION,
program: Optional[Tuple[int, bool]] = (0, False),
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion miditok/tokenizations/mumidi.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def tokens_to_midi(

def tokens_to_track(
self,
tokens: Union[TokSequence, List, np.ndarray, Any],
tokens: TokSequence,
time_division: Optional[int] = TIME_DIVISION,
program: Optional[Tuple[int, bool]] = (0, False),
):
Expand Down
2 changes: 1 addition & 1 deletion miditok/tokenizations/octuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def tokens_to_midi(

def tokens_to_track(
self,
tokens: Union[TokSequence, List, np.ndarray, Any],
tokens: TokSequence,
time_division: Optional[int] = TIME_DIVISION,
program: Optional[Tuple[int, bool]] = (0, False),
) -> Tuple[Instrument, List[TempoChange]]:
Expand Down
8 changes: 6 additions & 2 deletions miditok/tokenizations/octuple_mono.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def track_to_tokens(self, track: Instrument) -> TokSequence:

def tokens_to_track(
self,
tokens: Union[TokSequence, List, np.ndarray, Any],
tokens: TokSequence,
time_division: Optional[int] = TIME_DIVISION,
program: Optional[Tuple[int, bool]] = (0, False),
) -> Tuple[Instrument, List[TempoChange]]:
Expand Down Expand Up @@ -250,7 +250,7 @@ def _create_token_types_graph(self) -> Dict[str, List[str]]:
return {} # not relevant for this encoding

@_in_as_seq()
def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> float:
def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> Union[float, List[float]]:
r"""Checks if a sequence of tokens is made of good token values and
returns the error ratio (lower is better).
The token types are always the same in Octuple so this method only checks
Expand All @@ -262,6 +262,10 @@ def tokens_errors(self, tokens: Union[TokSequence, List, np.ndarray, Any]) -> fl
:param tokens: sequence of tokens to check
:return: the error ratio (lower is better)
"""
# If list of TokSequence -> recursive
if isinstance(tokens, list):
return [self.tokens_errors(tok_seq) for tok_seq in tokens]

err = 0
current_bar = current_pos = -1
current_pitches = []
Expand Down
6 changes: 3 additions & 3 deletions miditok/tokenizations/remi.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List, Tuple, Dict, Optional, Union, Any
from typing import List, Tuple, Dict, Optional

import numpy as np
from miditoolkit import Instrument, Note, TempoChange

from ..midi_tokenizer import MIDITokenizer, _in_as_seq, _out_as_complete_seq
from ..midi_tokenizer import MIDITokenizer, _out_as_complete_seq
from ..classes import TokSequence, Event
from ..utils import detect_chords
from ..constants import (
Expand Down Expand Up @@ -204,7 +204,7 @@ def track_to_tokens(self, track: Instrument) -> TokSequence:

def tokens_to_track(
self,
tokens: Union[TokSequence, List, np.ndarray, Any],
tokens: TokSequence,
time_division: Optional[int] = TIME_DIVISION,
program: Optional[Tuple[int, bool]] = (0, False),
) -> Tuple[Instrument, List[TempoChange]]:
Expand Down
2 changes: 1 addition & 1 deletion miditok/tokenizations/remi_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def track_to_tokens(self, track: Instrument) -> TokSequence:

def tokens_to_track(
self,
tokens: Union[TokSequence, List, np.ndarray, Any],
tokens: TokSequence,
time_division: Optional[int] = TIME_DIVISION,
program: Optional[Tuple[int, bool]] = (0, False),
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion miditok/tokenizations/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def track_to_tokens(self, track: Instrument) -> TokSequence:

def tokens_to_track(
self,
tokens: Union[TokSequence, List, np.ndarray, Any],
tokens: TokSequence,
time_division: Optional[int] = TIME_DIVISION,
program: Optional[Tuple[int, bool]] = (0, False),
) -> Tuple[Instrument, List[TempoChange]]:
Expand Down
2 changes: 1 addition & 1 deletion miditok/tokenizations/tsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def _midi_to_tokens(

def tokens_to_track(
self,
tokens: Union[TokSequence, List, np.ndarray, Any],
tokens: TokSequence,
time_division: Optional[int] = TIME_DIVISION,
program: Optional[Tuple[int, bool]] = (0, False),
) -> Tuple[Instrument, List[TempoChange]]:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
author="Nathan Fradet",
url="https://github.com/Natooz/MidiTok",
packages=find_packages(exclude=("tests",)),
version="2.1.1",
version="2.1.2",
license="MIT",
description="A convenient MIDI tokenizer for Deep Learning networks, with multiple encoding strategies",
long_description=long_description,
Expand Down

0 comments on commit 089fa74

Please sign in to comment.