Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing sample CLI design; Refactor devices; Trying out quantization #70

Merged
merged 18 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions aria/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def forward(self, src: torch.Tensor, use_cache=False, past_kv=None):
# remove torch.compile from the train script as this is not currently
# supported.
# Implements gradient checkpoints on Encoder Layers.
if self.model_config.grad_checkpoint is True:
if self.model_config.grad_checkpoint is True and not use_cache:
for layer in self.encode_layers:

def create_custom_forward(module):
Expand All @@ -326,7 +326,6 @@ def custom_forward(*args):
preserve_rng_state=True,
use_reentrant=True,
)

else:
new_past_kv = []
past_kv = (
Expand Down
154 changes: 141 additions & 13 deletions aria/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,93 @@

import argparse
import os
import re
import sys
import pathlib
import warnings


def _parse_sample_args():
argp = argparse.ArgumentParser(prog="aria sample")
argp.add_argument("model", help="name of model config file")
argp.add_argument("ckpt_path", help="path to model checkpoint")
argp.add_argument("midi_path", help="path to midi file")
argp.add_argument("-m", help="name of model config file")
argp.add_argument("-c", help="path to model checkpoint")
argp.add_argument("-p", help="path to midi file")
argp.add_argument(
"-var", help="number of variations", type=int, required=True
"-var",
help="number of variations",
type=int,
default=1,
)
argp.add_argument(
"-trunc", help="length to truncated prompt", type=int, required=True
"-trunc",
help="length to truncated prompt",
type=int,
default=200,
)
argp.add_argument("-e", action="store_true", help="enable force end")
argp.add_argument("-l", type=int, help="generation length")
argp.add_argument("-l", type=int, help="generation length", default=1024)
argp.add_argument("-q", action="store_true", help="quantize the model")

return argp.parse_args(sys.argv[2:])


def _get_model_name(name: str | None, state: dict):
if name is not None:
return name

print("Model name is not provided. Trying to infer from checkpoint...")
_defaults = {
16: "small",
32: "medium",
64: "large",
96: "xlarge",
}
try:
pattern = re.compile(r"encode_layers\.(\d+)\.")
layer_keys = [pattern.search(k) for k in state.keys()]
layer_keys = set(p.group(1) for p in layer_keys if p is not None)
for i in range(len(layer_keys)):
assert str(i) in layer_keys

if len(layer_keys) in _defaults:
print(f"Selecting model name: {_defaults[len(layer_keys)]}")
return _defaults[len(layer_keys)]
assert False
except:
raise ValueError("Model name is not provided and cannot be inferred.")


def _show_popup(prompt: str, files: list) -> str:
for i in range(len(files)):
print(f" [{i}] {files[i]}")

for tries in range(3): # 3 tries in case of fat fingers
try:
res = int(input(prompt + f" [0-{len(files) - 1}]: "))
assert 0 <= res < len(files)
return files[res]
except:
print("Invalid input. Try again...")

raise ValueError("Invalid input.")


def _get_ckpt_path(ckpt_path: str | None) -> str:
if ckpt_path is None:
ckpts = list(pathlib.Path(".").glob("*.bin"))
ckpt_path = _show_popup("Choose a checkpoint", ckpts)
return ckpt_path


def _get_midi_path(midi_path: str | None) -> str:
if midi_path is None:
midis = list(pathlib.Path(".").glob("*.mid")) + list(
pathlib.Path(".").glob("*.midi")
)
midi_path = _show_popup("Choose a midi-file", midis)
return midi_path


def sample(args):
"""Entrypoint for sampling"""

Expand All @@ -34,20 +101,80 @@ def sample(args):
from aria.data.midi import MidiDict
from aria.utils import midi_to_audio

assert cuda_is_available() is True, "CUDA device not available"
if not cuda_is_available():
print("CUDA device is not available. Using CPU instead.")
else:
greedy_sample = torch.autocast(device_type="cuda", dtype=torch.float16)(
greedy_sample
)
device = (
torch.device("cuda") if cuda_is_available() else torch.device("cpu")
)

ckpt_path = _get_ckpt_path(args.c) # let user input path if not provided
model_state = torch.load(ckpt_path, map_location=device)
model_name = _get_model_name(
args.m, model_state
) # infer model name if not provided

model_name = args.model
ckpt_path = args.ckpt_path
midi_path = args.midi_path
num_variations = args.var
truncate_len = args.trunc
force_end = args.e

tokenizer = TokenizerLazy(return_tensors=True)
model_config = ModelConfig(**load_model_config(model_name))
model_config.set_vocab_size(tokenizer.vocab_size)
model = TransformerLM(model_config).cuda()
model.load_state_dict(torch.load(ckpt_path))
model = TransformerLM(model_config).to(device)
model.load_state_dict(model_state)
if args.q:
if device.type != "cpu":
warnings.warn(
"Quantization is not supported on CUDA devices. Using CPU instead."
)
device = torch.device("cpu")

from torch.ao.quantization import get_default_qconfig_mapping
from torch.quantization.quantize_fx import prepare_fx, convert_fx

qconfig_mapping = get_default_qconfig_mapping()

def _quantize(module, key, input_shape):
inp = torch.randn(input_shape, dtype=torch.float, device=device)
m = prepare_fx(
getattr(module, key), qconfig_mapping, example_inputs=inp
)
m = convert_fx(m)
setattr(module, key, m)

for i in range(len(model.model.encode_layers)):
_quantize(
model.model.encode_layers[i],
"mixed_qkv",
input_shape=(1, 2048, model_config.n_heads),
)
_quantize(
model.model.encode_layers[i],
"att_proj_linear",
input_shape=(1, 2048, model_config.n_heads),
)
_quantize(
model.model.encode_layers[i],
"ff_linear_1",
input_shape=(1, 2048, model_config.n_heads),
)
_quantize(
model.model.encode_layers[i],
"ff_linear_2",
input_shape=(
1,
2048,
model_config.n_heads * model_config.ff_mult,
),
)

midi_path = _get_midi_path(
args.p
) # let user input midi path if not provided

if args.l and 0 < args.l < model.max_seq_len:
max_gen_len = args.l
Expand All @@ -70,6 +197,7 @@ def sample(args):
model,
tokenizer,
prompts,
device=device,
force_end=force_end,
max_seq_len=model_config.max_seq_len,
max_gen_len=max_gen_len,
Expand Down Expand Up @@ -124,7 +252,7 @@ def _parse_tokenized_dataset_args():
argp.add_argument("load_path", help="path midi_dict dataset")
argp.add_argument("save_path", help="path to save dataset")
argp.add_argument("-s", help="also produce shuffled", action="store_true")
argp.add_argument("-l", help="max sequence length", type=int)
argp.add_argument("-l", help="max sequence length", type=int, default=2048)

return argp.parse_args(sys.argv[2:])

Expand Down
18 changes: 12 additions & 6 deletions aria/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ def _get_cfg_coeff(cfg_gamma, cfg_mode, cur_pos, start_pos, total_len):
# temp=0.85, top_p=0.9, cfg_gamma=1.4


@torch.autocast(device_type="cuda", dtype=torch.float16)
def greedy_sample(
model: TransformerLM,
tokenizer: Tokenizer,
prompts: List[list],
max_seq_len: int,
max_gen_len: int,
device: torch.device | None = None,
cfg_gamma: float | None = 1.4,
cfg_mode: str | None = None,
neg_prompts: List[list] | None = None,
Expand All @@ -67,6 +67,7 @@ def greedy_sample(
prompts (List[list]): A list of prompts to sample as a batch.
max_seq_len (int): Maximum sequence length supported by the model.
max_gen_len (int): Maximum desired sequence length of the samples.
device (torch.device, optional): Device to use. Defaults to None.
cfg_gamma (float, optional): CFG gamma parameter. Defaults to 1.2.
This parameter *determines* whether parameters related to CFG are used.
None: No CFG or interpolation. `cfg_mode, neg_prompts, neg_prompt_len, alpha` are ignored.
Expand All @@ -88,6 +89,7 @@ def greedy_sample(
List[list]: The list of samples, decoded by the tokenizer.
"""
assert tokenizer.return_tensors is True, "tokenizer must return tensors."
device = device or torch.device("cuda")
model.eval()

pad_id = tokenizer.pad_id
Expand Down Expand Up @@ -121,24 +123,28 @@ def greedy_sample(
[
torch.concat(
[
torch.full((neg_max_len - len(neg_seq),), pad_id),
tokenizer.encode(neg_seq),
torch.full(
(neg_max_len - len(neg_seq),), pad_id, device=device
),
tokenizer.encode(neg_seq).to(device),
]
)
for neg_seq in neg_prompts
],
axis=0,
).cuda()
)
neg_len = (
neg_min_len
if neg_prompt_len is None
else min(neg_min_len, neg_prompt_len)
)
neg_tokens = neg_prompt_tensors[:, :neg_len]

tokens = torch.full((bsz, total_len), pad_id).cuda()
tokens = torch.full((bsz, total_len), pad_id, device=device)
for idx, unencoded_seq in enumerate(prompts):
tokens[idx, : len(unencoded_seq)] = tokenizer.encode(unencoded_seq)
tokens[idx, : len(unencoded_seq)] = tokenizer.encode(unencoded_seq).to(
device
)

dim_tok_inserted = [False for _ in range(bsz)]
input_text_mask = tokens != pad_id
Expand Down