From 023071572e60c385dbc820974109ca9e946e733f Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Thu, 9 Nov 2023 23:44:57 +0000 Subject: [PATCH 01/14] MidiDataset can initialize with an iterator and only expand when necessary. --- aria/data/datasets.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/aria/data/datasets.py b/aria/data/datasets.py index ca1db35..1ca4a65 100644 --- a/aria/data/datasets.py +++ b/aria/data/datasets.py @@ -52,14 +52,14 @@ class MidiDataset: entries (list[MidiDict]): MidiDict objects to be stored. """ - def __init__(self, entries: list[MidiDict]): + def __init__(self, entries: list[MidiDict] | Iterable): self.entries = entries def __len__(self): - return len(self.entries) + return len(list(self.entries)) def __getitem__(self, ind: int): - return self.entries[ind] + return list(self.entries)[ind] def __iter__(self): yield from self.entries @@ -74,12 +74,12 @@ def save(self, save_path: str): @classmethod def load(cls, load_path: str): """Loads dataset from JSON file.""" - midi_dicts = [] - with jsonlines.open(load_path) as reader: - for entry in reader: - midi_dicts.append(MidiDict.from_msg_dict(entry)) + def _load(): + with jsonlines.open(load_path) as reader: + for entry in reader: + yield MidiDict.from_msg_dict(entry) - return cls(midi_dicts) + return cls(_load()) @classmethod def split_from_file( From c1d170daf4a1b1dbb5fee80e60ff718814fb0dd2 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Thu, 9 Nov 2023 23:52:10 +0000 Subject: [PATCH 02/14] reduce some memory overhead (we are starting to have >100k MidiDict and may get more in the future) --- aria/data/midi.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aria/data/midi.py b/aria/data/midi.py index 979619f..8db1fb1 100644 --- a/aria/data/midi.py +++ b/aria/data/midi.py @@ -1,5 +1,5 @@ """Utils for data/MIDI processing.""" - +import functools import hashlib import json import re @@ -114,8 +114,10 @@ def __init__( } ] + @functools.cached_property + def program_to_instrument(self): # This combines the individual dictionaries into one - self.program_to_instrument = ( + return ( {i: "piano" for i in range(0, 7 + 1)} | {i: "chromatic" for i in range(8, 15 + 1)} | {i: "organ" for i in range(16, 23 + 1)} From deec979fa8a13c66c9f314b286bbe9444597b1db Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Fri, 10 Nov 2023 10:46:51 +0000 Subject: [PATCH 03/14] classmethod+property is better... --- aria/data/midi.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/aria/data/midi.py b/aria/data/midi.py index 8db1fb1..e1c1ca9 100644 --- a/aria/data/midi.py +++ b/aria/data/midi.py @@ -114,8 +114,9 @@ def __init__( } ] - @functools.cached_property - def program_to_instrument(self): + @classmethod + @property + def program_to_instrument(cls): # This combines the individual dictionaries into one return ( {i: "piano" for i in range(0, 7 + 1)} From 6e58f0985b8f445adb7e2ee8de10c325a1ecb930 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Fri, 10 Nov 2023 10:47:19 +0000 Subject: [PATCH 04/14] remove functools import --- aria/data/midi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aria/data/midi.py b/aria/data/midi.py index e1c1ca9..8ff5f8e 100644 --- a/aria/data/midi.py +++ b/aria/data/midi.py @@ -1,5 +1,4 @@ """Utils for data/MIDI processing.""" -import functools import hashlib import json import re From 7073a1aceeebb73a7fc96163052b10e5c6bb2081 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Fri, 10 Nov 2023 13:35:47 +0000 Subject: [PATCH 05/14] use separate workers to build dataset instead of process pool --- aria/data/datasets.py | 48 ++++++++++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/aria/data/datasets.py b/aria/data/datasets.py index 1ca4a65..70d0888 100644 --- a/aria/data/datasets.py +++ b/aria/data/datasets.py @@ -13,11 +13,12 @@ from typing import Callable, Iterable from collections import defaultdict from copy import deepcopy -from multiprocessing import Pool +from multiprocessing import Pool, Process, Queue from aria.config import load_config from aria.tokenizer import Tokenizer, TokenizerLazy from aria.data.midi import MidiDict, get_test_fn +import tqdm def setup_logger(): @@ -529,6 +530,13 @@ def build( TokenizedDataset: Dataset saved midi_dataset and saved at save_path. """ + def _worker(input_queue, output_queue, tokenizer): + while True: + item = input_queue.get() + if item is None: + break + output_queue.put(_get_tokenized_seqs(item, tokenizer)) + def _get_tokenized_seqs_mp(_midi_dict_iter: Iterable): # Gets tokenized sequences using multiprocessing @@ -536,17 +544,33 @@ def _get_tokenized_seqs_mp(_midi_dict_iter: Iterable): # and stride logic in _get_tokenized_seqs assert isinstance(tokenizer, TokenizerLazy), "Unsupported tokenizer" - with Pool() as pool: - results = pool.imap( - functools.partial(_get_tokenized_seqs, tokenizer=tokenizer), - _midi_dict_iter, - ) - - for idx, tokenized_seq in enumerate(results): - yield tokenized_seq - - if idx % 50 == 0 and idx != 0: - logger.info(f"Processed MidiDicts: {idx}") + iq = Queue() + oq = Queue() + + _num_proc = os.cpu_count() + workers = [Process(target=functools.partial(_worker, tokenizer=tokenizer), args=(iq, oq)) for _ in + range(_num_proc)] + for w in workers: + w.start() + + def _enqueue(iq): + for midi_dict in _midi_dict_iter: + iq.put(midi_dict) + for i in range(_num_proc): + iq.put(None) + + enqueue = Process(target=_enqueue, args=(iq,)) + enqueue.start() + + with tqdm.tqdm() as t: + while True: + try: + result = oq.get(timeout=1000) + t.update(1) + yield result + except oq.Empty: + if not any(proc.is_alive() for proc in workers): + break logger = setup_logger() From 8bbc14b2f7498db3fd8f079a7531a80036b3c33b Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Fri, 10 Nov 2023 15:05:06 +0000 Subject: [PATCH 06/14] add jsonl.zst support; unit test; fix bug --- aria/data/datasets.py | 14 ++++++--- aria/data/jsonl_zst.py | 71 ++++++++++++++++++++++++++++++++++++++++++ tests/test_data.py | 18 +++++++++++ 3 files changed, 98 insertions(+), 5 deletions(-) create mode 100644 aria/data/jsonl_zst.py diff --git a/aria/data/datasets.py b/aria/data/datasets.py index fb69159..a5fab21 100644 --- a/aria/data/datasets.py +++ b/aria/data/datasets.py @@ -57,10 +57,14 @@ def __init__(self, entries: list[MidiDict] | Iterable): self.entries = entries def __len__(self): - return len(list(self.entries)) + if not isinstance(self.entries, list): + self.entries = list(self.entries) + return len(self.entries) def __getitem__(self, ind: int): - return list(self.entries)[ind] + if not isinstance(self.entries, list): + self.entries = list(self.entries) + return self.entries[ind] def __iter__(self): yield from self.entries @@ -564,11 +568,11 @@ def _enqueue(iq): with tqdm.tqdm() as t: while True: - try: - result = oq.get(timeout=1000) + if not oq.empty(): + result = oq.get() t.update(1) yield result - except oq.Empty: + else: if not any(proc.is_alive() for proc in workers): break diff --git a/aria/data/jsonl_zst.py b/aria/data/jsonl_zst.py new file mode 100644 index 0000000..273776f --- /dev/null +++ b/aria/data/jsonl_zst.py @@ -0,0 +1,71 @@ +import builtins +import contextlib +import io +import zstandard +import jsonlines +import json + + +class Reader: + """Reader for the jsonl.zst format.""" + + def __init__(self, path: str): + """Initializes the reader. + + Args: + path (str): Path to the file. + """ + self.path = path + + def __iter__(self): + with builtins.open(self.path, 'rb') as fh: + cctx = zstandard.ZstdDecompressor() + reader = io.BufferedReader(cctx.stream_reader(fh)) + yield from jsonlines.Reader(reader) + + +class Writer: + """Writer for the jsonl.zst format.""" + + def __init__(self, path: str): + """Initializes the writer. + + Args: + path (str): Path to the file. + """ + self.path = path + + def __enter__(self): + self.fh = builtins.open(self.path, 'wb') + self.cctx = zstandard.ZstdCompressor() + self.compressor = self.cctx.stream_writer(self.fh) + return self + + def write(self, obj): + self.compressor.write(json.dumps(obj).encode('UTF-8') + b'\n') + + def __exit__(self, exc_type, exc_value, traceback): + self.compressor.flush(zstandard.FLUSH_FRAME) + self.fh.flush() + self.compressor.close() + self.fh.close() + + +@contextlib.contextmanager +def open(path: str, mode: str = "r"): + """Read/Write a jsonl.zst file. + + Args: + path (str): Path to the file. + mode (str): Mode to open the file in. Only 'r' and 'w' are supported. + + Returns: + Reader or Writer: Reader if mode is 'r', Writer if mode is 'w'. + """ + if mode == 'r': + yield Reader(path) + elif mode == 'w': + with Writer(path) as writer: + yield writer + else: + raise ValueError(f"Unsupported mode '{mode}'") diff --git a/tests/test_data.py b/tests/test_data.py index a61c7e7..0efebec 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -5,6 +5,7 @@ from aria import tokenizer from aria.data import datasets from aria.data.midi import MidiDict +from aria.data import jsonl_zst if not os.path.isdir("tests/test_results"): os.makedirs("tests/test_results") @@ -245,6 +246,23 @@ def test_augmentation(self): tokenized_dataset.close() +class TestReaderWriter(unittest.TestCase): + def test_jsonl_zst(self): + data = [{"a": i, "b": i+1} for i in range(0, 100, 4)] + filename = "tests/test_results/test.jsonl.zst" + # if test.jsonl.zst exists, delete it + if os.path.isfile(filename): + os.remove(filename) + with jsonl_zst.open(filename, "w") as f: + for d in data: + f.write(d) + with jsonl_zst.open(filename, "r") as f: + for d, d2 in zip(data, f): + self.assertEqual(d, d2) + # Remove the file + os.remove(filename) + + if __name__ == "__main__": if os.path.isdir("tests/test_results") is False: os.mkdir("tests/test_results") From 95f492dd8bd1f076664e3a058f0bb6092afb8a78 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Fri, 10 Nov 2023 16:26:36 +0000 Subject: [PATCH 07/14] receive context length via commandline. It's more convenient than digging into the config file every time. --- aria/run.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/aria/run.py b/aria/run.py index 7a9fcfd..53ede92 100644 --- a/aria/run.py +++ b/aria/run.py @@ -124,6 +124,7 @@ def _parse_tokenized_dataset_args(): argp.add_argument("load_path", help="path midi_dict dataset") argp.add_argument("save_path", help="path to save dataset") argp.add_argument("-s", help="also produce shuffled", action="store_true") + argp.add_argument("-l", help="max sequence length", type=int, default=2048) return argp.parse_args(sys.argv[2:]) @@ -131,15 +132,13 @@ def _parse_tokenized_dataset_args(): def build_tokenized_dataset(args): from aria.tokenizer import TokenizerLazy from aria.data.datasets import TokenizedDataset - from aria.config import load_config - config = load_config()["data"]["dataset_gen_args"] tokenizer = TokenizerLazy() dataset = TokenizedDataset.build( tokenizer=tokenizer, save_path=args.save_path, midi_dataset_path=args.load_path, - max_seq_len=config["max_seq_len"], + max_seq_len=args.l, overwrite=True, ) if args.s: From 93301fafde4970832c5d6fff0af88181d1a3ade7 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Sun, 12 Nov 2023 21:08:06 +0000 Subject: [PATCH 08/14] fix a minor output format mismatch when grad_checkpoint is true --- aria/model/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aria/model/model.py b/aria/model/model.py index b32f5f8..bcdbb9a 100644 --- a/aria/model/model.py +++ b/aria/model/model.py @@ -320,7 +320,7 @@ def custom_forward(*args): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states, _ = torch.utils.checkpoint.checkpoint( create_custom_forward(layer), hidden_states, preserve_rng_state=True, From b42543d83c3ef13c20f56cdf63438e7ae01d1d83 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Wed, 22 Nov 2023 19:28:04 +0000 Subject: [PATCH 09/14] remove hardcoded cuda() as well as autocast for cpu inferencing; int8 quantization works; fix a bug on gradient_checkpointing along with use_cache --- aria/model/model.py | 3 +- aria/run.py | 105 +++++++++++++++++++++++++++++++++++++++----- aria/sample.py | 5 +-- requirements.txt | 4 +- 4 files changed, 99 insertions(+), 18 deletions(-) diff --git a/aria/model/model.py b/aria/model/model.py index bcdbb9a..995a2dd 100644 --- a/aria/model/model.py +++ b/aria/model/model.py @@ -311,7 +311,7 @@ def forward(self, src: torch.Tensor, use_cache=False, past_kv=None): # remove torch.compile from the train script as this is not currently # supported. # Implements gradient checkpoints on Encoder Layers. - if self.model_config.grad_checkpoint is True: + if self.model_config.grad_checkpoint is True and not use_cache: for layer in self.encode_layers: def create_custom_forward(module): @@ -326,7 +326,6 @@ def custom_forward(*args): preserve_rng_state=True, use_reentrant=True, ) - else: new_past_kv = [] past_kv = ( diff --git a/aria/run.py b/aria/run.py index 53ede92..53f91b7 100644 --- a/aria/run.py +++ b/aria/run.py @@ -2,26 +2,84 @@ import argparse import os +import re import sys +import pathlib def _parse_sample_args(): argp = argparse.ArgumentParser(prog="aria sample") - argp.add_argument("model", help="name of model config file") - argp.add_argument("ckpt_path", help="path to model checkpoint") - argp.add_argument("midi_path", help="path to midi file") + argp.add_argument("-m", help="name of model config file") + argp.add_argument("-c", help="path to model checkpoint") + argp.add_argument("-p", help="path to midi file") argp.add_argument( - "-var", help="number of variations", type=int, required=True + "-var", help="number of variations", type=int, default=1, ) argp.add_argument( - "-trunc", help="length to truncated prompt", type=int, required=True + "-trunc", help="length to truncated prompt", type=int, default=200, ) argp.add_argument("-e", action="store_true", help="enable force end") - argp.add_argument("-l", type=int, help="generation length") + argp.add_argument("-l", type=int, help="generation length", default=1024) + argp.add_argument("-q", action="store_true", help="quantize the model") return argp.parse_args(sys.argv[2:]) +def _get_model_name(name: str | None, state: dict): + if name is not None: + return name + + print("Model name is not provided. Trying to infer from checkpoint...") + _defaults = { + 16: "small", + 32: "medium", + 64: "large", + 96: "xlarge", + } + try: + pattern = re.compile(r"encode_layers\.(\d+)\.") + layer_keys = [pattern.search(k) for k in state.keys()] + layer_keys = set(p.group(1) for p in layer_keys if p is not None) + for i in range(len(layer_keys)): + assert str(i) in layer_keys + + if len(layer_keys) in _defaults: + print(f"Selecting model name: {_defaults[len(layer_keys)]}") + return _defaults[len(layer_keys)] + assert False + except: + raise ValueError("Model name is not provided and cannot be inferred.") + + +def _show_popup(prompt: str, files: list) -> str: + for i in range(len(files)): + print(f" [{i}] {files[i]}") + + for tries in range(3): # 3 tries in case of fat fingers + try: + res = int(input(prompt + f" [0-{len(files) - 1}]: ")) + assert 0 <= res < len(files) + return files[res] + except: + print("Invalid input. Try again...") + + raise ValueError("Invalid input.") + + +def _get_ckpt_path(ckpt_path: str | None) -> str: + if ckpt_path is None: + ckpts = list(pathlib.Path(".").glob("*.bin")) + ckpt_path = _show_popup("Choose a checkpoint", ckpts) + return ckpt_path + + +def _get_midi_path(midi_path: str | None) -> str: + if midi_path is None: + midis = list(pathlib.Path(".").glob("*.mid")) + list(pathlib.Path(".").glob("*.midi")) + midi_path = _show_popup("Choose a midi-file", midis) + return midi_path + + def sample(args): """Entrypoint for sampling""" @@ -34,11 +92,16 @@ def sample(args): from aria.data.midi import MidiDict from aria.utils import midi_to_audio - assert cuda_is_available() is True, "CUDA device not available" + if not cuda_is_available(): + print("CUDA device is not available. Using CPU instead.") + else: + greedy_sample = torch.autocast(device_type="cuda", dtype=torch.float16)(greedy_sample) + device = torch.device("cuda") if cuda_is_available() else torch.device("cpu") + + ckpt_path = _get_ckpt_path(args.c) # let user input path if not provided + model_state = torch.load(ckpt_path, map_location=device) + model_name = _get_model_name(args.m, model_state) # infer model name if not provided - model_name = args.model - ckpt_path = args.ckpt_path - midi_path = args.midi_path num_variations = args.var truncate_len = args.trunc force_end = args.e @@ -46,8 +109,26 @@ def sample(args): tokenizer = TokenizerLazy(return_tensors=True) model_config = ModelConfig(**load_model_config(model_name)) model_config.set_vocab_size(tokenizer.vocab_size) - model = TransformerLM(model_config).cuda() - model.load_state_dict(torch.load(ckpt_path)) + model = TransformerLM(model_config).to(device) + model.load_state_dict(model_state) + if args.q: + from torch.ao.quantization import get_default_qconfig_mapping + from torch.quantization.quantize_fx import prepare_fx, convert_fx + qconfig_mapping = get_default_qconfig_mapping() + + def _quantize(module, key, input_shape): + inp = torch.randn(input_shape, dtype=torch.float, device=device) + m = prepare_fx(getattr(module, key), qconfig_mapping, example_inputs=inp) + m = convert_fx(m) + setattr(module, key, m) + + for i in range(len(model.model.encode_layers)): + _quantize(model.model.encode_layers[i], "mixed_qkv", input_shape=(1, 2048, model_config.n_heads)) + _quantize(model.model.encode_layers[i], "att_proj_linear", input_shape=(1, 2048, model_config.n_heads)) + _quantize(model.model.encode_layers[i], "ff_linear_1", input_shape=(1, 2048, model_config.n_heads)) + _quantize(model.model.encode_layers[i], "ff_linear_2", input_shape=(1, 2048, model_config.n_heads * model_config.ff_mult)) + + midi_path = _get_midi_path(args.p) # let user input midi path if not provided if args.l and 0 < args.l < model.max_seq_len: max_gen_len = args.l diff --git a/aria/sample.py b/aria/sample.py index 36026b9..275ba57 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -43,7 +43,6 @@ def _get_cfg_coeff(cfg_gamma, cfg_mode, cur_pos, start_pos, total_len): # temp=0.85, top_p=0.9, cfg_gamma=1.4 -@torch.autocast(device_type="cuda", dtype=torch.float16) def greedy_sample( model: TransformerLM, tokenizer: Tokenizer, @@ -128,7 +127,7 @@ def greedy_sample( for neg_seq in neg_prompts ], axis=0, - ).cuda() + ) neg_len = ( neg_min_len if neg_prompt_len is None @@ -136,7 +135,7 @@ def greedy_sample( ) neg_tokens = neg_prompt_tensors[:, :neg_len] - tokens = torch.full((bsz, total_len), pad_id).cuda() + tokens = torch.full((bsz, total_len), pad_id) for idx, unencoded_seq in enumerate(prompts): tokens[idx, : len(unencoded_seq)] = tokenizer.encode(unencoded_seq) diff --git a/requirements.txt b/requirements.txt index e211c7c..f7d9ea4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,6 @@ torch >= 2.1 accelerate mido jsonlines -pydub \ No newline at end of file +pydub +bitsandbytes +scipy \ No newline at end of file From 75c7a65eba58fe498f93bbf7c3ecd7619d28f748 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Wed, 22 Nov 2023 19:45:40 +0000 Subject: [PATCH 10/14] fixing device --- aria/run.py | 1 + aria/sample.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/aria/run.py b/aria/run.py index 53f91b7..27e24e1 100644 --- a/aria/run.py +++ b/aria/run.py @@ -151,6 +151,7 @@ def _quantize(module, key, input_shape): model, tokenizer, prompts, + device=device, force_end=force_end, max_seq_len=model_config.max_seq_len, max_gen_len=max_gen_len, diff --git a/aria/sample.py b/aria/sample.py index 275ba57..4cf76e2 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -49,6 +49,7 @@ def greedy_sample( prompts: List[list], max_seq_len: int, max_gen_len: int, + device: torch.device | None = None, cfg_gamma: float | None = 1.4, cfg_mode: str | None = None, neg_prompts: List[list] | None = None, @@ -66,6 +67,7 @@ def greedy_sample( prompts (List[list]): A list of prompts to sample as a batch. max_seq_len (int): Maximum sequence length supported by the model. max_gen_len (int): Maximum desired sequence length of the samples. + device (torch.device, optional): Device to use. Defaults to None. cfg_gamma (float, optional): CFG gamma parameter. Defaults to 1.2. This parameter *determines* whether parameters related to CFG are used. None: No CFG or interpolation. `cfg_mode, neg_prompts, neg_prompt_len, alpha` are ignored. @@ -87,6 +89,7 @@ def greedy_sample( List[list]: The list of samples, decoded by the tokenizer. """ assert tokenizer.return_tensors is True, "tokenizer must return tensors." + device = device or torch.device("cuda") model.eval() pad_id = tokenizer.pad_id @@ -120,7 +123,7 @@ def greedy_sample( [ torch.concat( [ - torch.full((neg_max_len - len(neg_seq),), pad_id), + torch.full((neg_max_len - len(neg_seq),), pad_id, device=device), tokenizer.encode(neg_seq), ] ) @@ -135,7 +138,7 @@ def greedy_sample( ) neg_tokens = neg_prompt_tensors[:, :neg_len] - tokens = torch.full((bsz, total_len), pad_id) + tokens = torch.full((bsz, total_len), pad_id, device=device) for idx, unencoded_seq in enumerate(prompts): tokens[idx, : len(unencoded_seq)] = tokenizer.encode(unencoded_seq) From caace6c6e4c55a3e536137a1332c77d7f58659fd Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Wed, 22 Nov 2023 19:47:32 +0000 Subject: [PATCH 11/14] fixing device --- aria/sample.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aria/sample.py b/aria/sample.py index 4cf76e2..1058cea 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -124,7 +124,7 @@ def greedy_sample( torch.concat( [ torch.full((neg_max_len - len(neg_seq),), pad_id, device=device), - tokenizer.encode(neg_seq), + tokenizer.encode(neg_seq).to(device), ] ) for neg_seq in neg_prompts @@ -140,7 +140,7 @@ def greedy_sample( tokens = torch.full((bsz, total_len), pad_id, device=device) for idx, unencoded_seq in enumerate(prompts): - tokens[idx, : len(unencoded_seq)] = tokenizer.encode(unencoded_seq) + tokens[idx, : len(unencoded_seq)] = tokenizer.encode(unencoded_seq).to(device) dim_tok_inserted = [False for _ in range(bsz)] input_text_mask = tokens != pad_id From e33ef69ccc40055ee411eef845ab960169767f02 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Wed, 22 Nov 2023 20:07:22 +0000 Subject: [PATCH 12/14] bitsandbytes unnecessary now. --- requirements.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index f7d9ea4..e211c7c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,4 @@ torch >= 2.1 accelerate mido jsonlines -pydub -bitsandbytes -scipy \ No newline at end of file +pydub \ No newline at end of file From 62d73098ee45db706c699a5023f906a40d829615 Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Thu, 30 Nov 2023 17:45:56 +0000 Subject: [PATCH 13/14] Add a warning to force CPU when quantization is used in aria.run sample --- aria/run.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/aria/run.py b/aria/run.py index 27e24e1..224a057 100644 --- a/aria/run.py +++ b/aria/run.py @@ -5,6 +5,7 @@ import re import sys import pathlib +import warnings def _parse_sample_args(): @@ -112,6 +113,10 @@ def sample(args): model = TransformerLM(model_config).to(device) model.load_state_dict(model_state) if args.q: + if device.type != 'cpu': + warnings.warn("Quantization is not supported on CUDA devices. Using CPU instead.") + device = torch.device("cpu") + from torch.ao.quantization import get_default_qconfig_mapping from torch.quantization.quantize_fx import prepare_fx, convert_fx qconfig_mapping = get_default_qconfig_mapping() From da5eee8899b78dbf3fcb1f04fdaf6f9d67842adb Mon Sep 17 00:00:00 2001 From: Honglu Fan Date: Thu, 30 Nov 2023 18:12:01 +0000 Subject: [PATCH 14/14] formatting; Also add black formatter to Makefile --- Makefile | 5 +++- aria/run.py | 71 +++++++++++++++++++++++++++++++++++++++----------- aria/sample.py | 8 ++++-- 3 files changed, 66 insertions(+), 18 deletions(-) diff --git a/Makefile b/Makefile index d4032d8..6adfd91 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,7 @@ test: python -m unittest tests/test_*.py -PHONY: test \ No newline at end of file +style: + black --line-length 80 aria + +PHONY: test style \ No newline at end of file diff --git a/aria/run.py b/aria/run.py index 224a057..cc8404f 100644 --- a/aria/run.py +++ b/aria/run.py @@ -14,10 +14,16 @@ def _parse_sample_args(): argp.add_argument("-c", help="path to model checkpoint") argp.add_argument("-p", help="path to midi file") argp.add_argument( - "-var", help="number of variations", type=int, default=1, + "-var", + help="number of variations", + type=int, + default=1, ) argp.add_argument( - "-trunc", help="length to truncated prompt", type=int, default=200, + "-trunc", + help="length to truncated prompt", + type=int, + default=200, ) argp.add_argument("-e", action="store_true", help="enable force end") argp.add_argument("-l", type=int, help="generation length", default=1024) @@ -76,7 +82,9 @@ def _get_ckpt_path(ckpt_path: str | None) -> str: def _get_midi_path(midi_path: str | None) -> str: if midi_path is None: - midis = list(pathlib.Path(".").glob("*.mid")) + list(pathlib.Path(".").glob("*.midi")) + midis = list(pathlib.Path(".").glob("*.mid")) + list( + pathlib.Path(".").glob("*.midi") + ) midi_path = _show_popup("Choose a midi-file", midis) return midi_path @@ -96,12 +104,18 @@ def sample(args): if not cuda_is_available(): print("CUDA device is not available. Using CPU instead.") else: - greedy_sample = torch.autocast(device_type="cuda", dtype=torch.float16)(greedy_sample) - device = torch.device("cuda") if cuda_is_available() else torch.device("cpu") + greedy_sample = torch.autocast(device_type="cuda", dtype=torch.float16)( + greedy_sample + ) + device = ( + torch.device("cuda") if cuda_is_available() else torch.device("cpu") + ) ckpt_path = _get_ckpt_path(args.c) # let user input path if not provided model_state = torch.load(ckpt_path, map_location=device) - model_name = _get_model_name(args.m, model_state) # infer model name if not provided + model_name = _get_model_name( + args.m, model_state + ) # infer model name if not provided num_variations = args.var truncate_len = args.trunc @@ -113,27 +127,54 @@ def sample(args): model = TransformerLM(model_config).to(device) model.load_state_dict(model_state) if args.q: - if device.type != 'cpu': - warnings.warn("Quantization is not supported on CUDA devices. Using CPU instead.") + if device.type != "cpu": + warnings.warn( + "Quantization is not supported on CUDA devices. Using CPU instead." + ) device = torch.device("cpu") from torch.ao.quantization import get_default_qconfig_mapping from torch.quantization.quantize_fx import prepare_fx, convert_fx + qconfig_mapping = get_default_qconfig_mapping() def _quantize(module, key, input_shape): inp = torch.randn(input_shape, dtype=torch.float, device=device) - m = prepare_fx(getattr(module, key), qconfig_mapping, example_inputs=inp) + m = prepare_fx( + getattr(module, key), qconfig_mapping, example_inputs=inp + ) m = convert_fx(m) setattr(module, key, m) for i in range(len(model.model.encode_layers)): - _quantize(model.model.encode_layers[i], "mixed_qkv", input_shape=(1, 2048, model_config.n_heads)) - _quantize(model.model.encode_layers[i], "att_proj_linear", input_shape=(1, 2048, model_config.n_heads)) - _quantize(model.model.encode_layers[i], "ff_linear_1", input_shape=(1, 2048, model_config.n_heads)) - _quantize(model.model.encode_layers[i], "ff_linear_2", input_shape=(1, 2048, model_config.n_heads * model_config.ff_mult)) - - midi_path = _get_midi_path(args.p) # let user input midi path if not provided + _quantize( + model.model.encode_layers[i], + "mixed_qkv", + input_shape=(1, 2048, model_config.n_heads), + ) + _quantize( + model.model.encode_layers[i], + "att_proj_linear", + input_shape=(1, 2048, model_config.n_heads), + ) + _quantize( + model.model.encode_layers[i], + "ff_linear_1", + input_shape=(1, 2048, model_config.n_heads), + ) + _quantize( + model.model.encode_layers[i], + "ff_linear_2", + input_shape=( + 1, + 2048, + model_config.n_heads * model_config.ff_mult, + ), + ) + + midi_path = _get_midi_path( + args.p + ) # let user input midi path if not provided if args.l and 0 < args.l < model.max_seq_len: max_gen_len = args.l diff --git a/aria/sample.py b/aria/sample.py index 1058cea..886af41 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -123,7 +123,9 @@ def greedy_sample( [ torch.concat( [ - torch.full((neg_max_len - len(neg_seq),), pad_id, device=device), + torch.full( + (neg_max_len - len(neg_seq),), pad_id, device=device + ), tokenizer.encode(neg_seq).to(device), ] ) @@ -140,7 +142,9 @@ def greedy_sample( tokens = torch.full((bsz, total_len), pad_id, device=device) for idx, unencoded_seq in enumerate(prompts): - tokens[idx, : len(unencoded_seq)] = tokenizer.encode(unencoded_seq).to(device) + tokens[idx, : len(unencoded_seq)] = tokenizer.encode(unencoded_seq).to( + device + ) dim_tok_inserted = [False for _ in range(bsz)] input_text_mask = tokens != pad_id