From 722d4beea3dd1f952cabc4762f1ad7a77c1cd848 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 5 Dec 2023 22:41:34 +0000 Subject: [PATCH] Fix bug in build (#75) * fix * fix bug in build --- aria/data/datasets.py | 6 ++++-- aria/run.py | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/aria/data/datasets.py b/aria/data/datasets.py index 67e783b..69ee3d1 100644 --- a/aria/data/datasets.py +++ b/aria/data/datasets.py @@ -452,7 +452,8 @@ def _get_seqs(_entry: MidiDict | dict, _tokenizer: Tokenizer): try: _tokenized_seq = _tokenizer.tokenize(_midi_dict) except Exception as e: - logger.error(f"Failed to tokenize midi_dict: {e}") + logger.info(f"Skipping midi_dict: {e}") + return else: if _tokenizer.unk_tok in _tokenized_seq: logger.warning("Unknown token seen while tokenizing midi_dict") @@ -601,7 +602,8 @@ def _build_epoch(_save_path, _midi_dataset): buffer = [] for entry in get_seqs(tokenizer, _midi_dataset): - buffer += entry + if entry is not None: + buffer += entry while len(buffer) >= max_seq_len: writer.write(buffer[:max_seq_len]) buffer = buffer[max_seq_len:] diff --git a/aria/run.py b/aria/run.py index 214a4ae..61071b4 100644 --- a/aria/run.py +++ b/aria/run.py @@ -257,7 +257,7 @@ def _parse_pretrain_dataset_args(): return argp.parse_args(sys.argv[2:]) -def build_tokenized_dataset(args): +def build_pretraining_dataset(args): from aria.tokenizer import TokenizerLazy from aria.data.datasets import PretrainingDataset @@ -318,9 +318,9 @@ def main(): elif args.command == "midi-dataset": build_midi_dataset(args=_parse_midi_dataset_args()) elif args.command == "pretrain-dataset": - build_tokenized_dataset(args=_parse_pretrain_dataset_args()) + build_pretraining_dataset(args=_parse_pretrain_dataset_args()) elif args.command == "finetune-dataset": - build_tokenized_dataset(args=_parse_finetune_dataset_args()) + build_finetune_dataset(args=_parse_finetune_dataset_args()) else: print("Unrecognized command") parser.print_help()