From 42a6eb918627d869e22e4ff99898aa95a46f85dc Mon Sep 17 00:00:00 2001 From: Heng Guo Date: Mon, 30 Dec 2024 10:31:12 +0800 Subject: [PATCH 1/2] vlm 70B+ in single card (#395) --- auto_round/mllm/autoround_mllm.py | 42 ++++++++++++++++++++--------- auto_round/script/mllm.py | 3 ++- auto_round/special_model_handler.py | 10 +++++++ test_cuda/test_support_vlms.py | 9 +++++++ 4 files changed, 51 insertions(+), 13 deletions(-) diff --git a/auto_round/mllm/autoround_mllm.py b/auto_round/mllm/autoround_mllm.py index 9ba6c414..429e914a 100644 --- a/auto_round/mllm/autoround_mllm.py +++ b/auto_round/mllm/autoround_mllm.py @@ -14,6 +14,7 @@ from typing import Optional, Union from tqdm import tqdm +from copy import deepcopy import torch @@ -24,28 +25,45 @@ to_dtype, get_multimodal_block_names, find_matching_blocks, - extract_block_names_to_str + extract_block_names_to_str, + clear_memory ) from ..autoround import AutoRound from .template import get_template, Template +from auto_round.special_model_handler import SUPPORT_ONLY_TEXT_MODELS from .mllm_dataset import get_mllm_dataloader from ..low_cpu_mem.utils import get_layers_before_block -def _only_text_test(model, tokenizer, device): +def _only_text_test(model, tokenizer, device, model_type): """Test if the model whether can use text-only datasets.""" + + if model_type in SUPPORT_ONLY_TEXT_MODELS: # save time + return True + + new_tokenizer = deepcopy(tokenizer) + device = detect_device(device) + text = ["only text", "test"] + new_tokenizer.padding_side = 'left' + if new_tokenizer.pad_token is None: + new_tokenizer.pad_token = new_tokenizer.eos_token + inputs = new_tokenizer(text, return_tensors="pt", padding=True, truncation=True) + try: - device = detect_device(device) - text = ["only text", "test"] - tokenizer.padding_side = 'left' - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - if device.split(':')[0] != model.device.type: - model = model.to(device) - inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(model.device) + inputs = inputs.to(device) + model = model.to(device) model(**inputs) return True - except: + except RuntimeError as e: + if "CUDA out of memory" in str(e): + model = model.to("cpu") + inputs = inputs.to("cpu") + try: + model(**inputs) + except: + return False + return False + except Exception as e: return False @@ -165,7 +183,7 @@ def __init__( if isinstance(dataset, str): if quant_nontext_module or \ (dataset in CALIB_DATASETS.keys() and not \ - _only_text_test(model, tokenizer, device)): + _only_text_test(model, tokenizer, device, self.template.model_type)): if quant_nontext_module: logger.warning(f"Text only dataset cannot be used for calibrating non-text modules," "switching to liuhaotian/llava_conv_58k") diff --git a/auto_round/script/mllm.py b/auto_round/script/mllm.py index 634518be..8541691e 100644 --- a/auto_round/script/mllm.py +++ b/auto_round/script/mllm.py @@ -270,7 +270,8 @@ def tune(args): os.environ["CUDA_VISIBLE_DEVICES"] = args.device args.device = ",".join(map(str, range(len(devices)))) devices = args.device.replace(" ", "").split(',') - use_auto_mapping = True + if len(devices) > 1: + use_auto_mapping = True ##for 70B model on single card, use auto will cause some layer offload to cpu elif args.device == "auto": use_auto_mapping == True diff --git a/auto_round/special_model_handler.py b/auto_round/special_model_handler.py index 36aa411a..95471c71 100644 --- a/auto_round/special_model_handler.py +++ b/auto_round/special_model_handler.py @@ -18,6 +18,16 @@ mllms_with_limited_bs = ("llava", "qwen2_vl", "phi3_v", "mllama") # Limitations on batch_size skippable_cache_keys = ("past_key_value",) +SUPPORT_ONLY_TEXT_MODELS = [ + "phi3_v", + "cogvlm2", + "llava", + "qwen2_vl", + "deepseek_vl_v2", + "chatglm", + "idefics3" +] + def to_device(input, device=torch.device("cpu")): """Moves input data to the specified device. diff --git a/test_cuda/test_support_vlms.py b/test_cuda/test_support_vlms.py index 81fc3f4a..91e69aa3 100644 --- a/test_cuda/test_support_vlms.py +++ b/test_cuda/test_support_vlms.py @@ -261,6 +261,15 @@ def test_cogvlm(self): response = response.split("<|end_of_text|>")[0] print(response) shutil.rmtree(quantized_model_path, ignore_errors=True) + + def test_72b(self): + model_path = "/data5/models/Qwen2-VL-72B-Instruct/" + res = os.system( + f"cd .. && {self.python_path} -m auto_round --mllm " + f"--model {model_path} --iter 1 --nsamples 1 --bs 1 --output_dir {self.save_dir} --device {self.device}" + ) + self.assertFalse(res > 0 or res == -1, msg="qwen2-72b tuning fail") + shutil.rmtree(self.save_dir, ignore_errors=True) if __name__ == "__main__": unittest.main() \ No newline at end of file From 01b779c01d944d4d3c154c4add40e6c24f4d4c95 Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Mon, 30 Dec 2024 14:00:59 +0800 Subject: [PATCH 2/2] enhance calibration dataset and add awq pre quantization warning (#396) --- README.md | 20 --- auto_round/autoround.py | 6 +- auto_round/backend.py | 2 +- auto_round/calib_dataset.py | 316 ++++++++++++++++++++++++++---------- auto_round/script/llm.py | 12 +- auto_round/script/mllm.py | 26 ++- auto_round/utils.py | 50 +++++- docs/step_by_step.md | 25 +-- test/test_calib_dataset.py | 31 +++- 9 files changed, 341 insertions(+), 147 deletions(-) diff --git a/README.md b/README.md index 097d13e6..81dac94b 100644 --- a/README.md +++ b/README.md @@ -113,26 +113,6 @@ auto-round-fast \ -#### Formats - -**AutoRound Format**: This format is well-suited for CPU, HPU devices, 2 bits, as well as mixed-precision -inference. [2,4] -bits are supported. It also benefits -from the Marlin kernel, which can boost inference performance notably. However, it has not yet gained widespread -community adoption. - -**AutoGPTQ Format**: This format is well-suited for symmetric quantization on CUDA devices and is widely adopted by the -community, [2,3,4,8] bits are supported. It also benefits -from the Marlin kernel, which can boost inference performance notably. However, **the -asymmetric kernel has issues** that can cause considerable accuracy drops, particularly at 2-bit quantization and small -models. -Additionally, symmetric quantization tends to perform poorly at 2-bit precision. - -**AutoAWQ Format**: This format is well-suited for asymmetric 4-bit quantization on CUDA devices and is widely -adopted -within the community, only 4-bits quantization is supported. It features -specialized layer fusion tailored for Llama models. - ### API Usage (Gaudi2/CPU/GPU) ```python diff --git a/auto_round/autoround.py b/auto_round/autoround.py index bd9ef47b..68ff8b1f 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -965,7 +965,7 @@ def quant_layer(self, layer_name, inputs, q_inputs=None, device=torch.device("cp unwrapper_layer(self.model, wrapper_linear, layer_name, best_params) mv_module_from_gpu(layer, self.low_cpu_mem_usage) dump_info = f"quantized {layer_name}, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}" - logger.info(dump_info) + logger.debug(dump_info) def register_act_max_hook(self, model): def get_act_max_hook(module, input, output): @@ -1045,7 +1045,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} " f"layers in the block" ) - logger.info(dump_info) + logger.debug(dump_info) return output, output if self.lr_scheduler is None: @@ -1136,7 +1136,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} " f"layers in the block, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}" ) - logger.info(dump_info) + logger.debug(dump_info) if len(unquantized_layer_names) != 0: logger.info(f"{unquantized_layer_names} have not been quantized") with torch.no_grad(): diff --git a/auto_round/backend.py b/auto_round/backend.py index fe05cd79..a19b7b87 100644 --- a/auto_round/backend.py +++ b/auto_round/backend.py @@ -145,7 +145,7 @@ def check_auto_round_exllamav2_installed(): BackendInfos['awq:gemm'] = BackendInfo(device=["cuda"], sym=[True, False], ##actrally is gemm packing_format="awq", bits=[4], group_size=None, - priority=4, feature_checks=[feature_num_greater_checker_1024], + priority=4, alias=["auto_awq:gemm", "auto_round:awq:gemm", "auto_round:auto_awq:gemm", "awq", "auto_awq", "auto_round:awq", "auto_round:auto_awq"], requirements=["autoawq"] diff --git a/auto_round/calib_dataset.py b/auto_round/calib_dataset.py index aad25b84..e2376fef 100644 --- a/auto_round/calib_dataset.py +++ b/auto_round/calib_dataset.py @@ -16,6 +16,8 @@ import random import torch +from datasets import IterableDataset + torch.use_deterministic_algorithms(True, warn_only=True) from torch.utils.data import DataLoader @@ -43,40 +45,47 @@ def register(dataset): return register -def get_tokenizer_function(tokenizer, seqlen, apply_template=False): +def apply_chat_template_to_samples(samples, tokenizer, seqlen): + from jinja2 import Template + chat_template = tokenizer.chat_template if tokenizer.chat_template is not None \ + else tokenizer.default_chat_template + template = Template(chat_template) + rendered_messages = [] + for text in samples: + message = [{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": text}] + rendered_message = template.render(messages=message, add_generation_prompt=True, \ + bos_token=tokenizer.bos_token) + rendered_messages.append(rendered_message) + example = tokenizer(rendered_messages, truncation=True, max_length=seqlen) + return example + + +def get_tokenizer_function(tokenizer, seqlen, apply_chat_template=False): """Returns a default tokenizer function. Args: tokenizer: The tokenizer to be used for tokenization. seqlen: The maximum sequence length. - apply_template: Whether to apply chat template in tokenization. + apply_chat_template: Whether to apply chat template in tokenization. Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length of seqlen to the "text" field of examples. """ - def default_tokenizer_function(examples, apply_template=apply_template): - if not apply_template: + def default_tokenizer_function(examples, apply_chat_template=apply_chat_template): + if not apply_chat_template: example = tokenizer(examples["text"], truncation=True, max_length=seqlen) else: - from jinja2 import Template # pylint: disable=E0401 - chat_template = tokenizer.chat_template if tokenizer.chat_template is not None \ - else tokenizer.default_chat_template - template = Template(chat_template) - rendered_messages = [] - for text in examples["text"]: - message = [{"role": "user", "content": text}] - rendered_message = template.render(messages=message, add_generation_prompt=True, \ - bos_token=tokenizer.bos_token) - rendered_messages.append(rendered_message) - example = tokenizer(rendered_messages, truncation=True, max_length=seqlen) + example = apply_chat_template_to_samples(examples["text"], tokenizer, seqlen) return example return default_tokenizer_function @register_dataset("NeelNanda/pile-10k") -def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split=None, seed=42, apply_template=False): +def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split=None, seed=42, + apply_chat_template=False): """Returns a dataloader for the specified dataset and split. Args: @@ -85,7 +94,7 @@ def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split data_name: The name of the dataset. split: The data split to be used (e.g., "train", "test"). seed: The random seed for shuffling the dataset. - apply_template: Whether to apply chat template in tokenization. + apply_chat_template: Whether to apply chat template in tokenization. Returns: A dataloader for the specified dataset and split, using the provided tokenizer and sequence length. @@ -93,7 +102,7 @@ def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split from datasets import load_dataset split = "train" - tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_template=apply_template) + tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template) calib_dataset = load_dataset(dataset_name, split=split) calib_dataset = calib_dataset.shuffle(seed=seed) @@ -102,6 +111,83 @@ def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split return calib_dataset +@register_dataset("BAAI/CCI3-HQ") +def get_CCI3_HQ_dataset(tokenizer, seqlen, dataset_name="BAAI/CCI3-HQ", split=None, seed=42, apply_chat_template=False): + """Returns a dataloader for the specified dataset and split. + + Args: + tokenizer: The tokenizer to be used for tokenization. + seqlen: The maximum sequence length. + data_name: The name of the dataset. + split: The data split to be used (e.g., "train", "test"). + seed: The random seed for shuffling the dataset. + apply_chat_template: Whether to apply chat template in tokenization. + + Returns: + A dataloader for the specified dataset and split, using the provided tokenizer and sequence length. + """ + from datasets import load_dataset + + tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template) + + calib_dataset = load_dataset(dataset_name, split='train', streaming=True) + calib_dataset = calib_dataset.take(10000) + calib_dataset = calib_dataset.shuffle(seed=seed) + calib_dataset = calib_dataset.map(tokenizer_function, batched=True) + + return calib_dataset + + +@register_dataset("codeparrot/github-code-clean") +def get_github_code_clean_dataset(tokenizer, seqlen, dataset_name="codeparrot/github-code-clean", split=None, seed=42, + apply_chat_template=False): + """Returns a dataloader for the specified dataset and split. + + Args: + tokenizer: The tokenizer to be used for tokenization. + seqlen: The maximum sequence length. + data_name: The name of the dataset. + split: The data split to be used (e.g., "train", "test"). + seed: The random seed for shuffling the dataset. + apply_chat_template: Whether to apply chat template in tokenization. + + Returns: + A dataloader for the specified dataset and split, using the provided tokenizer and sequence length. + """ + + def get_default_tokenizer_function(tokenizer, seqlen, apply_chat_template=False): + """Returns a default tokenizer function. + + Args: + tokenizer: The tokenizer to be used for tokenization. + seqlen: The maximum sequence length. + apply_chat_template: Whether to apply chat template in tokenization. + + Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length + of seqlen to the "code" field of examples. + """ + + def default_tokenizer_function(examples, apply_chat_template=apply_chat_template): + if not apply_chat_template: + example = tokenizer(examples["code"], truncation=True, max_length=seqlen) + else: + example = apply_chat_template_to_samples(examples["code"], tokenizer, seqlen) + return example + + return default_tokenizer_function + + from datasets import load_dataset + + tokenizer_function = get_default_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template) + + calib_dataset = load_dataset(dataset_name, split='train', streaming=True) + calib_dataset = calib_dataset.take(10000) + calib_dataset = calib_dataset.shuffle(seed=seed) + calib_dataset = calib_dataset.map(tokenizer_function, batched=True) + + return calib_dataset + + @register_dataset("madao33/new-title-chinese") def get_new_chinese_title_dataset( tokenizer, @@ -109,7 +195,7 @@ def get_new_chinese_title_dataset( dataset_name="madao33/new-title-chinese", split=None, seed=42, - apply_template=False + apply_chat_template=False ): """Returns a dataloader for the specified dataset and split. @@ -119,39 +205,29 @@ def get_new_chinese_title_dataset( data_name: The name of the dataset. split: The data split to be used (e.g., "train", "test"). seed: The random seed for shuffling the dataset. - apply_template: Whether to apply chat template in tokenization. + apply_chat_template: Whether to apply chat template in tokenization. Returns: A dataloader for the specified dataset and split, using the provided tokenizer and sequence length. """ - def get_tokenizer_function(tokenizer, seqlen, apply_template=apply_template): + def get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template): """Returns a default tokenizer function. Args: tokenizer: The tokenizer to be used for tokenization. seqlen: The maximum sequence length. - apply_template: Whether to apply chat template in tokenization. + apply_chat_template: Whether to apply chat template in tokenization. Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length of seqlen to the "text" field of examples. """ - def default_tokenizer_function(examples, apply_template=apply_template): - if not apply_template: + def default_tokenizer_function(examples, apply_chat_template=apply_chat_template): + if not apply_chat_template: example = tokenizer(examples["content"], truncation=True, max_length=seqlen) else: - from jinja2 import Template - chat_template = tokenizer.chat_template if tokenizer.chat_template is not None \ - else tokenizer.default_chat_template - template = Template(chat_template) - rendered_messages = [] - for text in examples["text"]: - message = [{"role": "user", "content": text}] - rendered_message = template.render(messages=message, add_generation_prompt=True, \ - bos_token=tokenizer.bos_token) - rendered_messages.append(rendered_message) - example = tokenizer(rendered_messages, truncation=True, max_length=seqlen) + example = apply_chat_template_to_samples(examples["content"], tokenizer, seqlen) return example return default_tokenizer_function @@ -159,7 +235,7 @@ def default_tokenizer_function(examples, apply_template=apply_template): split = "train" from datasets import load_dataset - tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_template=apply_template) + tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template) calib_dataset = load_dataset(dataset_name, split=split) calib_dataset = calib_dataset.shuffle(seed=seed) @@ -169,7 +245,7 @@ def default_tokenizer_function(examples, apply_template=apply_template): @register_dataset("mbpp") -def get_mbpp_dataset(tokenizer, seqlen, dataset_name="mbpp", split=None, seed=42, apply_template=False): +def get_mbpp_dataset(tokenizer, seqlen, dataset_name="mbpp", split=None, seed=42, apply_chat_template=False): """Returns a dataloader for the specified dataset and split. Args: @@ -178,14 +254,14 @@ def get_mbpp_dataset(tokenizer, seqlen, dataset_name="mbpp", split=None, seed=42 data_name: The name of the dataset. split: The data split to be used (e.g., "train", "test"). seed: The random seed for shuffling the dataset. - apply_template: Whether to apply chat template in tokenization. + apply_chat_template: Whether to apply chat template in tokenization. Returns: A dataloader for the specified dataset and split, using the provided tokenizer and sequence length. """ from datasets import load_dataset - tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_template=apply_template) + tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template) samples = [] splits = split @@ -208,34 +284,37 @@ def get_mbpp_dataset(tokenizer, seqlen, dataset_name="mbpp", split=None, seed=42 @register_dataset("local") -def get_local_dataset(tokenizer, seqlen, dataset_name="./tmp.json", split=None, seed=42, apply_template=False): +def get_local_dataset(tokenizer, seqlen, dataset_name="./tmp.json", split=None, seed=42, apply_chat_template=False): """Returns a dataloader for a custom dataset and split. We allow the input of a json or text file containing a processed text sample each line. Args: tokenizer: The tokenizer to be used for tokenization. seqlen: The maximum sequence length. - data_name: The name or path of the dataset, which is a jsonl file. + data_name: The name or path of the dataset, which is a json or jsonl file. split: The data split to be used (e.g., "train", "test"). seed: The random seed for shuffling the dataset. - apply_template: Whether to apply chat template in tokenization. + apply_chat_template: Whether to apply chat template in tokenization. Returns: A dataloader for a custom dataset and split, using the provided tokenizer and sequence length. """ - tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_template=apply_template) + tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template) def load_local_data(data_path): if data_path.endswith(".json"): with open(data_path, "r") as f: data = json.load(f) return data - elif data_path.endswith(".txt"): + elif data_path.endswith(".jsonl"): + data = [] with open(data_path) as f: - data = [line for line in f] + for line in f: + sample = json.loads(line) + data.append(sample) return data else: - logger.error("invalid local file type, for now only support json format data file.") + logger.error("invalid local file type, for now only support json/jsonl format data file.") samples = [] dataset = load_local_data(dataset_name) @@ -267,6 +346,67 @@ def load_local_data(data_path): return calib_dataset +def get_dataset_len(dataset): + """Calculates the length of a dataset. + + Args: + dataset: The dataset object, which can be any iterable or collection. + + Returns: + int: The length of the dataset. + + Raises: + If the dataset does not support `len()`, iterates through it to count the number of elements. + """ + try: + dataset_len = len(dataset) + return dataset_len + except: + cnt = 0 + for _ in dataset: + cnt += 1 + return cnt + + +def select(dataset, indices): + """Selects specific elements from a dataset based on given indices. + + Args: + dataset: The dataset object to iterate over. + indices: An iterable of integers specifying the indices to select. + + Yields: + Elements of the dataset corresponding to the specified indices. + + Notes: + Stops iterating once the highest index in `indices` has been processed + to optimize performance. + """ + indices = set(indices) + for idx, sample in enumerate(dataset): + if idx in indices: + yield sample + if idx > max(indices): + break + + +def select_dataset(dataset, indices): + """Selects elements from a dataset using its native `select` method, if available. + + Args: + dataset: The dataset object, which may have a `select` method. + indices: An iterable of integers specifying the indices to select. + + Returns: + A subset of the dataset, either using the dataset's `select` method or the + `select` function defined above as a fallback. + """ + try: + return dataset.select(indices) + except: + return select(dataset, indices) + + def get_dataloader( tokenizer, seqlen, @@ -287,7 +427,7 @@ def get_dataloader( seed (int, optional): The random seed for reproducibility. Defaults to 42. bs (int, optional): The batch size. Defaults to 4. nsamples (int, optional): The total number of samples to include. Defaults to 512. - apply_template: Whether to apply chat template in tokenization. + apply_chat_template: Whether to apply chat template in tokenization. Returns: DataLoader: The DataLoader for the calibrated dataset. @@ -302,7 +442,7 @@ def filter_func(example): return False input_ids = example["input_ids"][:seqlen] input_ids_list = input_ids.tolist() - if input_ids_list.count(input_ids_list[-1]) > seqlen // 2: + if len(input_ids_list) > 1 and input_ids_list.count(input_ids_list[-1]) > seqlen // 2: return False return True @@ -353,7 +493,7 @@ def concat_dataset_element(dataset): for name in dataset_names: split = None do_concat = False - apply_template = False + apply_chat_template = False if ":" in name: split_list = name.split(":") name, split_list = name.split(":")[0], name.split(":")[1:] @@ -365,8 +505,8 @@ def concat_dataset_element(dataset): data_lens[name] = int(values[0]) if key == "concat": do_concat = False if (len(values) > 0 and values[0].lower() == 'false') else True - if key == "apply_template": - apply_template = False if (len(values) > 0 and values[0].lower() == 'false') else True + if key == "apply_chat_template": + apply_chat_template = False if (len(values) > 0 and values[0].lower() == 'false') else True if is_local_path(name): get_dataset = CALIB_DATASETS.get("local") else: @@ -384,44 +524,56 @@ def concat_dataset_element(dataset): seed=seed, split=split, dataset_name=name, - apply_template=apply_template, + apply_chat_template=apply_chat_template, ) - dataset.set_format(type="torch", columns=["input_ids", "attention_mask"]) + if not isinstance(dataset, IterableDataset): + dataset.set_format(type="torch", columns=["input_ids", "attention_mask"]) if do_concat: dataset = concat_dataset_element(dataset) dataset = dataset.filter(filter_func) if name in data_lens: - dataset = dataset.select(range(data_lens[name])) + dataset = select_dataset(dataset, range(data_lens[name])) datasets.append(dataset) - indices = range(len(datasets)) - res = sorted(zip(indices, datasets), key=lambda x: len(x[1])) - indices = [item[0] for item in res] - datasets = [item[1] for item in res] - dataset_names = [dataset_names[index] for index in indices] - cnt = 0 if not data_lens else sum(data_lens.values()) - dataset_cnt_info = {} - if cnt > nsamples: - cnt = 0 - - for i in range(len(datasets)): - name = dataset_names[i].split(':')[0] - if name not in data_lens: - target_cnt = (nsamples - cnt) // (len(datasets) - len(data_lens)) if data_lens \ - else (nsamples - cnt) // (len(datasets) - i) - target_cnt = min(target_cnt, len(datasets[i])) - cnt += target_cnt - else: - target_cnt = data_lens[name] - datasets[i] = datasets[i].select(range(target_cnt)) - dataset_cnt_info[name] = target_cnt - if len(datasets) > 1: - from datasets import concatenate_datasets - - dataset_final = concatenate_datasets(datasets) - dataset_final = dataset_final.shuffle(seed=seed) - logger.info(dataset_cnt_info) - else: + if len(datasets) == 1: dataset_final = datasets[0] + else: + indices = range(len(datasets)) + lens = [] + for i in range(len(datasets)): + cnt = get_dataset_len(datasets[i]) + lens.append(cnt) + res = sorted(zip(indices, lens), key=lambda x: x[1]) + + # res = sorted(zip(indices, datasets), key=lambda x: len(x[1])) + indices = [item[0] for item in res] + datasets = [datasets[item[0]] for item in res] + dataset_names = [dataset_names[index] for index in indices] + cnt = 0 if not data_lens else sum(data_lens.values()) + dataset_cnt_info = {} + if cnt > nsamples: + cnt = 0 + + for i in range(len(datasets)): + name = dataset_names[i].split(':')[0] + if name not in data_lens: + target_cnt = (nsamples - cnt) // (len(datasets) - len(data_lens)) if data_lens \ + else (nsamples - cnt) // (len(datasets) - i) + target_cnt = min(target_cnt, lens[i]) + cnt += target_cnt + else: + target_cnt = data_lens[name] + datasets[i] = select_dataset(dataset, range(target_cnt)) + dataset_cnt_info[name] = target_cnt + if len(datasets) > 1: + from datasets import concatenate_datasets + + dataset_final = concatenate_datasets(datasets) + dataset_final = dataset_final.shuffle(seed=seed) + logger.info(dataset_cnt_info) + else: + dataset_final = datasets[0] + + # dataset_final = datasets[0] @torch.no_grad() def collate_batch(batch): diff --git a/auto_round/script/llm.py b/auto_round/script/llm.py index 56456446..4d13ffb3 100644 --- a/auto_round/script/llm.py +++ b/auto_round/script/llm.py @@ -170,7 +170,7 @@ def __init__(self, *args, **kwargs): "set --device 0,1,2 to use multiple cards.") self.add_argument("--tasks", default="lambada_openai,hellaswag,winogrande,piqa,mmlu,wikitext,truthfulqa_mc1," \ - "truthfulqa_mc2,openbookqa,boolq,rte,arc_easy,arc_challenge", + "openbookqa,boolq,arc_easy,arc_challenge", help="lm-eval tasks") self.add_argument("--disable_trust_remote_code", action='store_true', help="whether to disable trust_remote_code") @@ -271,6 +271,8 @@ def tune(args): if "marlin" in args.format and args.asym is True: assert False, "marlin backend only supports sym quantization, please remove --asym" + + ##must set this before import torch import os devices = args.device.replace(" ", "").split(',') @@ -416,6 +418,7 @@ def tune(args): ##TODO gptq could support some mixed precision config logger.warning(f"mixed precision exporting does not support {format} currently") + lm_head_layer_name = "lm_head" for n, _ in model.named_modules(): lm_head_layer_name = n @@ -430,6 +433,7 @@ def tune(args): f"reset `quant_lm_head` to `False` as quantizing lm_head with tied weights has not been " f"supported currently") break + if args.quant_lm_head: layer_config[lm_head_layer_name] = {"bits": args.bits} for format in formats: @@ -438,6 +442,12 @@ def tune(args): raise ValueError( f"{format} is not supported for lm-head quantization, please change to {auto_round_formats}") + if "auto_awq" in args.format: + from auto_round.utils import check_awq_gemm_compatibility + awq_supported, info = check_awq_gemm_compatibility(model,args.bits,args.group_size, not args.asym, layer_config) + if not awq_supported: + logger.warning(f"The AutoAWQ format may not be supported due to {info}") + autoround = round( model, tokenizer, args.bits, args.group_size, sym=not args.asym, batch_size=args.batch_size, dataset=args.dataset, seqlen=seqlen, nblocks=args.nblocks, iters=args.iters, lr=args.lr, diff --git a/auto_round/script/mllm.py b/auto_round/script/mllm.py index 8541691e..7036ec5b 100644 --- a/auto_round/script/mllm.py +++ b/auto_round/script/mllm.py @@ -46,10 +46,10 @@ def __init__(self, *args, **kwargs): self.add_argument("--device", "--devices", default="0", type=str, help="the device to be used for tuning. " - "Currently, device settings support CPU, GPU, and HPU." - "The default is set to cuda:0," - "allowing for automatic detection and switch to HPU or CPU." - "set --device 0,1,2 to use multiple cards.") + "Currently, device settings support CPU, GPU, and HPU." + "The default is set to cuda:0," + "allowing for automatic detection and switch to HPU or CPU." + "set --device 0,1,2 to use multiple cards.") self.add_argument("--asym", action='store_true', help="whether to use asym quantization") @@ -163,8 +163,6 @@ def __init__(self, *args, **kwargs): self.add_argument("--to_quant_block_names", default=None, type=str, help="Names of quantitative blocks, please use commas to separate them.") - - def setup_parser(): parser = BasicArgumentParser() @@ -191,24 +189,24 @@ def setup_parser(): def setup_lmeval_parser(): parser = argparse.ArgumentParser() parser.add_argument("--model", "--model_name", "--model_name_or_path", - help="model name or path") + help="model name or path") parser.add_argument("--tasks", type=str, default="MMBench_DEV_EN_V11,ScienceQA_VAL,TextVQA_VAL,POPE", help="eval tasks for VLMEvalKit.") # Args that only apply to Video Dataset parser.add_argument("--nframe", type=int, default=8, help="the number of frames to sample from a video," - " only applicable to the evaluation of video benchmarks.") + " only applicable to the evaluation of video benchmarks.") parser.add_argument("--pack", action='store_true', help="a video may associate with multiple questions, if pack==True," - " will ask all questions for a video in a single") + " will ask all questions for a video in a single") parser.add_argument("--fps", type=float, default=-1, help="set the fps for a video.") # Work Dir # Infer + Eval or Infer Only parser.add_argument("--mode", type=str, default='all', choices=['all', 'infer'], help="when mode set to 'all', will perform both inference and evaluation;" - " when set to 'infer' will only perform the inference.") + " when set to 'infer' will only perform the inference.") parser.add_argument('--eval_data_dir', type=str, default=None, help='path for VLMEvalKit to store the eval data. Default will store in ~/LMUData') # API Kwargs, Apply to API VLMs and Judge API LLMs @@ -227,7 +225,7 @@ def setup_lmeval_parser(): parser.add_argument('--rerun', action='store_true', help="if true, will remove all evaluation temp files and rerun.") parser.add_argument("--output_dir", default="./eval_result", type=str, - help="the directory to save quantized model") + help="the directory to save quantized model") args = parser.parse_args() return args @@ -285,7 +283,7 @@ def tune(args): processor, image_processor = None, None config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code) if "llava" in model_name and config.architectures[0] != "LlavaForConditionalGeneration": - from llava.model.builder import load_pretrained_model # pylint: disable=E0401 + from llava.model.builder import load_pretrained_model # pylint: disable=E0401 tokenizer, model, image_processor, _ = load_pretrained_model( model_name, model_base=None, model_name=model_name, torch_dtype=torch_dtype) @@ -394,7 +392,6 @@ def tune(args): if "--truncation" not in sys.argv: args.truncation = None - autoround = round(model, tokenizer, processor=processor, image_processor=image_processor, dataset=args.dataset, extra_data_dir=args.extra_data_dir, bits=args.bits, group_size=args.group_size, @@ -459,7 +456,7 @@ def setup_lmms_parser(): help="To get full list of tasks, use the command lmms-eval --tasks list", ) parser.add_argument("--output_dir", default="./eval_result", type=str, - help="the directory to save quantized model") + help="the directory to save quantized model") parser.add_argument( "--num_fewshot", type=int, @@ -515,4 +512,3 @@ def lmms_eval(args): apply_chat_template=False, ) return results - diff --git a/auto_round/utils.py b/auto_round/utils.py index 92f0c0a1..94e73b28 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -318,6 +318,7 @@ def validate_modules(module_names, quant_vision=False, vison_blocks_names=None): "or raise an issue at https://github.com/intel/auto-round/issues.") return + def get_common_prefix(paths): # Split each path into components and find the common prefix split_paths = [path.split('.') for path in paths] @@ -326,8 +327,9 @@ def get_common_prefix(paths): common_prefix = [comp for comp, other in zip(common_prefix, path) if comp == other] return '.'.join(common_prefix) + def extract_block_names_to_str(quant_block_list): - if not isinstance(quant_block_list, (list,tuple)): + if not isinstance(quant_block_list, (list, tuple)): return None # Extract common prefix for each list prefixes = [get_common_prefix(blocks) for blocks in quant_block_list] @@ -365,7 +367,7 @@ def find_matching_blocks(model, all_blocks, to_quant_block_names): target_blocks.append(matched_sublist) if not target_blocks: raise ValueError("No block names matched. Please check the input for to_quant_block_name," \ - "or set to_quant_block_name to None to automatically match quantizable blocks.") + "or set to_quant_block_name to None to automatically match quantizable blocks.") return target_blocks @@ -966,6 +968,7 @@ def torch_version_at_least(version_string): TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0") TORCH_VERSION_AT_LEAST_2_4 = torch_version_at_least("2.4.0") + # Note on HPU usage: # There are two modes available for enabling auto-round on HPU: # 1. Compile Mode @@ -977,11 +980,11 @@ def torch_version_at_least(version_string): def _check_hpu_compile_mode(): assert ( - os.getenv("PT_HPU_LAZY_MODE") == "0" + os.getenv("PT_HPU_LAZY_MODE") == "0" ), "Please set `PT_HPU_LAZY_MODE=0` to use HPU compile mode" # Note: this is a temporary solution, will be removed in the future assert ( - os.getenv("PT_ENABLE_INT64_SUPPORT") == "1" + os.getenv("PT_ENABLE_INT64_SUPPORT") == "1" ), "Please set `PT_ENABLE_INT64_SUPPORT=1` to use HPU compile mode" @@ -1107,9 +1110,46 @@ def get_fp_layer_names(model, fp_layers): not_to_quantized_layers.append(fp_layer) continue if fp_layer[-1].isdigit(): - fp_layer = fp_layer + "." ##ticky setting + fp_layer = fp_layer + "." ##tricky setting for name in all_layer_names: if fp_layer in name: not_to_quantized_layers.append(name) return not_to_quantized_layers + + +def check_awq_gemm_compatibility(model, bits, group_size, sym, layer_configs=None): + """Checks if a model is compatible with the AutoAWQ GEMM kernel. + + Args: + model: The model object to evaluate, typically a PyTorch model. + bits (int): The number of bits for quantization (must be 4 for compatibility). + group_size (int): The group size for quantization. + sym (bool): Whether symmetric quantization is used (not utilized in the current function logic). + layer_configs (dict, optional): A dictionary mapping layer names to configurations, where each + configuration can specify a custom number of bits for the layer. + + Returns: + tuple: A tuple containing: + - bool: `True` if the model is compatible, `False` otherwise. + - str: An error message describing why the model is incompatible, or an empty string if compatible. + """ + if bits != 4: + return False, f"AutoAWQ GEMM kernel only supports 4 bits" + for n, m in model.named_modules(): + if isinstance(m, transformers.modeling_utils.Conv1D): + return False, "AutoAWQ GEMM kernel does not support conv1d" + + layer_names = get_layer_names_in_block(model) + for layer_name in layer_names: + if layer_configs is not None and layer_name in layer_configs.keys() and layer_configs[layer_name].get("bits", + bits) > 8: + continue + + layer = get_module(model, layer_name) + if layer.in_features % group_size != 0: + return False, f"Layer {layer_name} in_features is not multiple of group_size {group_size}" + if layer.out_features % (32 // bits) != 0: + return False, f"Layer {layer_name} out_features is not multiple of 32 // bits" + + return True, "" diff --git a/docs/step_by_step.md b/docs/step_by_step.md index e66a1ffb..1d8fbd87 100644 --- a/docs/step_by_step.md +++ b/docs/step_by_step.md @@ -45,16 +45,19 @@ See more about loading [huggingface dataset](https://huggingface.co/docs/dataset tokens.append(token) return tokens ~~~ - -We support combination of different datasets and parametrization of calibration datasets by using "--dataset ./tmp.json: -concat,NeelNanda/pile-10k:split=train+val:num=256,mbpp:concat=True:num=128:apply_template". Both local calibration file -and huggingface dataset are supported. Through parametrization, users could specify splits of a dataset by setting " -split=split1+split2". A concatenation option could enable users to merge calibration samples, a process commonly used to -enhance calibration reliability. An 'apply_template' option would enable users to apply chat_template to calibration -data before tokenization and is widely used by instruct-models in generation. Please note that samples shorter than -args.seqlen will be dropped when concatenation option is not enabled. -Please use ',' to split datasets, ':' to split parameters of a dataset and '+' to add values for one targeted parameter. - + **Dataset combination**:We support combination of different datasets and parametrization of calibration datasets by using "--dataset ./tmp.json: + concat,NeelNanda/pile-10k:split=train+val:num=256,mbpp:concat=True:num=128:apply_chat_template". Both local calibration file + and huggingface dataset are supported. Through parametrization, users could specify splits of a dataset by setting " + split=split1+split2". + + **Samples concatenation**: A concatenation option could enable users to merge calibration samples. '--dataset NeelNanda/pile-10k:concat=True' + + **Apply chat template**: '--dataset NeelNanda/pile-10k:apply_chat_template' would enable users to apply chat_template to calibration + data before tokenization and is widely used by instruct-models in generation. Please note that samples shorter than + args.seqlen will be dropped when concatenation option is not enabled. + + Please use ',' to split datasets, ':' to split parameters of a dataset and '+' to add values for one targeted parameter. +
@@ -128,7 +131,7 @@ Please use ',' to split datasets, ':' to split parameters of a dataset and '+' t - To leverage auto-gptq marlin kernel, you need to install auto-gptq from source and export the model without sharding. ```bash - auto-round --model facebook/opt-125m --sym --bits 4 --group_size 128 --format "gptq:marlin" + auto-round --model facebook/opt-125m --sym --bits 4 --group_size 128 --format "auto_gptq:marlin" ``` - **Utilize the AdamW Optimizer:** diff --git a/test/test_calib_dataset.py b/test/test_calib_dataset.py index 2e7b887e..291d535c 100644 --- a/test/test_calib_dataset.py +++ b/test/test_calib_dataset.py @@ -30,11 +30,13 @@ def setUpClass(self): with open(self.json_file, "w") as json_file: json.dump(json_data, json_file, indent=4) - self.text_file = "./saved/tmp.txt" - txt_data = ["awefdsfsddfd", "fdfdfsdfdfdfd", "dfdsfsdfdfdfdf"] - with open(self.text_file, "w") as text_file: - for data in txt_data: - text_file.write(data + "\n") + jsonl_data = [{"text": "哈哈,開心點"}, {"text": "hello world"}] + os.makedirs("./saved", exist_ok=True) + self.jsonl_file = "./saved/tmp.jsonl" + with open(self.jsonl_file, "w") as jsonl_file: + for item in jsonl_data: + json.dump(item, jsonl_file, ensure_ascii=False) + jsonl_file.write('\n') model_name = "facebook/opt-125m" self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) @@ -54,7 +56,7 @@ def test_json(self): ) autoround.quantize() - def test_txt(self): + def test_jsonl(self): bits, group_size, sym = 4, 128, True autoround = AutoRound( self.model, @@ -63,13 +65,24 @@ def test_txt(self): group_size=group_size, sym=sym, iters=2, - seqlen=5, - dataset=self.text_file, + seqlen=4, + dataset=self.jsonl_file, + ) + autoround.quantize() + + def test_apply_chat_template(self): + model_name = "Qwen/Qwen2.5-0.5B-Instruct" + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + dataset = "NeelNanda/pile-10k:apply_chat_template" + bits, group_size, sym = 4, 128, True + autoround = AutoRound( + model, tokenizer, bits=bits, group_size=group_size, sym=sym, iters=2, seqlen=128, dataset=dataset ) autoround.quantize() def test_combine_dataset(self): - dataset = self.text_file + "," + "NeelNanda/pile-10k" + "," + "madao33/new-title-chinese" + "," + "mbpp" + dataset = "NeelNanda/pile-10k" + "," + "madao33/new-title-chinese" + "," + "mbpp" bits, group_size, sym = 4, 128, True autoround = AutoRound( self.model, self.tokenizer, bits=bits, group_size=group_size, sym=sym, iters=2, seqlen=128, dataset=dataset