diff --git a/.azure-pipelines/scripts/ut/run_ut_hpu.sh b/.azure-pipelines/scripts/ut/run_ut_hpu.sh index 77619df9..750562c2 100644 --- a/.azure-pipelines/scripts/ut/run_ut_hpu.sh +++ b/.azure-pipelines/scripts/ut/run_ut_hpu.sh @@ -18,9 +18,14 @@ LOG_DIR=/auto-round/log_dir mkdir -p ${LOG_DIR} ut_log_name=${LOG_DIR}/ut.log -find . -name "test*hpu_only.py" | sed "s,\.\/,python -m pytest --cov=\"${auto_round_path}\" --cov-report term --html=report.html --self-contained-html --cov-report xml:coverage.xml --cov-append -vs --disable-warnings ,g" > run.sh -cat run.sh -bash run.sh 2>&1 | tee ${ut_log_name} +find . -name "test*hpu_only.py" | sed "s,\.\/,python -m pytest --cov=\"${auto_round_path}\" --cov-report term --html=report.html --self-contained-html --cov-report xml:coverage.xml --cov-append -vs --disable-warnings ,g" > run_lazy.sh +find . -name "test*hpu_only.py" | sed "s,\.\/,python -m pytest --mode compile --cov=\"${auto_round_path}\" --cov-report term --html=report.html --self-contained-html --cov-report xml:coverage.xml --cov-append -vs --disable-warnings ,g" > run_compile.sh + +cat run_lazy.sh +bash run_lazy.sh 2>&1 | tee ${ut_log_name} + +cat run_compile.sh +bash run_compile.sh 2>&1 | tee ${ut_log_name} cp report.html ${LOG_DIR}/ cp coverage.xml ${LOG_DIR}/ diff --git a/auto_round/__main__.py b/auto_round/__main__.py index 7785b812..6bcebcde 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -53,9 +53,6 @@ def run_lmms(): lmms_eval(args) def switch(): - # if "--lmms" in sys.argv: - # sys.argv.remove("--lmms") - # run_lmms() if "--mllm" in sys.argv: sys.argv.remove("--mllm") run_mllm() diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 1371dc20..1bfc8288 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -1263,6 +1263,9 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k self.model.save_pretrained(output_dir) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) + processor = kwargs.get("processor", None) + if processor is not None: + processor.save_pretrained(output_dir) return from auto_round.export import EXPORT_FORMAT diff --git a/auto_round/export/export_to_autogptq/export.py b/auto_round/export/export_to_autogptq/export.py index 05034576..ad2cf5dc 100644 --- a/auto_round/export/export_to_autogptq/export.py +++ b/auto_round/export/export_to_autogptq/export.py @@ -88,8 +88,8 @@ def pack_layer(name, model, layer_config, backend, pbar): in_features = layer.weight.shape[0] out_features = layer.weight.shape[1] - ##bias = layer.bias is not None and torch.any(layer.bias) - bias = True ## if using the above, llama3 lambada RTN will be NAN , TODO why? + bias = layer.bias is not None + ##bias = True ## if using the above, llama3 lambada RTN will be NAN , TODO why? new_layer = QuantLinear( ##pylint: disable=E1123 bits, group_size, in_features, out_features, bias, weight_dtype=layer.weight.dtype ) diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index f90fb270..ee36b4eb 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -110,7 +110,7 @@ def pack_layer(name, model, layer_config, backend, pbar): elif isinstance(layer, transformers.pytorch_utils.Conv1D): in_features = layer.weight.shape[0] out_features = layer.weight.shape[1] - bias = layer.bias is not None and torch.any(layer.bias) + bias = layer.bias is not None if "awq" not in backend: new_layer = QuantLinear( ##pylint: disable=E1123 diff --git a/auto_round/mllm/autoround_mllm.py b/auto_round/mllm/autoround_mllm.py index fa07bd77..ccd1bc7a 100644 --- a/auto_round/mllm/autoround_mllm.py +++ b/auto_round/mllm/autoround_mllm.py @@ -38,7 +38,7 @@ def _only_text_test(model, tokenizer, device): tokenizer.padding_side = 'left' if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - if device != model.device.type: + if device.split(':')[0] != model.device.type: model = model.to(device) inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(model.device) model(**inputs) @@ -150,19 +150,20 @@ def __init__( self.to_quant_block_names = to_quant_block_names self.extra_data_dir = extra_data_dir self.quant_nontext_module = quant_nontext_module + self.processor = processor self.image_processor = image_processor self.template = template if template is not None else model.config.model_type if not isinstance(dataset, torch.utils.data.DataLoader): self.template = get_template( self.template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor) - - dataset = self.template.default_dataset if dataset is None else dataset + dataset = self.template.default_dataset if dataset is None else dataset from ..calib_dataset import CALIB_DATASETS from .mllm_dataset import MLLM_DATASET if isinstance(dataset, str): if quant_nontext_module or \ - (dataset in CALIB_DATASETS.keys() and not _only_text_test(model, tokenizer, device)): + (dataset in CALIB_DATASETS.keys() and not \ + _only_text_test(model, tokenizer, device)): if quant_nontext_module: logger.warning(f"Text only dataset cannot be used for calibrating non-text modules," "switching to liuhaotian/llava_conv_58k") @@ -372,4 +373,20 @@ def calib(self, nsamples, bs): m = m.to("meta") # torch.cuda.empty_cache() + def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **kwargs): + """Save the quantized model to the specified output directory in the specified format. + + Args: + output_dir (str, optional): The directory to save the quantized model. Defaults to None. + format (str, optional): The format in which to save the model. Defaults to "auto_round". + inplace (bool, optional): Whether to modify the model in place. Defaults to True. + **kwargs: Additional keyword arguments specific to the export format. + Returns: + object: The compressed model object. + """ + if self.processor is not None and not hasattr(self.processor, "chat_template"): + self.processor.chat_template = None + compressed_model = super().save_quantized( + output_dir=output_dir, format=format, inplace=inplace, processor=self.processor, **kwargs) + return compressed_model diff --git a/auto_round/mllm/template.py b/auto_round/mllm/template.py index 970fd1f8..08b4d9eb 100644 --- a/auto_round/mllm/template.py +++ b/auto_round/mllm/template.py @@ -118,24 +118,25 @@ def _register_template( def load_template(path: str): """Load template information from a json file.""" - data = json.load(open(path, "r")) - if "model_type" not in data: - data["model_type"] = "user_define" - if "replace_tokens" in data and data["replace_tokens"] is not None: - assert len(data["replace_tokens"]) % 2 == 0, \ - "the format of replace_tokens should be [old_tag1, replace_tag1, old_tag2, replace_tag2]" - temp = [] - for i in range(0, len(data["replace_tokens"]), 2): - temp.append((data["replace_tokens"][i], data["replace_tokens"][i + 1])) - data["replace_tokens"] = temp - if "processor" in data: - assert data["processor"] in PROCESSORS.keys(), \ - "{} is not supported, current support: {}".format(data["processor"], ",".join(PROCESSORS.keys())) - data["processor"] = PROCESSORS[data["processor"]] - template = _register_template( - **data - ) - return template + with open(path, "r") as file: + data = json.load(file) + if "model_type" not in data: + data["model_type"] = "user_define" + if "replace_tokens" in data and data["replace_tokens"] is not None: + assert len(data["replace_tokens"]) % 2 == 0, \ + "the format of replace_tokens should be [old_tag1, replace_tag1, old_tag2, replace_tag2]" + temp = [] + for i in range(0, len(data["replace_tokens"]), 2): + temp.append((data["replace_tokens"][i], data["replace_tokens"][i + 1])) + data["replace_tokens"] = temp + if "processor" in data: + assert data["processor"] in PROCESSORS.keys(), \ + "{} is not supported, current support: {}".format(data["processor"], ",".join(PROCESSORS.keys())) + data["processor"] = PROCESSORS[data["processor"]] + template = _register_template( + **data + ) + return template def _load_preset_template(): diff --git a/auto_round/script/mllm.py b/auto_round/script/mllm.py index 3c88a5b4..89e5b23a 100644 --- a/auto_round/script/mllm.py +++ b/auto_round/script/mllm.py @@ -418,13 +418,11 @@ def tune(args): inplace = False if len(format_list) > 1 else True for format_ in format_list: eval_folder = f'{export_dir}-{format_}' - if processor is not None and not hasattr(processor, "chat_template"): - processor.chat_template = None safe_serialization = True if "phi3_v" in model_type: safe_serialization = False autoround.save_quantized( - eval_folder, format=format_, inplace=inplace, processor=processor, safe_serialization=safe_serialization) + eval_folder, format=format_, inplace=inplace, safe_serialization=safe_serialization) def eval(args): diff --git a/auto_round/utils.py b/auto_round/utils.py index 8864d5ed..92f0c0a1 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -333,7 +333,7 @@ def extract_block_names_to_str(quant_block_list): prefixes = [get_common_prefix(blocks) for blocks in quant_block_list] # Join prefixes into a single string return ','.join(prefixes) - + def find_matching_blocks(model, all_blocks, to_quant_block_names): """ @@ -966,20 +966,36 @@ def torch_version_at_least(version_string): TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0") TORCH_VERSION_AT_LEAST_2_4 = torch_version_at_least("2.4.0") +# Note on HPU usage: +# There are two modes available for enabling auto-round on HPU: +# 1. Compile Mode +# 1) Use PyTorch version ≥ 2.4 (Intel® Gaudi® v1.18 or later) +# 2) Set `PT_HPU_LAZY_MODE=0` and `PT_ENABLE_INT64_SUPPORT=1` +# The compile mode can speed up quantization process but still in experimental stage. +# 2. Lazy Mode (By default) -def check_hpu_compile_mode(): + +def _check_hpu_compile_mode(): assert ( - os.getenv("PT_HPU_LAZY_MODE") == "0" + os.getenv("PT_HPU_LAZY_MODE") == "0" ), "Please set `PT_HPU_LAZY_MODE=0` to use HPU compile mode" # Note: this is a temporary solution, will be removed in the future assert ( - os.getenv("PT_ENABLE_INT64_SUPPORT") == "1" + os.getenv("PT_ENABLE_INT64_SUPPORT") == "1" ), "Please set `PT_ENABLE_INT64_SUPPORT=1` to use HPU compile mode" +def is_hpu_lazy_mode(): + return os.getenv("PT_HPU_LAZY_MODE") != "0" + + +def _use_hpu_compile_mode(): + return TORCH_VERSION_AT_LEAST_2_4 and not is_hpu_lazy_mode() + + def compile_func_on_hpu(func): - if TORCH_VERSION_AT_LEAST_2_4: - check_hpu_compile_mode() + if _use_hpu_compile_mode(): + _check_hpu_compile_mode() return torch.compile(func, backend="hpu_backend") return func @@ -1097,4 +1113,3 @@ def get_fp_layer_names(model, fp_layers): not_to_quantized_layers.append(name) return not_to_quantized_layers - diff --git a/test/_test_helpers.py b/test/_test_helpers.py new file mode 100644 index 00000000..ac753a2e --- /dev/null +++ b/test/_test_helpers.py @@ -0,0 +1,9 @@ +import pytest + + +def is_pytest_mode_compile(): + return pytest.mode == "compile" + + +def is_pytest_mode_lazy(): + return pytest.mode == "lazy" diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 00000000..f4e9675b --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,34 @@ +import os +from typing import Mapping + +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--mode", + action="store", + default="lazy", + help="{compile|lazy}, default lazy. Choose mode to run tests", + ) + + +backup_env = pytest.StashKey[Mapping]() + + +def pytest_configure(config): + pytest.mode = config.getoption("--mode") + assert pytest.mode.lower() in ["lazy", "compile"] + + config.stash[backup_env] = os.environ + + if pytest.mode == "lazy": + os.environ["PT_HPU_LAZY_MODE"] = "1" + elif pytest.mode == "compile": + os.environ["PT_HPU_LAZY_MODE"] = "0" + os.environ["PT_ENABLE_INT64_SUPPORT"] = "1" + + +def pytest_unconfigure(config): + os.environ.clear() + os.environ.update(config.stash[backup_env]) diff --git a/test/test_auto_round_hpu_only.py b/test/test_auto_round_hpu_only.py index 0e9e0680..8942e40b 100644 --- a/test/test_auto_round_hpu_only.py +++ b/test/test_auto_round_hpu_only.py @@ -1,3 +1,46 @@ +import pytest +import torch +from auto_round.utils import is_hpu_supported + +from _test_helpers import is_pytest_mode_compile, is_pytest_mode_lazy + + +def run_opt_125m_on_hpu(): + from auto_round import AutoRound + from transformers import AutoModelForCausalLM, AutoTokenizer + + model_name = "facebook/opt-125m" + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + bits, group_size, sym = 4, 128, False + autoround = AutoRound( + model, + tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=2, + seqlen=2, + ) + q_model, qconfig = autoround.quantize() + assert q_model is not None, f"Expected q_model to be not None" + + +@pytest.mark.skipif(not is_hpu_supported(), reason="HPU is not supported") +@pytest.mark.skipif(not is_pytest_mode_lazy(), reason="Only for lazy mode") +def test_opt_125m_lazy_mode(): + run_opt_125m_on_hpu() + + +@pytest.mark.skipif(not is_hpu_supported(), reason="HPU is not supported") +@pytest.mark.skipif(not is_pytest_mode_compile(), reason="Only for compile mode") +def test_opt_125m_compile_mode(): + torch._dynamo.reset() + run_opt_125m_on_hpu() + + def test_import(): from auto_round import AutoRound - from auto_round.export.export_to_itrex.export import save_quantized_as_itrex, WeightOnlyLinear \ No newline at end of file + from auto_round.export.export_to_itrex.export import ( + WeightOnlyLinear, save_quantized_as_itrex)