diff --git a/aria/model/model.py b/aria/model/model.py index 40da175..dae9686 100644 --- a/aria/model/model.py +++ b/aria/model/model.py @@ -311,7 +311,7 @@ def forward(self, src: torch.Tensor, use_cache=False, past_kv=None): # remove torch.compile from the train script as this is not currently # supported. # Implements gradient checkpoints on Encoder Layers. - if self.model_config.grad_checkpoint is True: + if self.model_config.grad_checkpoint is True and not use_cache: for layer in self.encode_layers: def create_custom_forward(module): @@ -326,7 +326,6 @@ def custom_forward(*args): preserve_rng_state=True, use_reentrant=True, ) - else: new_past_kv = [] past_kv = ( diff --git a/aria/run.py b/aria/run.py index de04216..cc8404f 100644 --- a/aria/run.py +++ b/aria/run.py @@ -2,26 +2,93 @@ import argparse import os +import re import sys +import pathlib +import warnings def _parse_sample_args(): argp = argparse.ArgumentParser(prog="aria sample") - argp.add_argument("model", help="name of model config file") - argp.add_argument("ckpt_path", help="path to model checkpoint") - argp.add_argument("midi_path", help="path to midi file") + argp.add_argument("-m", help="name of model config file") + argp.add_argument("-c", help="path to model checkpoint") + argp.add_argument("-p", help="path to midi file") argp.add_argument( - "-var", help="number of variations", type=int, required=True + "-var", + help="number of variations", + type=int, + default=1, ) argp.add_argument( - "-trunc", help="length to truncated prompt", type=int, required=True + "-trunc", + help="length to truncated prompt", + type=int, + default=200, ) argp.add_argument("-e", action="store_true", help="enable force end") - argp.add_argument("-l", type=int, help="generation length") + argp.add_argument("-l", type=int, help="generation length", default=1024) + argp.add_argument("-q", action="store_true", help="quantize the model") return argp.parse_args(sys.argv[2:]) +def _get_model_name(name: str | None, state: dict): + if name is not None: + return name + + print("Model name is not provided. Trying to infer from checkpoint...") + _defaults = { + 16: "small", + 32: "medium", + 64: "large", + 96: "xlarge", + } + try: + pattern = re.compile(r"encode_layers\.(\d+)\.") + layer_keys = [pattern.search(k) for k in state.keys()] + layer_keys = set(p.group(1) for p in layer_keys if p is not None) + for i in range(len(layer_keys)): + assert str(i) in layer_keys + + if len(layer_keys) in _defaults: + print(f"Selecting model name: {_defaults[len(layer_keys)]}") + return _defaults[len(layer_keys)] + assert False + except: + raise ValueError("Model name is not provided and cannot be inferred.") + + +def _show_popup(prompt: str, files: list) -> str: + for i in range(len(files)): + print(f" [{i}] {files[i]}") + + for tries in range(3): # 3 tries in case of fat fingers + try: + res = int(input(prompt + f" [0-{len(files) - 1}]: ")) + assert 0 <= res < len(files) + return files[res] + except: + print("Invalid input. Try again...") + + raise ValueError("Invalid input.") + + +def _get_ckpt_path(ckpt_path: str | None) -> str: + if ckpt_path is None: + ckpts = list(pathlib.Path(".").glob("*.bin")) + ckpt_path = _show_popup("Choose a checkpoint", ckpts) + return ckpt_path + + +def _get_midi_path(midi_path: str | None) -> str: + if midi_path is None: + midis = list(pathlib.Path(".").glob("*.mid")) + list( + pathlib.Path(".").glob("*.midi") + ) + midi_path = _show_popup("Choose a midi-file", midis) + return midi_path + + def sample(args): """Entrypoint for sampling""" @@ -34,11 +101,22 @@ def sample(args): from aria.data.midi import MidiDict from aria.utils import midi_to_audio - assert cuda_is_available() is True, "CUDA device not available" + if not cuda_is_available(): + print("CUDA device is not available. Using CPU instead.") + else: + greedy_sample = torch.autocast(device_type="cuda", dtype=torch.float16)( + greedy_sample + ) + device = ( + torch.device("cuda") if cuda_is_available() else torch.device("cpu") + ) + + ckpt_path = _get_ckpt_path(args.c) # let user input path if not provided + model_state = torch.load(ckpt_path, map_location=device) + model_name = _get_model_name( + args.m, model_state + ) # infer model name if not provided - model_name = args.model - ckpt_path = args.ckpt_path - midi_path = args.midi_path num_variations = args.var truncate_len = args.trunc force_end = args.e @@ -46,8 +124,57 @@ def sample(args): tokenizer = TokenizerLazy(return_tensors=True) model_config = ModelConfig(**load_model_config(model_name)) model_config.set_vocab_size(tokenizer.vocab_size) - model = TransformerLM(model_config).cuda() - model.load_state_dict(torch.load(ckpt_path)) + model = TransformerLM(model_config).to(device) + model.load_state_dict(model_state) + if args.q: + if device.type != "cpu": + warnings.warn( + "Quantization is not supported on CUDA devices. Using CPU instead." + ) + device = torch.device("cpu") + + from torch.ao.quantization import get_default_qconfig_mapping + from torch.quantization.quantize_fx import prepare_fx, convert_fx + + qconfig_mapping = get_default_qconfig_mapping() + + def _quantize(module, key, input_shape): + inp = torch.randn(input_shape, dtype=torch.float, device=device) + m = prepare_fx( + getattr(module, key), qconfig_mapping, example_inputs=inp + ) + m = convert_fx(m) + setattr(module, key, m) + + for i in range(len(model.model.encode_layers)): + _quantize( + model.model.encode_layers[i], + "mixed_qkv", + input_shape=(1, 2048, model_config.n_heads), + ) + _quantize( + model.model.encode_layers[i], + "att_proj_linear", + input_shape=(1, 2048, model_config.n_heads), + ) + _quantize( + model.model.encode_layers[i], + "ff_linear_1", + input_shape=(1, 2048, model_config.n_heads), + ) + _quantize( + model.model.encode_layers[i], + "ff_linear_2", + input_shape=( + 1, + 2048, + model_config.n_heads * model_config.ff_mult, + ), + ) + + midi_path = _get_midi_path( + args.p + ) # let user input midi path if not provided if args.l and 0 < args.l < model.max_seq_len: max_gen_len = args.l @@ -70,6 +197,7 @@ def sample(args): model, tokenizer, prompts, + device=device, force_end=force_end, max_seq_len=model_config.max_seq_len, max_gen_len=max_gen_len, @@ -124,7 +252,7 @@ def _parse_tokenized_dataset_args(): argp.add_argument("load_path", help="path midi_dict dataset") argp.add_argument("save_path", help="path to save dataset") argp.add_argument("-s", help="also produce shuffled", action="store_true") - argp.add_argument("-l", help="max sequence length", type=int) + argp.add_argument("-l", help="max sequence length", type=int, default=2048) return argp.parse_args(sys.argv[2:]) diff --git a/aria/sample.py b/aria/sample.py index 36026b9..886af41 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -43,13 +43,13 @@ def _get_cfg_coeff(cfg_gamma, cfg_mode, cur_pos, start_pos, total_len): # temp=0.85, top_p=0.9, cfg_gamma=1.4 -@torch.autocast(device_type="cuda", dtype=torch.float16) def greedy_sample( model: TransformerLM, tokenizer: Tokenizer, prompts: List[list], max_seq_len: int, max_gen_len: int, + device: torch.device | None = None, cfg_gamma: float | None = 1.4, cfg_mode: str | None = None, neg_prompts: List[list] | None = None, @@ -67,6 +67,7 @@ def greedy_sample( prompts (List[list]): A list of prompts to sample as a batch. max_seq_len (int): Maximum sequence length supported by the model. max_gen_len (int): Maximum desired sequence length of the samples. + device (torch.device, optional): Device to use. Defaults to None. cfg_gamma (float, optional): CFG gamma parameter. Defaults to 1.2. This parameter *determines* whether parameters related to CFG are used. None: No CFG or interpolation. `cfg_mode, neg_prompts, neg_prompt_len, alpha` are ignored. @@ -88,6 +89,7 @@ def greedy_sample( List[list]: The list of samples, decoded by the tokenizer. """ assert tokenizer.return_tensors is True, "tokenizer must return tensors." + device = device or torch.device("cuda") model.eval() pad_id = tokenizer.pad_id @@ -121,14 +123,16 @@ def greedy_sample( [ torch.concat( [ - torch.full((neg_max_len - len(neg_seq),), pad_id), - tokenizer.encode(neg_seq), + torch.full( + (neg_max_len - len(neg_seq),), pad_id, device=device + ), + tokenizer.encode(neg_seq).to(device), ] ) for neg_seq in neg_prompts ], axis=0, - ).cuda() + ) neg_len = ( neg_min_len if neg_prompt_len is None @@ -136,9 +140,11 @@ def greedy_sample( ) neg_tokens = neg_prompt_tensors[:, :neg_len] - tokens = torch.full((bsz, total_len), pad_id).cuda() + tokens = torch.full((bsz, total_len), pad_id, device=device) for idx, unencoded_seq in enumerate(prompts): - tokens[idx, : len(unencoded_seq)] = tokenizer.encode(unencoded_seq) + tokens[idx, : len(unencoded_seq)] = tokenizer.encode(unencoded_seq).to( + device + ) dim_tok_inserted = [False for _ in range(bsz)] input_text_mask = tokens != pad_id