diff --git a/examples/.config/model_params_onnxrt.json b/examples/.config/model_params_onnxrt.json index 4ade34f75..45fafcb34 100644 --- a/examples/.config/model_params_onnxrt.json +++ b/examples/.config/model_params_onnxrt.json @@ -18,6 +18,15 @@ "batch_size": 1, "algorithm": "RTN" }, + "llama-2-7b-rtn-with-past-qdq": { + "model_name": "meta-llama/Llama-2-7b-hf", + "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", + "dataset_location": "", + "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past-opset-21", + "main_script": "main.py", + "batch_size": 1, + "algorithm": "RTN" + }, "llama-2-7b-awq": { "model_name": "meta-llama/Llama-2-7b-hf", "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", @@ -36,6 +45,15 @@ "batch_size": 1, "algorithm": "AWQ" }, + "llama-2-7b-awq-with-past-qdq": { + "model_name": "meta-llama/Llama-2-7b-hf", + "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", + "dataset_location": "", + "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past-opset-21", + "main_script": "main.py", + "batch_size": 1, + "algorithm": "AWQ" + }, "llama-2-7b-gptq": { "model_name": "meta-llama/Llama-2-7b-hf", "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", @@ -54,6 +72,15 @@ "batch_size": 1, "algorithm": "GPTQ" }, + "llama-2-7b-gptq-with-past-qdq": { + "model_name": "meta-llama/Llama-2-7b-hf", + "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", + "dataset_location": "", + "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past-opset-21", + "main_script": "main.py", + "batch_size": 1, + "algorithm": "GPTQ" + }, "llama-2-7b-woq_tune": { "model_name": "meta-llama/Llama-2-7b-hf", "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", diff --git a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/README.md b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/README.md index 79b17c73f..6bbd8234f 100644 --- a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/README.md +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/README.md @@ -41,13 +41,21 @@ python prepare_model.py --input_model="meta-llama/Llama-2-7b-hf" \ Set `algorithm=WOQ_TUNE` to tune weight-only quantization algorithm or specify algorithm to `RTN` or `GPTQ` or `AWQ`. +`quant_format=QDQ` works only when: +- onnxruntime >= 1.19.0 +- opset version of the model >= 21 +- quantized bits is in [4, 8] + +otherwise it will execute QOperator automatically. + ```bash bash run_quant.sh --input_model=/path/to/model \ # folder path of onnx model --output_model=/path/to/model_tune \ # folder path to save onnx model --batch_size=batch_size # optional \ --dataset=NeelNanda/pile-10k \ --tokenizer=meta-llama/Llama-2-7b-hf \ # model name or folder path containing all relevant files for model's tokenizer - --algorithm=WOQ_TUNE # support WOQ_TUNE, RTN, AWQ, GPTQ + --algorithm=WOQ_TUNE # support WOQ_TUNE, RTN, AWQ, GPTQ \ + --quant_format=QDQ # support QOperator and QDQ ``` ## 2. Benchmark diff --git a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py index e327aa827..196d7d4d1 100644 --- a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py @@ -34,7 +34,7 @@ from torch.utils import data from onnx_neural_compressor import data_reader -from onnx_neural_compressor.quantization import config, matmul_nbits_quantizer, tuning +from onnx_neural_compressor.quantization import QuantFormat, config, matmul_nbits_quantizer, tuning logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.WARN @@ -74,7 +74,8 @@ parser.add_argument( "--tasks", nargs="+", - default=[ + default=["lambada_openai"], + choices=[ "winogrande", "copa", "piqa", @@ -105,6 +106,7 @@ default=[], help="nodes that will not be quantized. Doesn't take effect when 'algorithm' is 'WOQ_TUNE'", ) +parser.add_argument("--quant_format", type=str, default="QDQ", choices=["QOperator", "QDQ"]) args = parser.parse_args() if args.tune and not os.path.exists(args.output_model): @@ -347,8 +349,11 @@ def rewind(self): nodes_to_exclude = ["/lm_head/MatMul"] if not args.quantize_lm_head else [] nodes_to_exclude = list(set(args.nodes_to_exclude + nodes_to_exclude)) + quant_format = QuantFormat.QOperator if args.quant_format == "QOperator" else QuantFormat.QDQ if args.algorithm.upper() == "RTN": - algo_config = matmul_nbits_quantizer.RTNWeightOnlyQuantConfig(layer_wise_quant=args.layer_wise) + algo_config = matmul_nbits_quantizer.RTNWeightOnlyQuantConfig( + layer_wise_quant=args.layer_wise, quant_format=quant_format + ) quant = matmul_nbits_quantizer.MatMulNBitsQuantizer( model_path, n_bits=4, @@ -363,7 +368,9 @@ def rewind(self): elif args.algorithm.upper() == "AWQ": calibration_data_reader = AWQDataloader(model_path, pad_max=args.pad_max, batch_size=1) algo_config = matmul_nbits_quantizer.AWQWeightOnlyQuantConfig( - calibration_data_reader=calibration_data_reader, enable_mse_search=False + calibration_data_reader=calibration_data_reader, + enable_mse_search=False, + quant_format=quant_format, ) quant = matmul_nbits_quantizer.MatMulNBitsQuantizer( model_path, @@ -379,7 +386,9 @@ def rewind(self): elif args.algorithm.upper() == "GPTQ": calibration_data_reader = GPTQDataloader(model_path, seqlen=args.seqlen, batch_size=1) algo_config = matmul_nbits_quantizer.GPTQWeightOnlyQuantConfig( - calibration_data_reader=calibration_data_reader, layer_wise_quant=args.layer_wise + calibration_data_reader=calibration_data_reader, + layer_wise_quant=args.layer_wise, + quant_format=quant_format, ) quant = matmul_nbits_quantizer.MatMulNBitsQuantizer( model_path, @@ -395,7 +404,9 @@ def rewind(self): elif args.algorithm.upper() == "WOQ_TUNE": calibration_data_reader = GPTQDataloader(model_path, seqlen=args.seqlen, batch_size=1) # set tolerable_loss to 0.5% for test, default is 1% - custom_tune_config = tuning.TuningConfig(config_set=config.get_woq_tuning_config(), tolerable_loss=0.005) + custom_tune_config = tuning.TuningConfig( + config_set=config.get_woq_tuning_config(quant_format=quant_format), tolerable_loss=0.005 + ) best_model = tuning.autotune( model_input=model_path, tune_config=custom_tune_config, diff --git a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_benchmark.sh b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_benchmark.sh index 72348427c..fc8e60c87 100644 --- a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_benchmark.sh +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_benchmark.sh @@ -14,19 +14,19 @@ function init_params { do case $var in --input_model=*) - input_model=$(echo $var |cut -f2 -d=) + input_model=$(echo "$var" |cut -f2 -d=) ;; --batch_size=*) - batch_size=$(echo $var |cut -f2 -d=) + batch_size=$(echo "$var" |cut -f2 -d=) ;; --tokenizer=*) - tokenizer=$(echo $var |cut -f2 -d=) + tokenizer=$(echo "$var" |cut -f2 -d=) ;; --mode=*) - mode=$(echo $var |cut -f2 -d=) + mode=$(echo "$var" |cut -f2 -d=) ;; --intra_op_num_threads=*) - intra_op_num_threads=$(echo $var |cut -f2 -d=) + intra_op_num_threads=$(echo "$var" |cut -f2 -d=) ;; esac done @@ -42,19 +42,27 @@ function run_benchmark { input_model=$(dirname "$input_model") fi + extra_cmd="" + if [[ "${tokenizer}" =~ "Phi-3-mini" ]]; then - extra_cmd="--trust_remote_code True" + extra_cmd=$extra_cmd"--trust_remote_code True " + fi + + if [ "${batch_size}" ]; then + extra_cmd=$extra_cmd"--batch_size ${batch_size} " + fi + if [ "${tokenizer}" ]; then + extra_cmd=$extra_cmd"--tokenizer ${tokenizer} " + fi + if [ "${tasks}" ]; then + extra_cmd=$extra_cmd"--tasks ${tasks} " + fi + if [ "${intra_op_num_threads}" ]; then + extra_cmd=$extra_cmd"--intra_op_num_threads ${intra_op_num_threads} " fi - eval "python main.py \ - --model_path ${input_model} \ - --batch_size=${batch_size-1} \ - --tokenizer=${tokenizer-meta-llama/Llama-2-7b-hf} \ - --tasks=${tasks-lambada_openai} \ - --mode=${mode} \ - --intra_op_num_threads=${intra_op_num_threads-24} \ - --benchmark \ - ${extra_cmd}" + extra_cmd=$extra_cmd"--benchmark" + eval "python main.py --model_path ${input_model} --mode ${mode} ${extra_cmd}" } diff --git a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_quant.sh b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_quant.sh index 4198da9a8..1c8a84681 100644 --- a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_quant.sh +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_quant.sh @@ -12,22 +12,25 @@ function init_params { do case $var in --input_model=*) - input_model=$(echo $var |cut -f2 -d=) + input_model=$(echo "$var" |cut -f2 -d=) ;; --output_model=*) - output_model=$(echo $var |cut -f2 -d=) + output_model=$(echo "$var" |cut -f2 -d=) ;; --batch_size=*) - batch_size=$(echo $var |cut -f2 -d=) + batch_size=$(echo "$var" |cut -f2 -d=) ;; --dataset=*) - dataset=$(echo $var |cut -f2 -d=) + dataset=$(echo "$var" |cut -f2 -d=) ;; --tokenizer=*) - tokenizer=$(echo $var |cut -f2 -d=) + tokenizer=$(echo "$var" |cut -f2 -d=) ;; --algorithm=*) - algorithm=$(echo $var |cut -f2 -d=) + algorithm=$(echo "$var" |cut -f2 -d=) + ;; + --quant_format=*) + quant_format=$(echo "$var" |cut -f2 -d=) ;; esac done @@ -56,30 +59,42 @@ function run_tuning { echo "Created directory $output_model" fi + extra_cmd="" + if [[ "${tokenizer}" =~ "Phi-3-mini" ]]; then nodes_to_exclude="/model/layers.*/self_attn/qkv_proj/MatMul /model/layers.*/mlp/down_proj/MatMul" - extra_cmd="--nodes_to_exclude ${nodes_to_exclude} --trust_remote_code True" + extra_cmd=$extra_cmd"--nodes_to_exclude ${nodes_to_exclude} --trust_remote_code True " fi if [[ "${tokenizer}" =~ "Llama-3-8B" ]]; then nodes_to_exclude="/model/layers.*/mlp/down_proj/MatMul" - extra_cmd="--nodes_to_exclude ${nodes_to_exclude}" + extra_cmd=$extra_cmd"--nodes_to_exclude ${nodes_to_exclude} " fi if [[ "${tokenizer}" =~ "Qwen2-7B" ]]; then nodes_to_exclude="/model/layers.*/mlp/down_proj/MatMul /model/layers.*/mlp/up_proj/MatMul" - extra_cmd="--nodes_to_exclude ${nodes_to_exclude}" + extra_cmd=$extra_cmd"--nodes_to_exclude ${nodes_to_exclude} " + fi + + if [ "${tokenizer}" ]; then + extra_cmd=$extra_cmd"--tokenizer ${tokenizer} " + fi + if [ "${batch_size}" ]; then + extra_cmd=$extra_cmd"--batch_size ${batch_size} " + fi + if [ "${dataset}" ]; then + extra_cmd=$extra_cmd"--dataset ${dataset} " + fi + if [ "${algorithm}" ]; then + extra_cmd=$extra_cmd"--algorithm ${algorithm} " + fi + if [ "${tasks}" ]; then + extra_cmd=$extra_cmd"--tasks ${tasks} " + fi + if [ "${quant_format}" ]; then + extra_cmd=$extra_cmd"--quant_format ${quant_format} " fi - eval "python main.py \ - --model_path ${input_model} \ - --tokenizer ${tokenizer-meta-llama/Llama-2-7b-hf} \ - --output_model ${output_model} \ - --batch_size ${batch_size-1} \ - --dataset ${dataset-NeelNanda/pile-10k} \ - --algorithm ${algorithm-WOQ_TUNE} \ - --tasks ${tasks-lambada_openai} \ - --layer_wise \ - --tune \ - ${extra_cmd}" + extra_cmd=$extra_cmd"--layer_wise --tune" + eval "python main.py --model_path ${input_model} --output_model ${output_model} ${extra_cmd}" } main "$@" diff --git a/onnx_neural_compressor/algorithms/layer_wise/core.py b/onnx_neural_compressor/algorithms/layer_wise/core.py index 80077a9be..b2a7211b9 100644 --- a/onnx_neural_compressor/algorithms/layer_wise/core.py +++ b/onnx_neural_compressor/algorithms/layer_wise/core.py @@ -18,6 +18,7 @@ import copy import os import pathlib +import tempfile import onnx import onnxruntime as ort @@ -60,7 +61,7 @@ def layer_wise_quant( model = onnx_model.ONNXModel(model, ignore_warning=True, load_external_data=False) origin_model = copy.deepcopy(model) - + tmp_file = tempfile.TemporaryDirectory() providers = kwargs.get("providers", ["CPUExecutionProvider"]) # get and check split nodes @@ -97,7 +98,7 @@ def layer_wise_quant( # split model with given split node split_model_part_1, split_model_part_2 = split_model.split_model_with_node( - split_node.name, model.model_path, save_both_split_models + split_node.name, model.model_path, save_both_split_models, save_path=tmp_file.name ) if not save_both_split_models: @@ -201,6 +202,8 @@ def layer_wise_quant( onnx.external_data_helper.load_external_data_for_model( quantized_model_merged.model, os.path.dirname(quantized_model_merged.model_path) ) + + tmp_file.cleanup() return quantized_model_merged diff --git a/onnx_neural_compressor/algorithms/utility.py b/onnx_neural_compressor/algorithms/utility.py index 45c99c207..7d080ad8f 100644 --- a/onnx_neural_compressor/algorithms/utility.py +++ b/onnx_neural_compressor/algorithms/utility.py @@ -95,6 +95,17 @@ def attribute_to_kwarg(attribute): } +ONNX_TENSOR_TYPE = { + "bfloat16": getattr(onnx.TensorProto, "BFLOAT16", 16), + "float32": getattr(onnx.TensorProto, "FLOAT", 1), + "float16": getattr(onnx.TensorProto, "FLOAT16", 10), + "int4": getattr(onnx.TensorProto, "INT4", 22), + "uint4": getattr(onnx.TensorProto, "UNT4", 21), + "int8": getattr(onnx.TensorProto, "INT8", 3), + "uint8": getattr(onnx.TensorProto, "UINT8", 2), +} + + def _qType_to_np_type(qType): if isinstance(qType, int): return onnx.helper.tensor_dtype_to_np_dtype(qType) @@ -215,7 +226,7 @@ def calculate_scale_zp(rmin, rmax, qType, sym, reduce_range=False): rmin = -max_range rmax = max_range scale = (rmax - rmin) / (qmax - qmin) - scale[scale < np.finfo(rmax.dtype).tiny] = 1 + scale[abs(scale) < np.finfo(rmax.dtype).tiny] = 1 zero_point = ( np.multiply(np.ones(rmax.shape), np.round((qmax + qmin) / 2.0)).astype(dtype) if sym @@ -254,8 +265,8 @@ def quantize_data(data, qType, sym, reduce_range=False, ratio=1.0, axis=None): axis (int, optional): process data along a specific axis. Default is None (process the whole data) """ quantize_range = get_qmin_qmax_for_qType(qType, reduce_range, sym) - rmin = np.min(np.min(data), 0) if axis is None else np.min(data, axis=1, keepdims=True) - rmax = np.max(np.max(data), 0) if axis is None else np.max(data, axis=1, keepdims=True) + rmin = np.min(np.min(data), 0) if axis is None else np.min(data, axis=axis, keepdims=True) + rmax = np.max(np.max(data), 0) if axis is None else np.max(data, axis=axis, keepdims=True) rmin *= ratio rmax *= ratio @@ -298,6 +309,95 @@ def _get_blob_size(group_size, has_zp): # pragma: no cover return blob_size +def make_weight_only_dequant_node( + node: onnx.NodeProto, + weight_shape: tuple, + block_size: int, + num_bits: int, + dtype: str, + q_weight: np.array, + scale: np.array, + zero_point: np.array, + axis: int = 1, +): + """Build DequantizeLinear node. + Args: + node: original matmul node + weight_shape (tuple): original weight shape + block_size (int): how many elements share one scale/zp + num_bits (int): num_bits + dtype (str): use uint or int + q_weight (array): quantized weight + scale (array): scale + zero_point (array): zero point + axis (int): the axis of the dequantizing dimension of the input tensor + + Returns: + weight_only_dequant_node: DequantizeLinear node for weight dequantization + new_inits: initializers of the new node + """ + new_inits = [] + input_names = [] + kwargs = {"block_size": block_size, "axis": axis} + + q_weight = q_weight.reshape((weight_shape[-1], -1)).T + if num_bits == 4: + q_weight = ((q_weight[:, ::2] & 0xF | q_weight[:, 1::2] << 4) & 0xFF).astype("uint8") + + qtype = ONNX_TENSOR_TYPE.get(dtype + str(num_bits), None) + + if qtype is None: + raise ValueError( + "Unsupported qtype {}, only support {}".format(dtype + str(num_bits), list(ONNX_TENSOR_TYPE.keys())) + ) + + q_weight_tensor = onnx.helper.make_tensor( + name=node.input[1] + "_Q{}G{}".format(str(num_bits), str(block_size)), + data_type=qtype, + dims=weight_shape, + vals=q_weight.flatten().tobytes(), + raw=True, + ) + new_inits.append(q_weight_tensor) + input_names.append(q_weight_tensor.name) + + scale = scale.reshape((weight_shape[-1], -1)).T + scale_tensor = onnx.helper.make_tensor( + name=node.input[1] + "_scale", + data_type=onnx.helper.np_dtype_to_tensor_dtype(scale.dtype), + dims=scale.shape, + vals=scale.tobytes(), + raw=True, + ) + input_names.append(scale_tensor.name) + new_inits.append(scale_tensor) + + # build zero_point tensor + zero_point = zero_point.reshape((weight_shape[-1], -1)).T + if num_bits == 4: + zero_point = ((zero_point[:, ::2] & 0xF | zero_point[:, 1::2] << 4) & 0xFF).astype("uint8") + + zp_tensor = onnx.helper.make_tensor( + name=node.input[1] + "_zp", + data_type=qtype, + dims=scale.shape, + vals=zero_point.flatten().tobytes(), + raw=True, + ) + input_names.append(zp_tensor.name) + new_inits.append(zp_tensor) + + dequant_node = onnx.helper.make_node( + "DequantizeLinear", + inputs=input_names, + outputs=[q_weight_tensor.name + "_dequant"], + name=node.name + "_woq_dequant", + **kwargs, + ) + node.input[1] = dequant_node.output[0] + return dequant_node, new_inits + + def make_matmul_weight_only_node( node: onnx.NodeProto, weight_shape: tuple, @@ -436,6 +536,92 @@ def make_matmul_weight_only_node( return matmul_weight_only_node, new_inits +def quant_matmul_weight_only( + node, + weight, + dtype, + num_bits, + sym, + group_size, + ratio=1, + quant_format=None, + accuracy_level=0, +): + new_nodes = [] + new_inits = [] + remove_nodes = [] + + org_w_shape = weight.shape # ic, oc + group_size = group_size if group_size != -1 else org_w_shape[0] + k_blocks = (org_w_shape[0] - 1) // group_size + 1 + weight = pad_tensor(weight, group_size, k_blocks) + + if quant_format == 1: + _, _, zp, scale, q_weight = quantize_data( + weight.T.reshape((-1, group_size)), + dtype + str(num_bits), + sym, + ratio=ratio, + axis=1, + ) + dequant_node, inits = make_weight_only_dequant_node( + node=node, + weight_shape=org_w_shape, + num_bits=num_bits, + dtype=dtype, + q_weight=q_weight, + scale=scale.astype(weight.dtype), + axis=0, + block_size=group_size, + zero_point=zp, + ) + new_nodes.append(dequant_node) + new_inits.extend(inits) + elif quant_format == 0: + _, _, zp, scale, q_weight = quantize_data( + weight.T.reshape((-1, group_size)), + dtype + str(num_bits), + sym, + ratio=ratio, + axis=1, + ) + q_matmul_node, inits = make_matmul_weight_only_node( + node=node, + weight_shape=org_w_shape, + num_bits=num_bits, + group_size=group_size, + k_blocks=k_blocks, + q_weight=q_weight, + scale=scale.astype(weight.dtype), + zero_point=zp if not sym else None, + accuracy_level=accuracy_level, + ) + new_nodes.append(q_matmul_node) + new_inits.extend(inits) + remove_nodes.append(node) + else: + q_weight = qdq_data( + weight.T.reshape((-1, group_size)), + dtype + str(num_bits), + sym, + ratio=ratio, + axis=1, + ) + q_weight = np.reshape(q_weight, (org_w_shape[1], -1)) + q_weight = np.transpose(q_weight) + q_weight = q_weight[: org_w_shape[0], :].astype(weight.dtype) + q_weight_tensor = onnx.helper.make_tensor( + name=node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)), + data_type=onnx.helper.np_dtype_to_tensor_dtype(q_weight.dtype), + dims=weight.shape, + vals=q_weight.tobytes(), + raw=True, + ) + node.input[1] = q_weight_tensor.name + new_inits.append(q_weight_tensor) + return new_nodes, new_inits, remove_nodes + + def prepare_inputs(model, data_reader, providers): """Prepare inputs for weight only quantization. @@ -494,33 +680,35 @@ def pad_tensor(weight, group_size, k_blocks): return weight -def dump_woq_stats(model, quantize_config): +def dump_woq_stats(model, quantize_config, white_list=["MatMul"]): res = {} dtype_set = set() for node in model.graph.node: - if node.name.split("_Q")[0] not in quantize_config: - continue if node.op_type in ["MatMulFpQ4", "MatMulNBits"]: optype = "MatMul" else: optype = node.op_type + if optype not in white_list and optype != "DequantizeLinear": + continue + if optype not in res: res[optype] = {} - if re.fullmatch("^.*_Q\d*G\d*", node.input[1]): - search_out = re.search("_Q\d*", node.input[1]) - dtype = "A32W{}G{}".format( - node.input[1][search_out.start() + 2 : search_out.end()], node.input[1][search_out.end() + 1 :] - ) - else: - dtype = "FP32" - dtype_set.add(dtype) - if dtype in res[optype]: - res[optype][dtype] += 1 - else: - res[optype][dtype] = 1 + dtype = "FP32" + for inp in node.input: + if re.match("^.*_Q\d*G\d*", inp): + Q_position = re.search("_Q\d*", inp) + full_position = re.search("_Q\d*G\d*", inp) + dtype = "A32W{}G{}".format( + inp[Q_position.start() + 2 : Q_position.end()], + inp[Q_position.end() + 1 : full_position.end()], + ) + dtype_set.add(dtype) + break + + res[optype][dtype] = res[optype].get(dtype, 0) + 1 dtype_list = list(dtype_set) for dtype in dtype_list: diff --git a/onnx_neural_compressor/algorithms/weight_only/awq.py b/onnx_neural_compressor/algorithms/weight_only/awq.py index 81d896288..55a1b2801 100644 --- a/onnx_neural_compressor/algorithms/weight_only/awq.py +++ b/onnx_neural_compressor/algorithms/weight_only/awq.py @@ -63,6 +63,7 @@ def _apply_awq_scale(model, weight_config, absorb_pairs, output_dicts): weight = [] org_out = [] + weight_dtype = weight_config[nodes[0].name].get("weight_dtype", "int") num_bits = weight_config[nodes[0].name].get("weight_bits", 4) group_size = weight_config[nodes[0].name].get("weight_group_size", 32) sym = weight_config[nodes[0].name].get("weight_sym", True) @@ -70,6 +71,7 @@ def _apply_awq_scale(model, weight_config, absorb_pairs, output_dicts): # use same params for all children of one parent for node in nodes: + weight_config.setdefault(node.name, {}).update({"weight_dtype": weight_dtype}) weight_config.setdefault(node.name, {}).update({"weight_bits": num_bits}) weight_config.setdefault(node.name, {}).update({"weight_group_size": group_size}) weight_config.setdefault(node.name, {}).update({"weight_sym": sym}) @@ -98,24 +100,11 @@ def _apply_awq_scale(model, weight_config, absorb_pairs, output_dicts): weight = weight.T * scales weight = quant_utils.pad_tensor(weight.T, group_size, (org_w_shape[0] + group_size - 1) // group_size) - if (version.Version(ort.__version__) > constants.ONNXRT1161_VERSION and num_bits == 4) or ( - version.Version(ort.__version__) >= constants.ONNXRT116_VERSION - and num_bits == 4 - and group_size == 32 - ): - # MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions - # MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1 - q_weight = quant_utils.qdq_data( - weight.reshape((-1, group_size)), - "uint" + str(num_bits), - sym, - ).reshape(weight.shape) - else: - q_weight = quant_utils.qdq_data( - weight.reshape((-1, group_size)), - "int" + str(num_bits), - sym, - ).reshape(weight.shape) + q_weight = quant_utils.qdq_data( + weight.reshape((-1, group_size)), + weight_dtype + str(num_bits), + sym, + ).reshape(weight.shape) q_weight = q_weight[: org_w_shape[0], :] / np.expand_dims(scales, axis=-1) out = np.matmul(inp, q_weight) @@ -237,6 +226,7 @@ def _apply_awq_clip(model, weight_config, absorb_pairs, output_dicts): inp = np.concatenate(output_dicts[nodes[0].input[0]], axis=0) for node in nodes: + weight_dtype = weight_config[node.name].get("weight_dtype", "int") num_bits = weight_config[node.name].get("weight_bits", 4) group_size = weight_config[node.name].get("weight_group_size", 32) sym = weight_config[node.name].get("weight_sym", True) @@ -256,26 +246,12 @@ def _apply_awq_clip(model, weight_config, absorb_pairs, output_dicts): for i_s in range(10): ratio = 1 - i_s / 100 weight = copy.deepcopy(org_weight) - if (version.Version(ort.__version__) > constants.ONNXRT1161_VERSION and num_bits == 4) or ( - version.Version(ort.__version__) >= constants.ONNXRT116_VERSION - and num_bits == 4 - and group_size == 32 - ): - # MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions - # MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1 - weight = quant_utils.qdq_data( - weight.reshape((-1, group_size)), - "uint" + str(num_bits), - sym, - ratio=ratio, - ).reshape(org_weight.shape) - else: - weight = quant_utils.qdq_data( - weight.reshape((-1, group_size)), - "int" + str(num_bits), - sym, - ratio=ratio, - ).reshape(org_weight.shape) + weight = quant_utils.qdq_data( + weight.reshape((-1, group_size)), + weight_dtype + str(num_bits), + sym, + ratio=ratio, + ).reshape(org_weight.shape) cur_out = np.matmul(inp, weight[:, : org_w_shape[0]].T) loss = np.mean(np.power((org_out - cur_out), 2)) @@ -336,12 +312,12 @@ def awq_quantize( output_names = [] for node in model.nodes(): # check op_type of node is MatMul + # check op_name in quantization config # check dim 1 of input is weight tensor - # check weight_type is not "fp32" if ( node.op_type in ["MatMul"] + and node.name in weight_config and model.get_initializer(node.input[1]) is not None - and weight_config.get(node.name, {}).get("weight_dtype", "fp32") != "fp32" ): output_names.append(node.input[0]) output_names = list(set(output_names)) @@ -371,12 +347,12 @@ def awq_quantize( for node in input_name_to_nodes[input_name]: # check op_type of node is MatMul + # check op_name in quantization config # check dim 1 of input is weight tensor - # check weight_type is not "fp32" if ( node.op_type in ["MatMul"] + and node.name in weight_config and model.get_initializer(node.input[1]) is not None - and weight_config.get(node.name, {}).get("weight_dtype", "fp32") != "fp32" ): dump_pairs[parent].append(model.get_node(node.name)) @@ -408,7 +384,12 @@ def awq_quantize( model.remove_tensors_from_outputs(output_names) model.model.graph.output.MergeFrom(org_output) - model = rtn.rtn_quantize(model, weight_config, full_ratio, providers) + model = rtn.rtn_quantize( + model=model, + weight_config=weight_config, + ratios=full_ratio, + providers=providers, + ) return model diff --git a/onnx_neural_compressor/algorithms/weight_only/gptq.py b/onnx_neural_compressor/algorithms/weight_only/gptq.py index 4a7b35b31..06d184b19 100644 --- a/onnx_neural_compressor/algorithms/weight_only/gptq.py +++ b/onnx_neural_compressor/algorithms/weight_only/gptq.py @@ -22,11 +22,11 @@ import numpy as np import onnx import onnxruntime as ort -from packaging.version import Version from onnx_neural_compressor import constants, data_reader, onnx_model, utility from onnx_neural_compressor.algorithms import utility as quant_utils from onnx_neural_compressor.algorithms.layer_wise import core +from onnx_neural_compressor.algorithms.weight_only import rtn from onnx_neural_compressor.quantization import config from typing import List, Union # isort: skip @@ -228,12 +228,12 @@ def gptq_quantize( output_names = [] for node in model.nodes(): # check op_type of node is MatMul + # check op_name in quantization config # check dim 1 of input is weight tensor - # check weight_type is not "fp32" if ( node.op_type in ["MatMul"] + and node.name in weight_config and model.get_initializer(node.input[1]) is not None - and weight_config.get(node.name, {}).get("weight_dtype", "fp32") != "fp32" ): output_names.append(node.input[0]) output_names = list(set(output_names)) @@ -262,12 +262,12 @@ def gptq_quantize( for node in input_name_to_nodes[input_name]: # check op_type of node is MatMul + # check op_name in quantization config # check dim 1 of input is weight tensor - # check weight_type is not "fp32" if ( node.op_type in ["MatMul"] + and node.name in weight_config and model.get_initializer(node.input[1]) is not None - and weight_config.get(node.name, {}).get("weight_dtype", "fp32") != "fp32" ): weight = onnx.numpy_helper.to_array( model.get_initializer(model.get_node(node.name).input[1]), base_dir @@ -300,10 +300,14 @@ def gptq_quantize( num_bits = weight_config[node.name].get("weight_bits", 4) group_size = weight_config[node.name].get("weight_group_size", 32) sym = weight_config[node.name].get("weight_sym", True) + dtype = weight_config[node.name].get("weight_dtype", "int") accuracy_level = weight_config[node.name].get("accuracy_level", 0) - group_size = group_size if group_size != -1 else weight.shape[0] - dtype = weight.dtype + quant_format = getattr(weight_config[node.name].get("quant_format", None), "value", None) + weight_tensor = model.get_initializer(node.input[1]) + init_share_num = model.get_initializer_share_num(node.input[1]) + + # weight -> quant -> dequant -> q_weight q_weight = _gptq( weight, H, @@ -316,60 +320,25 @@ def gptq_quantize( mse=mse, perchannel=perchannel, ) - - weight_tensor = model.get_initializer(node.input[1]) - init_share_num = model.get_initializer_share_num(node.input[1]) - - satisfy_MatMulNBits_condition = Version(ort.__version__) > constants.ONNXRT1161_VERSION and num_bits == 4 - satisfy_MatMulFpQ4_condition = ( - Version(ort.__version__) >= constants.ONNXRT116_VERSION and num_bits == 4 and group_size == 32 + new_nodes, new_inits, remove_nodes = quant_utils.quant_matmul_weight_only( + node=node, + weight=weight, + dtype=dtype, + num_bits=num_bits, + sym=sym, + group_size=group_size, + quant_format=quant_format, + accuracy_level=accuracy_level, ) - if ("CUDAExecutionProvider" in providers and satisfy_MatMulNBits_condition) or ( - "CUDAExecutionProvider" not in providers - and (satisfy_MatMulFpQ4_condition or satisfy_MatMulNBits_condition) - ): - # MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions, supported by CPU EP - # MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1, supported by CPU EP AND CUDA EP - org_shape = weight.shape - k_blocks = (org_shape[0] + group_size - 1) // group_size - q_weight = quant_utils.pad_tensor(q_weight, group_size, k_blocks) - _, _, zp, scale, q_weight = quant_utils.quantize_data( - q_weight.T.reshape((-1, group_size)), - "uint" + str(num_bits), - sym, - axis=1, - ) - q_matmul_node, new_inits = quant_utils.make_matmul_weight_only_node( - node=node, - weight_shape=org_shape, - num_bits=num_bits, - group_size=group_size, - k_blocks=k_blocks, - q_weight=q_weight, - scale=scale.astype(dtype), - zero_point=zp if not sym else None, - accuracy_level=accuracy_level, - ) - - model.add_initializers(new_inits) - model.remove_node(node) - model.add_node(q_matmul_node) - else: - q_weight_tensor = onnx.helper.make_tensor( - name=node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)), - data_type=onnx.helper.np_dtype_to_tensor_dtype(dtype), - dims=q_weight.shape, - vals=q_weight.astype(dtype).tobytes(), - raw=True, - ) - model.add_initializer(q_weight_tensor) - node.input[1] = q_weight_tensor.name + model.add_initializers(new_inits) + model.add_nodes(new_nodes) + model.remove_nodes(remove_nodes) + if init_share_num == 1: model.remove_initializer(weight_tensor) model.remove_tensors_from_outputs(output_names) model.model.graph.output.MergeFrom(org_output) - model.topological_sort() # reload external data to prevent external data file path errors diff --git a/onnx_neural_compressor/algorithms/weight_only/rtn.py b/onnx_neural_compressor/algorithms/weight_only/rtn.py index d4ca7e55e..72f061554 100644 --- a/onnx_neural_compressor/algorithms/weight_only/rtn.py +++ b/onnx_neural_compressor/algorithms/weight_only/rtn.py @@ -21,7 +21,6 @@ import numpy as np import onnx import onnxruntime as ort -from packaging import version from onnx_neural_compressor import constants, onnx_model, utility from onnx_neural_compressor.algorithms import utility as quant_utils @@ -63,8 +62,8 @@ def rtn_quantize( if not isinstance(model, onnx_model.ONNXModel): model = onnx_model.ONNXModel(model) base_dir = os.path.dirname(model.model_path) if model.model_path is not None else "" - new_nodes = [] - remove_nodes = [] + new_nodes_all = [] + remove_nodes_all = [] total_num = len([i for i in model.nodes() if i.op_type in ["MatMul"]]) curr_id = 0 for node in model.nodes(): @@ -73,91 +72,46 @@ def rtn_quantize( utility.simple_progress_bar(total_num, curr_id) # check op_type of node is MatMul + # check op_name in quantization config # check dim 1 of input is weight tensor - # check weight_type is not "fp32" if ( - node.op_type in ["MatMul"] # check op_type of node is MatMul + node.op_type in ["MatMul"] + and node.name in weight_config and model.get_initializer(node.input[1]) is not None - and weight_config.get(node.name, {}).get("weight_dtype", "fp32") != "fp32" ): weight_tensor = model.get_initializer(node.input[1]) weight = onnx.numpy_helper.to_array(weight_tensor, base_dir=base_dir).copy() if len(weight.shape) != 2: continue - dtype = weight.dtype + dtype = weight_config[node.name].get("weight_dtype", "int") num_bits = weight_config[node.name].get("weight_bits", 4) group_size = weight_config[node.name].get("weight_group_size", 32) sym = weight_config[node.name].get("weight_sym", True) accuracy_level = weight_config[node.name].get("accuracy_level", 0) + quant_format = getattr(weight_config[node.name].get("quant_format", None), "value", None) - org_w_shape = weight.shape # ic, oc - group_size = group_size if group_size != -1 else org_w_shape[0] - - k_blocks = (org_w_shape[0] - 1) // group_size + 1 init_share_num = model.get_initializer_share_num(node.input[1]) - weight = quant_utils.pad_tensor(weight, group_size, k_blocks) - - satisfy_MatMulNBits_condition = ( - version.Version(ort.__version__) > constants.ONNXRT1161_VERSION and num_bits == 4 - ) - satisfy_MatMulFpQ4_condition = ( - version.Version(ort.__version__) >= constants.ONNXRT116_VERSION and num_bits == 4 and group_size == 32 + new_nodes, new_inits, remove_nodes = quant_utils.quant_matmul_weight_only( + node=node, + weight=weight, + dtype=dtype, + num_bits=num_bits, + sym=sym, + group_size=group_size, + ratio=ratios.get(node.input[1], 1), + quant_format=quant_format, + accuracy_level=accuracy_level, ) - if ("CUDAExecutionProvider" in providers and satisfy_MatMulNBits_condition) or ( - "CUDAExecutionProvider" not in providers - and (satisfy_MatMulFpQ4_condition or satisfy_MatMulNBits_condition) - ): # pragma: no cover - # MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions, supported by CPU EP - # MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1, supported by CPU EP AND CUDA EP - _, _, zp, scale, q_weight = quant_utils.quantize_data( - weight.T.reshape((-1, group_size)), - "uint" + str(num_bits), - sym, - ratio=ratios.get(node.input[1], 1), - axis=1, - ) - q_matmul_node, new_inits = quant_utils.make_matmul_weight_only_node( - node=node, - weight_shape=org_w_shape, - num_bits=num_bits, - group_size=group_size, - k_blocks=k_blocks, - q_weight=q_weight, - scale=scale.astype(dtype), - zero_point=zp if not sym else None, - accuracy_level=accuracy_level, - ) - - model.add_initializers(new_inits) - remove_nodes.append(node) - new_nodes.append(q_matmul_node) - else: - q_weight = quant_utils.qdq_data( - weight.T.reshape((-1, group_size)), - "int" + str(num_bits), - sym, - ratio=ratios.get(node.input[1], 1), - axis=1, - ) - q_weight = np.reshape(q_weight, (org_w_shape[1], -1)) - q_weight = np.transpose(q_weight) - q_weight = q_weight[: org_w_shape[0], :].astype(dtype) - q_weight_tensor = onnx.helper.make_tensor( - name=node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)), - data_type=onnx.helper.np_dtype_to_tensor_dtype(dtype), - dims=weight.shape, - vals=q_weight.tobytes(), - raw=True, - ) - model.add_initializer(q_weight_tensor) - node.input[1] = q_weight_tensor.name + model.add_initializers(new_inits) + new_nodes_all.extend(new_nodes) + remove_nodes_all.extend(remove_nodes) if init_share_num == 1: model.remove_initializer(weight_tensor) - model.add_nodes(new_nodes) - model.remove_nodes(remove_nodes) + model.add_nodes(new_nodes_all) + model.remove_nodes(remove_nodes_all) model.topological_sort() # reload external data to prevent external data file path errors diff --git a/onnx_neural_compressor/constants.py b/onnx_neural_compressor/constants.py index 71caf2a49..54889bda0 100644 --- a/onnx_neural_compressor/constants.py +++ b/onnx_neural_compressor/constants.py @@ -38,6 +38,7 @@ ONNXRT116_VERSION = version.Version("1.16.0") ONNXRT1161_VERSION = version.Version("1.16.1") +ONNXRT119_VERSION = version.Version("1.19.0") PRIORITY_RTN = 60 PRIORITY_GPTQ = 70 diff --git a/onnx_neural_compressor/onnx_model.py b/onnx_neural_compressor/onnx_model.py index 5488615e5..efc9cf9c8 100644 --- a/onnx_neural_compressor/onnx_model.py +++ b/onnx_neural_compressor/onnx_model.py @@ -136,6 +136,7 @@ def is_large_model(self): """Check the onnx model is over 2GB.""" return self._is_large_model + @property def framework(self): """Return framework.""" return "onnxruntime" @@ -201,10 +202,12 @@ def graph(self): """Return model graph.""" return self._model.graph + @property def ir_version(self): """Return model ir_version.""" return self._model.ir_version + @property def opset_import(self): """Return model opset_import.""" return self._model.opset_import @@ -469,16 +472,6 @@ def replace_output_of_all_nodes(self, old_output_name, new_output_name, white_op if node.op_type not in black_optype: ONNXModel.replace_node_output(node, old_output_name, new_output_name) - def remove_duplicate_nodes(self): - """remove duplicate nodes""" - new_nodes = [] - for node in self.nodes(): - if node not in new_nodes: - new_nodes.append(node) - self.model.graph.ClearField("node") - self.model.graph.node.extend(new_nodes) - self.update() - def remove_unused_nodes(self): """Remove unused nodes.""" unused_nodes = [] @@ -870,7 +863,9 @@ def _build_input_output_tensor(self, tensor_name, value_info): tensor_type = value_info.get(tensor_name, onnx.TensorProto.FLOAT) return onnx.helper.make_tensor_value_info(tensor_name, tensor_type, None) - def split_model_with_node(self, split_node_name, path_of_model_to_split, save_both_split_models=True): + def split_model_with_node( + self, split_node_name, path_of_model_to_split, save_both_split_models=True, save_path=None + ): """Split model into two parts at a given node. Args: @@ -880,6 +875,7 @@ def split_model_with_node(self, split_node_name, path_of_model_to_split, save_bo False means only save the first split model. True means save both the two split models. Default id True. + save_path (str): path to save split models. None means using self.model_path Returns: tuple: the first split model, the second split model @@ -971,7 +967,11 @@ def split_model_with_node(self, split_node_name, path_of_model_to_split, save_bo dir_of_model_to_split = os.path.dirname(path_of_model_to_split) split_model_part_1.load_model_initializer_by_tensor(dir_of_model_to_split) - split_model_part_1_path = os.path.join(dir_of_model_to_split, "split_model_part_1.onnx") + split_model_part_1_path = ( + os.path.join(save_path, "split_model_part_1.onnx") + if save_path is not None + else os.path.join(dir_of_model_to_split, "split_model_part_1.onnx") + ) split_model_part_1.model_path = split_model_part_1_path split_model_part_1._save_split_model(split_model_part_1_path) split_model_part_1.check_is_large_model() @@ -979,7 +979,11 @@ def split_model_with_node(self, split_node_name, path_of_model_to_split, save_bo if save_both_split_models: split_model_part_2.load_model_initializer_by_tensor(dir_of_model_to_split) - split_model_part_2_path = os.path.join(dir_of_model_to_split, "split_model_part_2.onnx") + split_model_part_2_path = ( + os.path.join(save_path, "split_model_part_2.onnx") + if save_path is not None + else os.path.join(dir_of_model_to_split, "split_model_part_2.onnx") + ) split_model_part_2.model_path = split_model_part_2_path split_model_part_2._save_split_model(split_model_part_2_path) split_model_part_2.check_is_large_model() @@ -996,6 +1000,7 @@ def _save_split_model(self, save_path): """ if os.path.exists(save_path + "_data"): os.remove(save_path + "_data") + self._model_path = save_path onnx.save_model( self.model, save_path, diff --git a/onnx_neural_compressor/quantization/algorithm_entry.py b/onnx_neural_compressor/quantization/algorithm_entry.py index 12689fa7e..703e250e2 100644 --- a/onnx_neural_compressor/quantization/algorithm_entry.py +++ b/onnx_neural_compressor/quantization/algorithm_entry.py @@ -18,12 +18,15 @@ import onnx import onnxruntime as ort +from packaging import version from onnx_neural_compressor import constants, data_reader, logger, utility from onnx_neural_compressor.algorithms.post_training_quant import calibrate, quantizer from onnx_neural_compressor.algorithms.smoother import core from onnx_neural_compressor.algorithms.weight_only import awq, gptq, rtn -from onnx_neural_compressor.quantization import config +from onnx_neural_compressor.quantization import QuantFormat, config + +ort_version = version.Version(ort.__version__) ###################### RTN Algo Entry ################################## @@ -32,15 +35,14 @@ def rtn_quantize_entry( model: Union[pathlib.Path, str], quant_config: config.RTNConfig, *args, **kwargs ) -> onnx.ModelProto: """The main entry to apply rtn quantization.""" - if len(quant_config.config_mapping) == 0: - # map config to each op - model_info = config.RTNConfig.get_model_info(model=model) - config_mapping = quant_config.to_config_mapping(model_info=model_info) - logger.debug(config_mapping) - else: - config_mapping = quant_config.config_mapping - quant_kwargs = {} - quant_kwargs = {key: getattr(quant_config, key) for key in config.RTNConfig.model_params_list} + config_mapping = quant_config.to_config_mapping(model=model) + + quant_kwargs = dict( + zip( + quant_config.model_params_list, + [getattr(quant_config, key, None) for key in quant_config.model_params_list], + ) + ) model = rtn.apply_rtn_on_model(model, config_mapping, **quant_kwargs) return model @@ -60,15 +62,13 @@ def gptq_quantize_entry( calibration_data_reader, data_reader.CalibrationDataReader ), "Please follow onnx_neural_compressor/data_reader.py to implement calibration_data_reader" - if len(quant_config.config_mapping) == 0: - # map config to each op - model_info = config.GPTQConfig.get_model_info(model=model) - config_mapping = quant_config.to_config_mapping(model_info=model_info) - logger.debug(config_mapping) - else: - config_mapping = quant_config.config_mapping - quant_kwargs = {} - quant_kwargs = {key: getattr(quant_config, key) for key in config.GPTQConfig.model_params_list} + config_mapping = quant_config.to_config_mapping(model=model) + quant_kwargs = dict( + zip( + quant_config.model_params_list, + [getattr(quant_config, key, None) for key in quant_config.model_params_list], + ) + ) # regenerate to ensure data exists calibration_data_reader.rewind() @@ -91,15 +91,13 @@ def awq_quantize_entry( calibration_data_reader, data_reader.CalibrationDataReader ), "Please follow onnx_neural_compressor/data_reader.py to implement calibration_data_reader" - if len(quant_config.config_mapping) == 0: - # map config to each op - model_info = config.AWQConfig.get_model_info(model=model) - config_mapping = quant_config.to_config_mapping(model_info=model_info) - logger.debug(config_mapping) - else: - config_mapping = quant_config.config_mapping - quant_kwargs = {} - quant_kwargs = {key: getattr(quant_config, key) for key in config.AWQConfig.model_params_list} + config_mapping = quant_config.to_config_mapping(model=model) + quant_kwargs = dict( + zip( + quant_config.model_params_list, + [getattr(quant_config, key, None) for key in quant_config.model_params_list], + ) + ) # regenerate to ensure data exists calibration_data_reader.rewind() @@ -126,13 +124,7 @@ def static_quantize_entry( calibration_data_reader, data_reader.CalibrationDataReader ), "Please follow onnx_neural_compressor/quantization/calibrate.py to implement calibration_data_reader" - if len(quant_config.config_mapping) == 0: - # map config to each op - model_info = config.StaticQuantConfig.get_model_info(model=model) - config_mapping = quant_config.to_config_mapping(model_info=model_info) - logger.debug(config_mapping) - else: - config_mapping = quant_config.config_mapping + config_mapping = quant_config.to_config_mapping(model=model) calibration_data_reader.rewind() augment = calibrate.ONNXRTAugment( @@ -184,7 +176,7 @@ def smooth_quant_entry( calibration_data_reader, execution_provider=getattr(quant_config, "execution_provider", "CPUExecutionProvider"), ) - smoothed_model = smoother.transform(**quant_config.to_dict()) + smoothed_model = smoother.transform(**quant_config.get_model_params_dict()) with tempfile.TemporaryDirectory(prefix="ort.quant.") as tmp_dir: # ORT quant API requires str input onnx.save_model( @@ -227,13 +219,7 @@ def dynamic_quantize_entry( logger.warning("No candidate op type to do quantization, exit.") exit(0) - if len(quant_config.config_mapping) == 0: - # map config to each op - model_info = config.DynamicQuantConfig.get_model_info(model=model) - config_mapping = quant_config.to_config_mapping(model_info=model_info) - logger.debug(config_mapping) - else: - config_mapping = quant_config.config_mapping + config_mapping = quant_config.to_config_mapping(model=model) _quantizer = quantizer.DynamicQuantizer( model, diff --git a/onnx_neural_compressor/quantization/config.py b/onnx_neural_compressor/quantization/config.py index f4fe2672e..bc761dcce 100644 --- a/onnx_neural_compressor/quantization/config.py +++ b/onnx_neural_compressor/quantization/config.py @@ -27,8 +27,10 @@ import numpy as np import onnx +import onnxruntime as ort import pydantic from onnxruntime import quantization as ort_quant +from packaging import version from typing_extensions import Self from onnx_neural_compressor import constants, data_reader, logger, quantization, utility @@ -36,6 +38,8 @@ from collections import OrderedDict # isort: skip from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type, Union, _GenericAlias # isort: skip +ort_version = version.Version(ort.__version__) + class ParamLevel(enum.Enum): OP_LEVEL = enum.auto() @@ -201,6 +205,17 @@ class ExampleAlgorithmConfig: return config_registry.register_config_impl(algo_name=algo_name, priority=priority) +class Encoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, quantization.QuantType): + return getattr(o, "tensor_type") + if isinstance(o, quantization.QuantFormat): + return getattr(o, "value") + if isinstance(o, quantization.CalibrationMethod): + return getattr(o, "name") + return super().default(o) + + class BaseConfig(ABC): """The base config for all algorithm configs.""" @@ -210,22 +225,23 @@ class BaseConfig(ABC): def __init__( self, - white_list: Optional[Union[Union[str, Callable], List[Union[str, Callable]]]] = constants.DEFAULT_WHITE_LIST, + white_list: Optional[List[str]] = constants.DEFAULT_WHITE_LIST, ) -> None: self._global_config: Optional[BaseConfig] = None # local config is the collections of operator_type configs and operator configs self._local_config: Dict[str, Optional[BaseConfig]] = {} self._white_list = white_list self._config_mapping = OrderedDict() + self._post_init() def _post_init(self): if self.white_list == constants.DEFAULT_WHITE_LIST: global_config = self.get_init_args() - self._global_config = self.__class__(**global_config, white_list=None) + self._global_config = self.__class__(**global_config, white_list=constants.EMPTY_WHITE_LIST) elif isinstance(self.white_list, list) and len(self.white_list) > 0: for op_name_or_type in self.white_list: global_config = self.get_init_args() - tmp_config = self.__class__(**global_config, white_list=None) + tmp_config = self.__class__(**global_config, white_list=constants.EMPTY_WHITE_LIST) self.set_local(op_name_or_type, tmp_config) elif self.white_list == constants.EMPTY_WHITE_LIST: return @@ -296,6 +312,19 @@ def get_init_args(self): result[param] = value return result + @staticmethod + def get_model_info(model) -> list: + """Get (node_name, optype) pairs of the model.""" + if not isinstance(model, onnx.ModelProto): + model = onnx.load(model, load_external_data=False) + + ops = [] + for node in model.graph.node: + pair = (node.name, node.op_type) + ops.append(pair) + logger.debug(f"Get model info: {ops}") + return ops + def __getitem__(self, key): if hasattr(self, key): return getattr(self, key) @@ -323,7 +352,7 @@ def from_dict(cls, config_dict): operator_config = config_dict.get(constants.LOCAL, {}) if operator_config: for op_name, op_config in operator_config.items(): - config.set_local(op_name, cls(**op_config, white_list=None)) + config.set_local(op_name, cls(**op_config, white_list=constants.EMPTY_WHITE_LIST)) return config def get_diff_dict(self, config) -> Dict[str, Any]: @@ -348,7 +377,7 @@ def from_json_file(cls, filename): def to_json_file(self, filename): config_dict = self.to_dict() with open(filename, "w", encoding="utf-8") as file: - json.dump(config_dict, file, indent=4) + json.dump(config_dict, file, indent=4, cls=Encoder) logger.info("Dump the config into %s.", filename) def to_json_string(self, use_diff: bool = False) -> Union[str, Dict]: @@ -367,7 +396,7 @@ def to_json_string(self, use_diff: bool = False) -> Union[str, Dict]: else: config_dict = self.to_dict() try: - return json.dumps(config_dict, indent=2) + "\n" + return json.dumps(config_dict, indent=2, cls=Encoder) + "\n" except Exception as e: logger.error("Failed to serialize the config to JSON string: %s", e) return config_dict @@ -534,13 +563,19 @@ def _get_op_name_op_type_config(self): return op_type_config_dict, op_name_config_dict def to_config_mapping( - self, config_list: Optional[List[BaseConfig]] = None, model_info: List[Tuple[str, str]] = None + self, + model: Union[onnx.ModelProto, str], + config_list: Optional[List[BaseConfig]] = None, ) -> OrderedDict[Tuple[str, str], OrderedDict[str, BaseConfig]]: if config_list is None: config_list = [self] + model_info = self.get_model_info(model) for config in config_list: + global_config = config.get_params_dict() op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config() for op_name, op_type in model_info: + if global_config is not None: + self._config_mapping[op_name] = global_config if op_type in op_type_config_dict: self._config_mapping[op_name] = op_name_config_dict[op_type] for op_name_pattern in op_name_config_dict: @@ -597,14 +632,15 @@ def from_dict(cls, config_dict: OrderedDict[str, Dict], config_registry: Dict[st return config def to_json_string(self, use_diff: bool = False) -> str: - return json.dumps(self.to_dict(), indent=2) + "\n" + return json.dumps(self.to_dict(), indent=2, cls=Encoder) + "\n" def __repr__(self) -> str: return f"{self.__class__.__name__} {self.to_json_string()}" def to_config_mapping( - self, config_list: List[BaseConfig] = None, model_info: Dict[str, Any] = None + self, model: Union[onnx.ModelProto, str], config_list: List[BaseConfig] = None ) -> OrderedDict[str, BaseConfig]: + model_info = self.get_model_info(model) for config in self.config_list: op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config() single_config_model_info = model_info.get(config.name, None) @@ -711,11 +747,128 @@ class _OperatorConfig(NamedTuple): valid_func_list: List[Callable] = [] +class BaseWeightOnlyConfig(BaseConfig): + """Base config class for weight-only quantization.""" + + def __init__( + self, + weight_dtype: bool = "int", + weight_bits: int = 4, + weight_group_size: int = 32, + weight_sym: bool = True, + act_dtype: str = "fp32", + accuracy_level: int = 0, + providers: List[str] = ["CPUExecutionProvider"], + quant_last_matmul: bool = True, + quant_format: quantization.QuantFormat = quantization.QuantFormat.QOperator, + nodes_to_exclude: list = [], + white_list: List[Union[str, Callable]] = constants.DEFAULT_WHITE_LIST, + ): + """Initialize weight-only quantization config. + + Args: + weight_dtype (str, optional): Data type for weights, support "uint" and "int", default is "int". + weight_bits (int, optional): Number of bits used to represent weights, default is 4. + weight_group_size (int, optional): Size of weight groups, default is 32. + weight_sym (bool, optional): Indicates whether weights are symmetric, default is True. + act_dtype (str, optional): Data type for activations, default is "fp32". + accuracy_level (int, optional): accuracy level. Support 0 (unset), 1(fp32 compute type of jblas kernel), + 2 (fp16 compute type of jblas kernel), 3 (bf16 compute type of jblas kernel), + 4 (int8 compute type of jblas kernel). Defaults to 0. + ratios (dict, optional): percentile of clip. Defaults to {}. + providers (list, optional): execution providers to use. Defaults to ["CPUExecutionProvider"]. + quant_last_matmul (bool, optional): whether to quantize the last matmul of the model, default is True. + quant_format (QuantFormat, optional): use QOperator or QDQ format, default is QOperator. + nodes_to_exclude (list, optional): nodes in nodes_to_exclude list will be skipped during quantization. + white_list (list, optional): op in white_list will be applied current config. + Defaults to constants.DEFAULT_WHITE_LIST. + """ + self.weight_bits = weight_bits + self.weight_dtype = weight_dtype + self.weight_group_size = weight_group_size + self.weight_sym = weight_sym + self.act_dtype = act_dtype + self.accuracy_level = accuracy_level + self.providers = providers + self.quant_last_matmul = quant_last_matmul + self.quant_format = quant_format + self.nodes_to_exclude = nodes_to_exclude + super().__init__(white_list=white_list) + + def get_model_params_dict(self): + result = dict() + for param in self.model_params_list: + result[param] = getattr(self, param) + return result + + def to_config_mapping(self, model: Union[onnx.ModelProto, str], config_list: List[BaseConfig] = None): + if isinstance(model, str): + model = onnx.load(model, load_external_data=False) + + model_info = self.get_model_info(model) + if config_list is None: + config_list = [self] + for config in config_list: + # update model level setting + self._config_mapping.update(config.get_model_params_dict()) + + # update node level setting + last_matmul = None + global_config = config.get_params_dict() + op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config() + for op_name, op_type in model_info: + if op_type not in self.white_list: + continue + + # skip excluded op + if any([re.match(exclude_name, op_name) for exclude_name in self.nodes_to_exclude]): + continue + + if op_type == "MatMul": + last_matmul = op_name + + if global_config is not None: + self._config_mapping[op_name] = global_config + + if op_type in op_type_config_dict: + self._config_mapping[op_name] = op_type_config_dict[op_type] + + for op_name_pattern in op_name_config_dict: + if re.match(op_name_pattern, op_name): + self._config_mapping[op_name] = op_name_config_dict[op_name_pattern] + + # convert config to dict + if op_name in self._config_mapping and hasattr(self._config_mapping[op_name], "to_dict"): + self._config_mapping[op_name] = self._config_mapping[op_name].to_dict() + + # update quant_format + if ( + ort_version < constants.ONNXRT119_VERSION + or model.opset_import[0].version < 21 + or self._config_mapping[op_name].get("weight_bits", 4) not in [4, 8] + ): + self._config_mapping[op_name].update({"quant_format": quantization.QuantFormat.QOperator}) + if ( + self._config_mapping[op_name].get("weight_bits", 4) != 4 + or ort_version < constants.ONNXRT116_VERSION + or ( + ort_version <= constants.ONNXRT1161_VERSION + and self._config_mapping[op_name].get("weight_group_size", 32) != 32 + ) + ): + # MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions + # MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1 + del self._config_mapping[op_name]["quant_format"] + if not self.quant_last_matmul and last_matmul is not None and last_matmul in self._config_mapping: + del self._config_mapping[last_matmul] + return self._config_mapping + + ######################## RNT Config ############################### @register_config(algo_name=constants.RTN, priority=constants.PRIORITY_RTN) -class RTNConfig(BaseConfig): +class RTNConfig(BaseWeightOnlyConfig): """Config class for round-to-nearest weight-only quantization.""" supported_configs: List[_OperatorConfig] = [] @@ -727,6 +880,7 @@ class RTNConfig(BaseConfig): "act_dtype", "accuracy_level", "ratios", + "quant_format", ] model_params_list: List[str] = [ "providers", @@ -746,12 +900,14 @@ def __init__( providers: List[str] = ["CPUExecutionProvider"], layer_wise_quant: bool = False, quant_last_matmul: bool = True, - white_list: List[Union[str, Callable]] = constants.RTN_OP_LIST, + quant_format: quantization.QuantFormat = quantization.QuantFormat.QOperator, + nodes_to_exclude: List[str] = [], + white_list: List[str] = constants.RTN_OP_LIST, ): """Init RTN weight-only quantization config. Args: - weight_dtype (str, optional): Data type for weights, default is "int". + weight_dtype (str, optional): Data type for weights, support "uint" and "int", default is "int". weight_bits (int, optional): Number of bits used to represent weights, default is 4. weight_group_size (int, optional): Size of weight groups, default is 32. weight_sym (bool, optional): Indicates whether weights are symmetric, default is True. @@ -766,39 +922,27 @@ def __init__( https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_layer_wise.md, default is False. quant_last_matmul (bool, optional): whether to quantize the last matmul of the model, default is True. + quant_format (QuantFormat, optional): use QOperator or QDQ format, default is QOperator. + nodes_to_exclude (list, optional): nodes in nodes_to_exclude list will be skipped during quantization. white_list (list, optional): op in white_list will be applied current config. - Defaults to constants.DEFAULT_WHITE_LIST. """ - super().__init__(white_list=white_list) - self.weight_bits = weight_bits - self.weight_dtype = weight_dtype - self.weight_group_size = weight_group_size - self.weight_sym = weight_sym - self.act_dtype = act_dtype - self.accuracy_level = accuracy_level - self.ratios = ratios - self.providers = providers self.layer_wise_quant = layer_wise_quant - self.quant_last_matmul = quant_last_matmul - self._post_init() - - def _post_init(self): - if self.white_list == constants.RTN_OP_LIST: - global_config = self.get_init_args() - self._global_config = self.__class__(**global_config, white_list=None) - elif isinstance(self.white_list, list) and len(self.white_list) > 0: - for op_name_or_type in self.white_list: - global_config = self.get_init_args() - tmp_config = self.__class__(**global_config, white_list=None) - self.set_local(op_name_or_type, tmp_config) - elif self.white_list == constants.EMPTY_WHITE_LIST: - return + self.ratios = ratios - def get_model_params_dict(self): - result = dict() - for param in self.model_params_list: - result[param] = getattr(self, param) - return result + super().__init__( + weight_bits=weight_bits, + weight_dtype=weight_dtype, + weight_group_size=weight_group_size, + weight_sym=weight_sym, + act_dtype=act_dtype, + accuracy_level=accuracy_level, + providers=providers, + quant_last_matmul=quant_last_matmul, + quant_format=quant_format, + nodes_to_exclude=nodes_to_exclude, + white_list=white_list if white_list != constants.RTN_OP_LIST else constants.DEFAULT_WHITE_LIST, + ) + self.white_list = white_list @classmethod def register_supported_configs(cls) -> None: @@ -814,46 +958,6 @@ def register_supported_configs(cls) -> None: supported_configs.append(_OperatorConfig(config=linear_rtn_config, operators=operators)) cls.supported_configs = supported_configs - def to_config_mapping(self, config_list: List[BaseConfig] = None, model_info: list = None): - if config_list is None: - config_list = [self] - for config in config_list: - # update model level setting - self._config_mapping.update(config.get_model_params_dict()) - - # update node level setting - last_matmul = None - global_config = config.get_params_dict() - op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config() - for op_name, op_type in model_info: - if op_type == "MatMul": - last_matmul = op_name - if global_config is not None: - self._config_mapping[op_name] = global_config - if op_type in op_type_config_dict: - self._config_mapping[op_name] = op_type_config_dict[op_type] - for op_name_pattern in op_name_config_dict: - if re.match(op_name_pattern, op_name): - self._config_mapping[op_name] = op_name_config_dict[op_name_pattern] - if op_name in self._config_mapping and hasattr(self._config_mapping[op_name], "to_dict"): - self._config_mapping[op_name] = self._config_mapping[op_name].to_dict() - if not self.quant_last_matmul and last_matmul is not None and last_matmul in self._config_mapping: - del self._config_mapping[last_matmul] - return self._config_mapping - - @staticmethod - def get_model_info(model: Union[onnx.ModelProto, pathlib.Path, str], white_list=constants.RTN_OP_LIST) -> list: - if not isinstance(model, onnx.ModelProto): - model = onnx.load(model, load_external_data=False) - - filter_result = [] - for node in model.graph.node: - if node.op_type in white_list: - pair = (node.name, node.op_type) - filter_result.append(pair) - logger.debug(f"Get model info: {filter_result}") - return filter_result - @classmethod def get_config_set_for_tuning(cls) -> Union[None, "RTNConfig", List["RTNConfig"]]: # pragma: no cover return RTNConfig(weight_bits=[4, 8], weight_sym=[True, False]) @@ -872,7 +976,7 @@ def get_default_rtn_config() -> RTNConfig: @register_config(algo_name=constants.GPTQ, priority=constants.PRIORITY_GPTQ) -class GPTQConfig(BaseConfig): +class GPTQConfig(BaseWeightOnlyConfig): """Config class for gptq weight-only quantization.""" supported_configs: List[_OperatorConfig] = [] @@ -883,6 +987,7 @@ class GPTQConfig(BaseConfig): "weight_sym", "act_dtype", "accuracy_level", + "quant_format", ] model_params_list: List[Union[str, TuningParam]] = [ "percdamp", @@ -911,7 +1016,9 @@ def __init__( providers: List[str] = ["CPUExecutionProvider"], layer_wise_quant: bool = False, quant_last_matmul: bool = True, - white_list: List[Union[str, Callable]] = constants.GPTQ_OP_LIST, + quant_format: quantization.QuantFormat = quantization.QuantFormat.QOperator, + nodes_to_exclude: List[str] = [], + white_list: List[str] = constants.GPTQ_OP_LIST, ): """Init GPTQ weight-only quantization config. @@ -937,43 +1044,31 @@ def __init__( https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_layer_wise.md, default is False. quant_last_matmul (bool, optional): whether to quantize the last matmul of the model, default is True. + quant_format (QuantFormat, optional): use QOperator or QDQ format, default is QOperator. + nodes_to_exclude (list, optional): nodes in nodes_to_exclude list will be skipped during quantization. white_list (list, optional): op in white_list will be applied current config. - Defaults to constants.DEFAULT_WHITE_LIST. """ - super().__init__(white_list=white_list) - self.weight_bits = weight_bits - self.weight_dtype = weight_dtype - self.weight_group_size = weight_group_size - self.weight_sym = weight_sym - self.act_dtype = act_dtype - self.accuracy_level = accuracy_level self.percdamp = percdamp self.block_size = block_size self.actorder = actorder self.mse = mse self.perchannel = perchannel - self.providers = providers self.layer_wise_quant = layer_wise_quant - self.quant_last_matmul = quant_last_matmul - self._post_init() - def _post_init(self): - if self.white_list == constants.GPTQ_OP_LIST: - global_config = self.get_init_args() - self._global_config = self.__class__(**global_config, white_list=None) - elif isinstance(self.white_list, list) and len(self.white_list) > 0: - for op_name_or_type in self.white_list: - global_config = self.get_init_args() - tmp_config = self.__class__(**global_config, white_list=None) - self.set_local(op_name_or_type, tmp_config) - elif self.white_list == constants.EMPTY_WHITE_LIST: - return - - def get_model_params_dict(self): - result = dict() - for param in self.model_params_list: - result[param] = getattr(self, param) - return result + super().__init__( + weight_bits=weight_bits, + weight_dtype=weight_dtype, + weight_group_size=weight_group_size, + weight_sym=weight_sym, + act_dtype=act_dtype, + accuracy_level=accuracy_level, + providers=providers, + quant_last_matmul=quant_last_matmul, + quant_format=quant_format, + nodes_to_exclude=nodes_to_exclude, + white_list=white_list if white_list != constants.GPTQ_OP_LIST else constants.DEFAULT_WHITE_LIST, + ) + self.white_list = white_list @classmethod def register_supported_configs(cls) -> None: @@ -992,46 +1087,6 @@ def register_supported_configs(cls) -> None: supported_configs.append(_OperatorConfig(config=linear_gptq_config, operators=operators)) cls.supported_configs = supported_configs - def to_config_mapping(self, config_list: list = None, model_info: list = None) -> OrderedDict: - if config_list is None: - config_list = [self] - for config in config_list: - # update model level setting - self._config_mapping.update(config.get_model_params_dict()) - - # update node level setting - last_matmul = None - global_config = config.get_params_dict() - op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config() - for op_name, op_type in model_info: - if op_type == "MatMul": - last_matmul = op_name - if global_config is not None: - self._config_mapping[op_name] = global_config - if op_type in op_type_config_dict: - self._config_mapping[op_name] = op_type_config_dict[op_type] - for op_name_pattern in op_name_config_dict: - if re.match(op_name_pattern, op_name): - self._config_mapping[op_name] = op_name_config_dict[op_name_pattern] - if op_name in self._config_mapping and hasattr(self._config_mapping[op_name], "to_dict"): - self._config_mapping[op_name] = self._config_mapping[op_name].to_dict() - if not self.quant_last_matmul and last_matmul is not None and last_matmul in self._config_mapping: - del self._config_mapping[last_matmul] - return self._config_mapping - - @staticmethod - def get_model_info(model: Union[onnx.ModelProto, pathlib.Path, str], white_list=constants.GPTQ_OP_LIST) -> list: - if not isinstance(model, onnx.ModelProto): - model = onnx.load(model, load_external_data=False) - - filter_result = [] - for node in model.graph.node: - if node.op_type in white_list: - pair = (node.name, node.op_type) - filter_result.append(pair) - logger.debug(f"Get model info: {filter_result}") - return filter_result - @classmethod def get_config_set_for_tuning(cls) -> Union[None, "GPTQConfig", List["GPTQConfig"]]: # pragma: no cover return GPTQConfig( @@ -1056,7 +1111,7 @@ def get_default_gptq_config() -> GPTQConfig: @register_config(algo_name=constants.AWQ, priority=constants.PRIORITY_AWQ) -class AWQConfig(BaseConfig): +class AWQConfig(BaseWeightOnlyConfig): """Config class for awq weight-only quantization.""" supported_configs: List[_OperatorConfig] = [] @@ -1067,6 +1122,7 @@ class AWQConfig(BaseConfig): "weight_sym", "act_dtype", "accuracy_level", + "quant_format", ] model_params_list: List[str] = [ "enable_auto_scale", @@ -1087,7 +1143,9 @@ def __init__( enable_mse_search: bool = True, providers: List[str] = ["CPUExecutionProvider"], quant_last_matmul: bool = True, - white_list: List[Union[str, Callable]] = constants.AWQ_OP_LIST, + quant_format: quantization.QuantFormat = quantization.QuantFormat.QOperator, + nodes_to_exclude: List[str] = [], + white_list: List[str] = constants.AWQ_OP_LIST, ): """Init AWQ weight-only quantization config. @@ -1106,39 +1164,27 @@ def __init__( [0.91, 1.0, 0.01]. Defaults to True. providers (list, optional): execution providers to use. Defaults to ["CPUExecutionProvider"]. quant_last_matmul (bool, optional): whether to quantize the last matmul of the model, default is True. + quant_format (QuantFormat, optional): use QOperator or QDQ format, default is QOperator. + nodes_to_exclude (list, optional): nodes in nodes_to_exclude list will be skipped during quantization. white_list (list, optional): op in white_list will be applied current config. - Defaults to constants.DEFAULT_WHITE_LIST. """ - super().__init__(white_list=white_list) - self.weight_bits = weight_bits - self.weight_dtype = weight_dtype - self.weight_group_size = weight_group_size - self.weight_sym = weight_sym - self.act_dtype = act_dtype - self.accuracy_level = accuracy_level self.enable_auto_scale = enable_auto_scale self.enable_mse_search = enable_mse_search - self.providers = providers - self.quant_last_matmul = quant_last_matmul - self._post_init() - def _post_init(self): - if self.white_list == constants.GPTQ_OP_LIST: - global_config = self.get_init_args() - self._global_config = self.__class__(**global_config, white_list=None) - elif isinstance(self.white_list, list) and len(self.white_list) > 0: - for op_name_or_type in self.white_list: - global_config = self.get_init_args() - tmp_config = self.__class__(**global_config, white_list=None) - self.set_local(op_name_or_type, tmp_config) - elif self.white_list == constants.EMPTY_WHITE_LIST: - return - - def get_model_params_dict(self): - result = dict() - for param in self.model_params_list: - result[param] = getattr(self, param) - return result + super().__init__( + weight_bits=weight_bits, + weight_dtype=weight_dtype, + weight_group_size=weight_group_size, + weight_sym=weight_sym, + act_dtype=act_dtype, + accuracy_level=accuracy_level, + providers=providers, + quant_last_matmul=quant_last_matmul, + quant_format=quant_format, + nodes_to_exclude=nodes_to_exclude, + white_list=white_list if white_list != constants.AWQ_OP_LIST else constants.DEFAULT_WHITE_LIST, + ) + self.white_list = white_list @classmethod def register_supported_configs(cls) -> List[_OperatorConfig]: @@ -1156,46 +1202,6 @@ def register_supported_configs(cls) -> List[_OperatorConfig]: supported_configs.append(_OperatorConfig(config=linear_awq_config, operators=operators)) cls.supported_configs = supported_configs - def to_config_mapping(self, config_list: list = None, model_info: list = None) -> OrderedDict: - if config_list is None: - config_list = [self] - for config in config_list: - # update model level setting - self._config_mapping.update(config.get_model_params_dict()) - - # update node level setting - last_matmul = None - global_config = config.get_params_dict() - op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config() - for op_name, op_type in model_info: - if op_type == "MatMul": - last_matmul = op_name - if global_config is not None: - self._config_mapping[op_name] = global_config - if op_type in op_type_config_dict: - self._config_mapping[op_name] = op_type_config_dict[op_type] - for op_name_pattern in op_name_config_dict: - if re.match(op_name_pattern, op_name): - self._config_mapping[op_name] = op_name_config_dict[op_name_pattern] - if op_name in self._config_mapping and hasattr(self._config_mapping[op_name], "to_dict"): - self._config_mapping[op_name] = self._config_mapping[op_name].to_dict() - if not self.quant_last_matmul and last_matmul is not None and last_matmul in self._config_mapping: - del self._config_mapping[last_matmul] - return self._config_mapping - - @staticmethod - def get_model_info(model: Union[onnx.ModelProto, pathlib.Path, str], white_list=constants.AWQ_OP_LIST) -> list: - if not isinstance(model, onnx.ModelProto): - model = onnx.load(model, load_external_data=False) - - filter_result = [] - for node in model.graph.node: - if node.op_type in white_list: - pair = (node.name, node.op_type) - filter_result.append(pair) - logger.debug(f"Get model info: {filter_result}") - return filter_result - @classmethod def get_config_set_for_tuning(cls) -> Union[None, "AWQConfig", List["AWQConfig"]]: # pragma: no cover return AWQConfig( @@ -1218,17 +1224,17 @@ def get_default_awq_config() -> AWQConfig: ######################## WOQ Tuning Config ############################### -def get_woq_tuning_config() -> list: +def get_woq_tuning_config(quant_format=quantization.QuantFormat.QOperator) -> list: """Generate the config set for WOQ tuning. Returns: the list of WOQ quant config. """ - RTN_G32ASYM = RTNConfig(weight_sym=False) - GPTQ_G32ASYM = GPTQConfig(weight_sym=False) - GPTQ_G32ASYM_DISABLE_LAST_MATMUL = GPTQConfig(weight_sym=False, quant_last_matmul=False) - GPTQ_G128ASYM = GPTQConfig(weight_group_size=128, weight_sym=False) - AWQ_G32ASYM = AWQConfig(weight_sym=False) + RTN_G32ASYM = RTNConfig(weight_sym=False, quant_format=quant_format) + GPTQ_G32ASYM = GPTQConfig(weight_sym=False, quant_format=quant_format) + GPTQ_G32ASYM_DISABLE_LAST_MATMUL = GPTQConfig(weight_sym=False, quant_last_matmul=False, quant_format=quant_format) + GPTQ_G128ASYM = GPTQConfig(weight_group_size=128, weight_sym=False, quant_format=quant_format) + AWQ_G32ASYM = AWQConfig(weight_sym=False, quant_format=quant_format) return [RTN_G32ASYM, GPTQ_G32ASYM, GPTQ_G32ASYM_DISABLE_LAST_MATMUL, GPTQ_G128ASYM, AWQ_G32ASYM] @@ -1537,7 +1543,6 @@ def __init__( calibration_sampling_size=100, quant_last_matmul=True, execution_provider=None, - white_list: list = constants.DEFAULT_WHITE_LIST, **kwargs, ): """This is a class for static Quant Configuration. @@ -1607,7 +1612,6 @@ def __init__( else: os.environ["ORT_TENSORRT_UNAVAILABLE"] = "1" - BaseConfig.__init__(self, white_list=self.op_types_to_quantize) self.execution_provider = execution_provider self.quant_last_matmul = quant_last_matmul self.calibration_sampling_size = calibration_sampling_size @@ -1617,21 +1621,7 @@ def __init__( self.optypes_to_exclude_output_quant = _extra_options.OpTypesToExcludeOutputQuantization self.dedicated_qdq_pair = _extra_options.DedicatedQDQPair self.add_qdq_pair_to_weight = _extra_options.AddQDQPairToWeight - self.white_list = white_list - self._post_init() - - @staticmethod - def get_model_info(model, white_list=constants.STATIC_QOPERATOR_CPU_OP_LIST) -> list: - if not isinstance(model, onnx.ModelProto): - model = onnx.load(model, load_external_data=False) - - filter_result = [] - for node in model.graph.node: - if node.op_type in white_list: - pair = (node.name, node.op_type) - filter_result.append(pair) - logger.debug(f"Get model info: {filter_result}") - return filter_result + BaseConfig.__init__(self, white_list=self.op_types_to_quantize) def get_model_params_dict(self): result = dict() @@ -1647,13 +1637,13 @@ def _post_init(self): for valid_func in STATIC_CHECK_FUNC_LIST: op_config = valid_func(op_config, op_name_or_type, self.execution_provider, self.quant_format) self.set_local(op_name_or_type, op_config) - if isinstance(self.white_list, list) and len(self.white_list) > 0: - for op_name_or_type in self.white_list: - global_config = self.get_init_args() - tmp_config = self.__class__(**global_config, white_list=None) - self.set_local(op_name_or_type, tmp_config) - def to_config_mapping(self, config_list: list = None, model_info: list = None) -> OrderedDict: + def to_config_mapping(self, model: Union[onnx.ModelProto, str], config_list: List[BaseConfig] = None): + if isinstance(model, str): + model = onnx.load(model, load_external_data=False) + + model_info = self.get_model_info(model) + if config_list is None: config_list = [self] for config in config_list: @@ -1859,34 +1849,6 @@ def register_supported_configs(cls) -> None: ) cls.supported_configs = supported_configs - def to_dict(self): - result = {} - for key, val in self.__dict__.items(): - if key in ["_global_config", "_config_mapping"]: - continue - if key == "_local_config": - local_result = {} - for name, cfg in val.items(): - local_result[name] = cfg.to_dict() - result[key] = local_result - continue - if not isinstance(val, list): - result[key] = ( - getattr(val, "tensor_type", val) - if isinstance(val, quantization.QuantType) - else getattr(val, "value", val) - ) - else: - result[key] = [ - ( - getattr(item, "tensor_type", item) - if isinstance(item, quantization.QuantType) - else getattr(item, "value", item) - ) - for item in val - ] - return result - ######################## SmoohQuant Config ############################### @@ -1922,7 +1884,6 @@ def __init__( calib_iter: int = 100, scales_per_op: bool = True, auto_alpha_args: dict = {"alpha_min": 0.3, "alpha_max": 0.7, "alpha_step": 0.05, "attn_method": "min"}, - white_list: list = None, **kwargs, ): """Init smooth quant config. @@ -1942,7 +1903,7 @@ def __init__( kwargs (dict): kwargs in below link are supported except calibration_data_reader: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/quantize.py#L78 """ - super().__init__(white_list=white_list, **kwargs) + super().__init__(**kwargs) self.alpha = alpha self.folding = folding self.op_types = op_types @@ -1958,19 +1919,6 @@ def register_supported_configs(cls) -> List[_OperatorConfig]: supported_configs.append(_OperatorConfig(config=smooth_quant_config, operators=operators)) cls.supported_configs = supported_configs - @staticmethod - def get_model_info(model, white_list=["Gemm", "Conv", "MatMul", "FusedConv"]) -> list: - if not isinstance(model, onnx.ModelProto): - model = onnx.load(model, load_external_data=False) - - filter_result = [] - for node in model.graph.node: - if node.op_type in white_list: - pair = (node.name, node.op_type) - filter_result.append(pair) - logger.debug(f"Get model info: {filter_result}") - return filter_result - @classmethod def get_config_set_for_tuning( cls, @@ -2022,7 +1970,6 @@ def __init__( extra_options: dict = None, quant_last_matmul: bool = True, execution_provider: str = None, - white_list: list = constants.DEFAULT_WHITE_LIST, **kwargs, ): if execution_provider is None: @@ -2044,28 +1991,13 @@ def __init__( use_external_data_format=use_external_data_format, extra_options=extra_options, ) - BaseConfig.__init__(self, white_list=op_types_to_quantize) self.execution_provider = execution_provider self.quant_last_matmul = quant_last_matmul self.activation_type = quantization.QuantType.QUInt8 _extra_options = ExtraOptions(**self.extra_options) self.weight_sym = _extra_options.WeightSymmetric self.activation_sym = _extra_options.ActivationSymmetric - self.white_list = white_list - self._post_init() - - @staticmethod - def get_model_info(model, white_list=constants.DYNAMIC_CPU_OP_LIST) -> list: - if not isinstance(model, onnx.ModelProto): - model = onnx.load(model, load_external_data=False) - - filter_result = [] - for node in model.graph.node: - if node.op_type in white_list: - pair = (node.name, node.op_type) - filter_result.append(pair) - logger.debug(f"Get model info: {filter_result}") - return filter_result + BaseConfig.__init__(self, white_list=op_types_to_quantize) def get_model_params_dict(self): result = dict() @@ -2080,13 +2012,13 @@ def _post_init(self): for valid_func in DYNAMIC_CHECK_FUNC_LIST: op_config = valid_func(op_config, op_name_or_type, self.execution_provider) self.set_local(op_name_or_type, op_config) - if isinstance(self.white_list, list) and len(self.white_list) > 0: - for op_name_or_type in self.white_list: - global_config = self.get_init_args() - tmp_config = self.__class__(**global_config, white_list=None) - self.set_local(op_name_or_type, tmp_config) - def to_config_mapping(self, config_list: list = None, model_info: list = None) -> OrderedDict: + def to_config_mapping(self, model: Union[onnx.ModelProto, str], config_list: List[BaseConfig] = None): + if isinstance(model, str): + model = onnx.load(model, load_external_data=False) + + model_info = self.get_model_info(model) + if config_list is None: config_list = [self] for config in config_list: @@ -2221,34 +2153,6 @@ def register_supported_configs(cls) -> None: ) cls.supported_configs = supported_configs - def to_dict(self): - result = {} - for key, val in self.__dict__.items(): - if key in ["_global_config", "_config_mapping"]: - continue - if key == "_local_config": - local_result = {} - for name, cfg in val.items(): - local_result[name] = cfg.to_dict() - result[key] = local_result - continue - if not isinstance(val, list): - result[key] = ( - getattr(val, "tensor_type", val) - if isinstance(val, quantization.QuantType) - else getattr(val, "value", val) - ) - else: - result[key] = [ - ( - getattr(item, "tensor_type", item) - if isinstance(item, quantization.QuantType) - else getattr(item, "value", item) - ) - for item in val - ] - return result - ##################### NC Algo Configs End ################################### diff --git a/onnx_neural_compressor/quantization/matmul_4bits_quantizer.py b/onnx_neural_compressor/quantization/matmul_4bits_quantizer.py index 41c58a29f..78f2cb88a 100644 --- a/onnx_neural_compressor/quantization/matmul_4bits_quantizer.py +++ b/onnx_neural_compressor/quantization/matmul_4bits_quantizer.py @@ -31,6 +31,7 @@ def __init__( model: Union[onnx.ModelProto, str], block_size: int = 128, is_symmetric: bool = False, + is_signed: bool = False, accuracy_level: int = 0, nodes_to_exclude=None, algo_config: matmul_nbits_quantizer.WeightOnlyQuantConfig = None, @@ -41,6 +42,7 @@ def __init__( model=model, block_size=block_size, is_symmetric=is_symmetric, + is_signed=is_signed, accuracy_level=accuracy_level, nodes_to_exclude=nodes_to_exclude, algo_config=algo_config, diff --git a/onnx_neural_compressor/quantization/matmul_nbits_quantizer.py b/onnx_neural_compressor/quantization/matmul_nbits_quantizer.py index 99bf760e9..11822c0fc 100644 --- a/onnx_neural_compressor/quantization/matmul_nbits_quantizer.py +++ b/onnx_neural_compressor/quantization/matmul_nbits_quantizer.py @@ -21,12 +21,13 @@ import onnxruntime as ort from onnx_neural_compressor import data_reader, logger, onnx_model, utility +from onnx_neural_compressor.quantization import QuantFormat from onnx_neural_compressor.quantization import algorithm_entry as algos from onnx_neural_compressor.quantization import config class WeightOnlyQuantConfig: - def __init__(self, algorithm): + def __init__(self, algorithm, quant_format=QuantFormat.QOperator): """This is the Base class for Weight Only Quant Configuration. Args: @@ -34,13 +35,15 @@ def __init__(self, algorithm): weight only quantize algorithm name. """ self.algorithm = algorithm + self.quant_format = quant_format class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig): - def __init__(self, ratios=None, layer_wise_quant=False): + def __init__(self, ratios=None, layer_wise_quant=False, quant_format=QuantFormat.QOperator): super().__init__( algorithm="RTN", + quant_format=quant_format, ) if ratios is None: ratios = {} @@ -59,9 +62,11 @@ def __init__( mse=False, perchannel=True, layer_wise_quant=False, + quant_format=QuantFormat.QOperator, ): super().__init__( algorithm="GPTQ", + quant_format=quant_format, ) self.calibration_data_reader = calibration_data_reader self.percdamp = percdamp @@ -79,8 +84,9 @@ def __init__( calibration_data_reader: data_reader.CalibrationDataReader, enable_auto_scale=True, enable_mse_search=True, + quant_format=QuantFormat.QOperator, ): - super().__init__(algorithm="AWQ") + super().__init__(algorithm="AWQ", quant_format=quant_format) self.calibration_data_reader = calibration_data_reader self.enable_auto_scale = enable_auto_scale self.enable_mse_search = enable_mse_search @@ -100,6 +106,7 @@ def __init__( model: Union[onnx.ModelProto, str], block_size: int = 128, is_symmetric: bool = False, + is_signed: bool = False, accuracy_level: int = 0, nodes_to_exclude: List[str] = None, algo_config: WeightOnlyQuantConfig = None, @@ -112,6 +119,7 @@ def __init__( self.model = model self.block_size = block_size self.is_symmetric = is_symmetric + self.is_signed = is_signed self.accuracy_level = accuracy_level self.nodes_to_exclude = list(set(nodes_to_exclude)) self.algo_config = algo_config or RTNWeightOnlyQuantConfig() @@ -128,11 +136,14 @@ def __init__( def _generate_nc_config(self): config_class = config.config_registry.get_cls_configs()[self.algorithm.lower()] quant_kwargs = { + "weight_dtype": "int" if self.is_signed else "uint", "weight_bits": self.n_bits, "weight_group_size": self.block_size, "weight_sym": self.is_symmetric, "accuracy_level": self.accuracy_level, "providers": self.providers, + "quant_format": self.algo_config.quant_format, + "nodes_to_exclude": self.nodes_to_exclude, } if self.algorithm == "RTN": quant_kwargs.update( @@ -160,10 +171,6 @@ def _generate_nc_config(self): ) nc_config = config_class(**quant_kwargs) - if len(self.nodes_to_exclude) > 0: - not_quant_kwargs = {"weight_dtype": "fp32", "white_list": self.nodes_to_exclude} - nc_config += config_class(**not_quant_kwargs) - return nc_config def int4_quant_algo(self): diff --git a/onnx_neural_compressor/quantization/quant_utils.py b/onnx_neural_compressor/quantization/quant_utils.py index 2d5518857..348fa8cdb 100644 --- a/onnx_neural_compressor/quantization/quant_utils.py +++ b/onnx_neural_compressor/quantization/quant_utils.py @@ -25,6 +25,8 @@ class QuantType(enum.Enum): # pragma: no cover QInt8 = 0 QUInt8 = 1 + QInt4 = 4 + QUInt4 = 5 @property def tensor_type(self): @@ -32,6 +34,10 @@ def tensor_type(self): return onnx.TensorProto.INT8 if self == QuantType.QUInt8: return onnx.TensorProto.UINT8 + if self == QuantType.QInt8: + return onnx.TensorProto.INT4 + if self == QuantType.QUInt4: + return onnx.TensorProto.UINT4 raise ValueError(f"Unexpected value qtype={self!r}.") diff --git a/onnx_neural_compressor/quantization/tuning.py b/onnx_neural_compressor/quantization/tuning.py index 385ac63c0..862c2f40f 100644 --- a/onnx_neural_compressor/quantization/tuning.py +++ b/onnx_neural_compressor/quantization/tuning.py @@ -529,8 +529,7 @@ def autotune( tuning_logger.tuning_start() for trial_index, quant_config in enumerate(config_loader): # check whether config_mapping is verified - model_info = quant_config.__class__.get_model_info(model=model_input) - config_mapping = quant_config.to_config_mapping(model_info=model_info) + config_mapping = quant_config.to_config_mapping(model=model_input) if tuning_monitor.need_skip(config_mapping): continue diff --git a/test/quantization/post_training_quant/test_operators.py b/test/quantization/post_training_quant/test_operators.py index 45c189328..c06759f3c 100644 --- a/test/quantization/post_training_quant/test_operators.py +++ b/test/quantization/post_training_quant/test_operators.py @@ -78,6 +78,9 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): shutil.rmtree("./onnxrt_test", ignore_errors=True) + os.remove("int8.onnx") + os.remove("qdq.onnx") + os.remove("test.onnx") def qlinear_test(self, model, q_config, quantize_params, quantizable_op_types, **kwargs): quant = quantizer.StaticQuantizer( diff --git a/test/quantization/test_algorithm_utility.py b/test/quantization/test_algorithm_utility.py index 4301545c7..5c899c828 100644 --- a/test/quantization/test_algorithm_utility.py +++ b/test/quantization/test_algorithm_utility.py @@ -6,6 +6,7 @@ import numpy as np import onnx +from onnx_neural_compressor import onnx_model from onnx_neural_compressor.algorithms import utility as quant_utils @@ -40,3 +41,49 @@ def test_is_B_transposed(self): beta=0.35, ) self.assertFalse(quant_utils.is_B_transposed(node)) + + def test_make_woq_dq_node(self): + node = onnx.helper.make_node("MatMul", ["input", "weight"], "output", name="Matmul") + with self.assertRaises(ValueError): + quant_utils.make_weight_only_dequant_node( + node=node, + weight_shape=(32, 32), + block_size=16, + num_bits=32, + dtype="int", + q_weight=np.random.randint(0, 10, size=(2, 32), dtype=np.uint8), + scale=np.random.random((2, 32)), + zero_point=np.zeros((2, 32)), + ) + + def test_split_shared_bias(self): + input = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 15, 15]) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 5, 11, 11]) + bias_initializer = onnx.numpy_helper.from_array(np.random.random(5).astype(np.float32), name="bias") + conv1_weight_initializer = onnx.numpy_helper.from_array( + np.random.randint(-1, 2, [5, 3, 3, 3]).astype(np.float32), name="conv1_weight" + ) + conv1_node = onnx.helper.make_node("Conv", ["add_out", "conv1_weight", "bias"], ["conv1_output"], name="conv1") + conv2_weight_initializer = onnx.numpy_helper.from_array( + np.random.randint(-1, 2, [5, 5, 3, 3]).astype(np.float32), name="conv2_weight" + ) + conv2_node = onnx.helper.make_node("Conv", ["add_out", "conv2_weight", "bias"], ["conv2_output"], name="conv2") + initializers = [conv1_weight_initializer, conv2_weight_initializer, bias_initializer] + graph = onnx.helper.make_graph([conv1_node, conv2_node], "test", [input], [output], initializer=initializers) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 13)]) + + update_model = quant_utils.split_shared_bias(onnx_model.ONNXModel(model)) + split = any(["_nc_split_" in i.name for i in update_model.initializer()]) + self.assertTrue(split) + + def test_get_qmin_qmax_for_qType(self): + with self.assertRaises(ValueError): + quant_utils.get_qmin_qmax_for_qType(onnx.TensorProto.INT64) + + qmin, qmax = quant_utils.get_qmin_qmax_for_qType(onnx.TensorProto.INT8, reduce_range=True) + self.assertEqual(qmin, -64) + self.assertEqual(qmax, 64) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/quantization/test_config.py b/test/quantization/test_config.py index 81ccd245d..7b33f9b9d 100644 --- a/test/quantization/test_config.py +++ b/test/quantization/test_config.py @@ -100,8 +100,7 @@ def test_dynamic_quant_config(self): ) config_loader = tuning.ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler) for idx, quant_config in enumerate(config_loader): - model_info = quant_config.get_model_info(model=self.simple_onnx_model) - configs_mapping = quant_config.to_config_mapping(model_info=model_info) + configs_mapping = quant_config.to_config_mapping(model=self.simple_onnx_model) if idx == 0: self.assertTrue(configs_mapping["Matmul"]["per_channel"]) elif idx == 1: @@ -125,8 +124,7 @@ def test_dynamic_quant_config(self): ) config_loader = tuning.ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler) for idx, quant_config in enumerate(config_loader): - model_info = quant_config.get_model_info(model=self.simple_onnx_model) - configs_mapping = quant_config.to_config_mapping(model_info=model_info) + configs_mapping = quant_config.to_config_mapping(model=self.simple_onnx_model) self.assertTrue("add" not in configs_mapping) self.assertTrue("add2" not in configs_mapping) self.assertTrue("Matmul" not in configs_mapping) @@ -143,8 +141,7 @@ def test_dynamic_custom_quant_config(self): ) config_loader = tuning.ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler) for idx, quant_config in enumerate(config_loader): - model_info = quant_config.get_model_info(model=self.simple_onnx_model) - configs_mapping = quant_config.to_config_mapping(model_info=model_info) + configs_mapping = quant_config.to_config_mapping(model=self.simple_onnx_model) if idx == 0: self.assertTrue(configs_mapping["Matmul"]["per_channel"]) elif idx == 1: @@ -161,8 +158,7 @@ def test_dynamic_custom_quant_config(self): ) config_loader = tuning.ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler) for idx, quant_config in enumerate(config_loader): - model_info = quant_config.get_model_info(model=self.simple_onnx_model) - configs_mapping = quant_config.to_config_mapping(model_info=model_info) + configs_mapping = quant_config.to_config_mapping(model=self.simple_onnx_model) self.assertTrue("add" not in configs_mapping) self.assertTrue("add2" not in configs_mapping) self.assertTrue("Matmul" not in configs_mapping) @@ -179,8 +175,7 @@ def test_static_quant_config(self): ) config_loader = tuning.ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler) for idx, quant_config in enumerate(config_loader): - model_info = quant_config.get_model_info(model=self.simple_onnx_model) - configs_mapping = quant_config.to_config_mapping(model_info=model_info) + configs_mapping = quant_config.to_config_mapping(model=self.simple_onnx_model) if idx in [0, 4]: self.assertTrue(configs_mapping["Matmul"]["per_channel"]) elif idx in [1, 5]: @@ -202,8 +197,7 @@ def test_static_quant_config(self): ) config_loader = tuning.ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler) for idx, quant_config in enumerate(config_loader): - model_info = quant_config.get_model_info(model=self.simple_onnx_model) - configs_mapping = quant_config.to_config_mapping(model_info=model_info) + configs_mapping = quant_config.to_config_mapping(model=self.simple_onnx_model) self.assertTrue("add" not in configs_mapping) self.assertTrue("add2" not in configs_mapping) self.assertTrue("Matmul" not in configs_mapping) @@ -218,8 +212,7 @@ def test_static_quant_config(self): ) config_loader = tuning.ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler) for idx, quant_config in enumerate(config_loader): - model_info = quant_config.get_model_info(model=self.simple_onnx_model) - configs_mapping = quant_config.to_config_mapping(model_info=model_info) + configs_mapping = quant_config.to_config_mapping(model=self.simple_onnx_model) if "Matmul" in configs_mapping: self.assertFalse(configs_mapping["Matmul"]["per_channel"]) self.assertEqual(configs_mapping["Matmul"]["calibrate_method"], "MinMax") @@ -236,8 +229,7 @@ def test_static_quant_config(self): ) config_loader = tuning.ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler) for idx, quant_config in enumerate(config_loader): - model_info = quant_config.get_model_info(model=self.simple_onnx_model) - configs_mapping = quant_config.to_config_mapping(model_info=model_info) + configs_mapping = quant_config.to_config_mapping(model=self.simple_onnx_model) if idx in [0, 4]: self.assertTrue(configs_mapping["Matmul"]["per_channel"]) elif idx in [1, 5]: @@ -262,8 +254,7 @@ def test_static_custom_quant_config(self): ) config_loader = tuning.ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler) for idx, quant_config in enumerate(config_loader): - model_info = quant_config.get_model_info(model=self.simple_onnx_model) - configs_mapping = quant_config.to_config_mapping(model_info=model_info) + configs_mapping = quant_config.to_config_mapping(model=self.simple_onnx_model) if idx == 0: self.assertTrue(configs_mapping["Matmul"]["per_channel"]) elif idx == 1: @@ -281,8 +272,7 @@ def test_static_custom_quant_config(self): ) config_loader = tuning.ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler) for idx, quant_config in enumerate(config_loader): - model_info = quant_config.get_model_info(model=self.simple_onnx_model) - configs_mapping = quant_config.to_config_mapping(model_info=model_info) + configs_mapping = quant_config.to_config_mapping(model=self.simple_onnx_model) self.assertTrue("add" not in configs_mapping) self.assertTrue("add2" not in configs_mapping) self.assertTrue("Matmul" not in configs_mapping) @@ -299,8 +289,7 @@ def test_static_custom_quant_config(self): ) config_loader = tuning.ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler) for idx, quant_config in enumerate(config_loader): - model_info = quant_config.get_model_info(model=self.simple_onnx_model) - configs_mapping = quant_config.to_config_mapping(model_info=model_info) + configs_mapping = quant_config.to_config_mapping(model=self.simple_onnx_model) self.assertFalse(configs_mapping["Matmul"]["per_channel"]) self.assertEqual(configs_mapping["add"]["calibrate_method"], "MinMax") self.assertLess(idx, 4) @@ -315,8 +304,7 @@ def test_static_custom_quant_config(self): ) config_loader = tuning.ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler) for idx, quant_config in enumerate(config_loader): - model_info = quant_config.get_model_info(model=self.simple_onnx_model) - configs_mapping = quant_config.to_config_mapping(model_info=model_info) + configs_mapping = quant_config.to_config_mapping(model=self.simple_onnx_model) if idx == 0: self.assertTrue(configs_mapping["Matmul"]["per_channel"]) elif idx == 1: @@ -329,28 +317,6 @@ def test_static_custom_quant_config(self): self.assertLess(idx, 2) def test_config_white_lst(self): - global_config = config.RTNConfig(weight_bits=4) - # set operator instance - fc_out_config = config.RTNConfig(weight_dtype="fp32", white_list=["/h.4/mlp/fc_out/MatMul"]) - # get model and quantize - fp32_model = self.gptj - qmodel = algos.rtn_quantize_entry(fp32_model, quant_config=global_config + fc_out_config) - self.assertIsNotNone(qmodel) - self.assertEqual(self._count_woq_matmul(qmodel), 29) - self.assertFalse(self._check_node_is_quantized(qmodel, "/h.4/mlp/fc_out/MatMul")) - - def test_config_white_lst2(self): - global_config = config.RTNConfig(weight_dtype="fp32") - # set operator instance - fc_out_config = config.RTNConfig(weight_bits=4, white_list=["/h.4/mlp/fc_out/MatMul"]) - # get model and quantize - fp32_model = self.gptj - qmodel = algos.rtn_quantize_entry(fp32_model, quant_config=global_config + fc_out_config) - self.assertIsNotNone(qmodel) - self.assertEqual(self._count_woq_matmul(qmodel), 1) - self.assertTrue(self._check_node_is_quantized(qmodel, "/h.4/mlp/fc_out/MatMul")) - - def test_config_white_lst3(self): global_config = config.RTNConfig(weight_bits=4) # set operator instance @@ -360,7 +326,7 @@ def test_config_white_lst3(self): fp32_model = self.gptj model_info = config.RTNConfig.get_model_info(fp32_model) logger.info(quant_config) - configs_mapping = quant_config.to_config_mapping(model_info=model_info) + configs_mapping = quant_config.to_config_mapping(model=fp32_model) logger.info(configs_mapping) self.assertTrue(configs_mapping["/h.4/mlp/fc_out/MatMul"]["weight_bits"] == 8) self.assertTrue(configs_mapping["/h.4/mlp/fc_in/MatMul"]["weight_bits"] == 4) @@ -433,14 +399,14 @@ def test_config_mapping(self): fp32_model = self.gptj model_info = config.RTNConfig.get_model_info(fp32_model) logger.info(quant_config) - configs_mapping = quant_config.to_config_mapping(model_info=model_info) + configs_mapping = quant_config.to_config_mapping(model=fp32_model) logger.info(configs_mapping) self.assertTrue(configs_mapping["/h.4/mlp/fc_out/MatMul"]["weight_bits"] == 8) self.assertTrue(configs_mapping["/h.4/mlp/fc_in/MatMul"]["weight_bits"] == 4) # test regular matching fc_config = config.RTNConfig(weight_bits=3) quant_config.set_local("/h.[1-4]/mlp/fc_out/MatMul", fc_config) - configs_mapping = quant_config.to_config_mapping(model_info=model_info) + configs_mapping = quant_config.to_config_mapping(model=fp32_model) logger.info(configs_mapping) self.assertTrue(configs_mapping["/h.4/mlp/fc_out/MatMul"]["weight_bits"] == 3) self.assertTrue(configs_mapping["/h.3/mlp/fc_out/MatMul"]["weight_bits"] == 3) diff --git a/test/quantization/test_smooth_quant.py b/test/quantization/test_smooth_quant.py index 9ad53c148..b27b7f7f0 100644 --- a/test/quantization/test_smooth_quant.py +++ b/test/quantization/test_smooth_quant.py @@ -74,9 +74,10 @@ def tearDownClass(self): os.remove("Optimized_model.onnx") def test_sq_config(self): + model = onnx.load(self.gptj) sq_config = config.SmoothQuantConfig() - model_info = sq_config.get_model_info(model=onnx.load(self.gptj)) - self.assertEqual(len(model_info), 40) + model_info = sq_config.get_model_info(model=model) + self.assertEqual(len(model_info), len(model.graph.node)) def test_sq_from_class_beginner(self): self.data_reader.rewind() diff --git a/test/quantization/weight_only/test_awq.py b/test/quantization/weight_only/test_awq.py index b7def741b..ec2d9d1ea 100644 --- a/test/quantization/weight_only/test_awq.py +++ b/test/quantization/weight_only/test_awq.py @@ -10,8 +10,10 @@ import torch import transformers from optimum.exporters.onnx import main_export +from packaging import version from onnx_neural_compressor import data_reader, logger +from onnx_neural_compressor.quantization import QuantFormat from onnx_neural_compressor.quantization import algorithm_entry as algos from onnx_neural_compressor.quantization import config, matmul_4bits_quantizer, matmul_nbits_quantizer @@ -153,6 +155,7 @@ def test_awq_params_combination(self): "weight_sym": [True, False], "act_dtype": ["fp32"], "accuracy_level": [0], + "quant_format": [0, 1], "enable_auto_scale": [True, False], "enable_mse_search": [True, False], } @@ -191,13 +194,12 @@ def test_quantize_awq_from_class_beginner(self): def test_quantize_awq_fallback(self): - fp32_config = config.AWQConfig(weight_dtype="fp32") quant_config = config.AWQConfig( weight_dtype="int", weight_sym=False, weight_group_size=32, + nodes_to_exclude=["/h.4/mlp/fc_out/MatMul"], ) - quant_config.set_local("/h.4/mlp/fc_out/MatMul", fp32_config) qmodel = self._apply_awq(quant_config) self.assertIsNotNone(qmodel) self.assertEqual(self._count_woq_matmul(qmodel), 29) @@ -215,6 +217,28 @@ def test_quantize_awq_fallback(self): self.assertEqual(self._count_woq_matmul(qmodel), 29) self.assertFalse(self._check_node_is_quantized(qmodel, "/h.4/mlp/fc_out/MatMul")) + @unittest.skipIf( + version.Version(ort.__version__) < version.Version("1.19.0"), + "Please use onnxruntime >= 1.19.0 for QDQ format test", + ) + def test_awq_with_QDQ_format(self): + + quant_config = config.AWQConfig( + weight_dtype="int", + weight_sym=False, + weight_group_size=32, + weight_bits=4, + quant_format=QuantFormat.QDQ, + ) + + op21_model = copy.deepcopy(self.matmul_model) + op21_model.opset_import[0].version = 21 + qmodel = algos.awq_quantize_entry(op21_model, quant_config, calibration_data_reader=self.matmul_data_reader) + + self.assertIsNotNone(qmodel) + self.assertTrue("MatMul" in [i.op_type for i in qmodel.graph.node]) + self.assertTrue("DequantizeLinear" in [i.op_type for i in qmodel.graph.node]) + class TestAWQQuantWithORTLikeAPI(TestAWQQuant): @@ -326,6 +350,32 @@ def test_awq_with_specified_matmul(self): self.assertIsNotNone(quant.model) self.assertEqual(self._count_woq_matmul(quant.model, bits=4, group_size=32), 1) + @unittest.skipIf( + version.Version(ort.__version__) < version.Version("1.19.0"), + "Please use onnxruntime >= 1.19.0 for QDQ format test", + ) + def test_awq_with_QDQ_format(self): + + algo_config = matmul_nbits_quantizer.AWQWeightOnlyQuantConfig( + calibration_data_reader=self.matmul_data_reader, quant_format=QuantFormat.QDQ + ) + + op21_model = copy.deepcopy(self.matmul_model) + op21_model.opset_import[0].version = 21 + + quant = matmul_nbits_quantizer.MatMulNBitsQuantizer( + op21_model, + n_bits=4, + block_size=32, + is_symmetric=False, + algo_config=algo_config, + optimization_level=ort.GraphOptimizationLevel.ORT_DISABLE_ALL, + ) + quant.process() + self.assertIsNotNone(quant.model) + self.assertTrue("MatMul" in [i.op_type for i in quant.model.graph.node]) + self.assertTrue("DequantizeLinear" in [i.op_type for i in quant.model.graph.node]) + if __name__ == "__main__": unittest.main() diff --git a/test/quantization/weight_only/test_gptq.py b/test/quantization/weight_only/test_gptq.py index 7902371c7..789d5f684 100644 --- a/test/quantization/weight_only/test_gptq.py +++ b/test/quantization/weight_only/test_gptq.py @@ -10,8 +10,10 @@ import torch import transformers from optimum.exporters.onnx import main_export +from packaging import version from onnx_neural_compressor import data_reader, logger +from onnx_neural_compressor.quantization import QuantFormat from onnx_neural_compressor.quantization import algorithm_entry as algos from onnx_neural_compressor.quantization import config, matmul_4bits_quantizer, matmul_nbits_quantizer @@ -151,6 +153,7 @@ def test_gptq_params_combination(self): "weight_sym": [True, False], "act_dtype": ["fp32"], "accuracy_level": [0], + "quant_format": [0, 1], "percdamp": [0.01], "blocksize": [128], "actorder": [True, False], @@ -188,14 +191,13 @@ def test_quantize_gptq_from_class_beginner(self): self.assertIsNotNone(qmodel) def test_quantize_gptq_fallback(self): - fp32_config = config.GPTQConfig(weight_dtype="fp32") quant_config = config.GPTQConfig( weight_bits=4, weight_dtype="int", weight_sym=False, weight_group_size=32, + nodes_to_exclude=["/h.4/mlp/fc_out/MatMul"], ) - quant_config.set_local("/h.4/mlp/fc_out/MatMul", fp32_config) qmodel = self._apply_gptq(quant_config) self.assertIsNotNone(qmodel) self.assertEqual(self._count_woq_matmul(qmodel), 29) @@ -214,6 +216,26 @@ def test_quantize_gptq_fallback(self): self.assertEqual(self._count_woq_matmul(qmodel), 29) self.assertFalse(self._check_node_is_quantized(qmodel, "/h.4/mlp/fc_out/MatMul")) + @unittest.skipIf( + version.Version(ort.__version__) < version.Version("1.19.0"), + "Please use onnxruntime >= 1.19.0 for QDQ format test", + ) + def test_gptq_with_QDQ_format(self): + quant_config = config.GPTQConfig( + weight_bits=4, + weight_dtype="int", + weight_sym=False, + weight_group_size=32, + quant_format=QuantFormat.QDQ, + ) + op21_model = copy.deepcopy(self.matmul_model) + op21_model.opset_import[0].version = 21 + qmodel = algos.gptq_quantize_entry(op21_model, quant_config, calibration_data_reader=self.matmul_data_reader) + + self.assertIsNotNone(qmodel) + self.assertTrue("MatMul" in [i.op_type for i in qmodel.graph.node]) + self.assertTrue("DequantizeLinear" in [i.op_type for i in qmodel.graph.node]) + class TestGPTQQuantWithORTLikeAPI(TestGPTQQuant): @@ -323,6 +345,32 @@ def test_gptq_with_specified_matmul(self): self.assertIsNotNone(quant.model) self.assertEqual(self._count_woq_matmul(quant.model, bits=4, group_size=32), 1) + @unittest.skipIf( + version.Version(ort.__version__) < version.Version("1.19.0"), + "Please use onnxruntime >= 1.19.0 for QDQ format test", + ) + def test_gptq_with_QDQ_format(self): + + algo_config = matmul_nbits_quantizer.GPTQWeightOnlyQuantConfig( + calibration_data_reader=self.matmul_data_reader, quant_format=QuantFormat.QDQ + ) + + op21_model = copy.deepcopy(self.matmul_model) + op21_model.opset_import[0].version = 21 + + quant = matmul_nbits_quantizer.MatMulNBitsQuantizer( + op21_model, + n_bits=4, + block_size=32, + is_symmetric=False, + algo_config=algo_config, + optimization_level=ort.GraphOptimizationLevel.ORT_DISABLE_ALL, + ) + quant.process() + self.assertIsNotNone(quant.model) + self.assertTrue("MatMul" in [i.op_type for i in quant.model.graph.node]) + self.assertTrue("DequantizeLinear" in [i.op_type for i in quant.model.graph.node]) + if __name__ == "__main__": unittest.main() diff --git a/test/quantization/weight_only/test_rtn.py b/test/quantization/weight_only/test_rtn.py index 6f7cea1b8..d294342db 100644 --- a/test/quantization/weight_only/test_rtn.py +++ b/test/quantization/weight_only/test_rtn.py @@ -8,8 +8,10 @@ import onnx import onnxruntime as ort from optimum.exporters.onnx import main_export +from packaging import version from onnx_neural_compressor import logger +from onnx_neural_compressor.quantization import QuantFormat from onnx_neural_compressor.quantization import algorithm_entry as algos from onnx_neural_compressor.quantization import config, matmul_4bits_quantizer, matmul_nbits_quantizer @@ -138,14 +140,13 @@ def test_quantize_rtn_from_class_beginner(self): def test_quantize_rtn_fallback(self): - fp32_config = config.RTNConfig(weight_dtype="fp32") quant_config = config.RTNConfig( weight_bits=4, weight_dtype="int", weight_sym=False, weight_group_size=32, + nodes_to_exclude=["/h.4/mlp/fc_out/MatMul"], ) - quant_config.set_local("/h.4/mlp/fc_out/MatMul", fp32_config) qmodel = self._apply_rtn(quant_config) self.assertIsNotNone(qmodel) self.assertEqual(self._count_woq_matmul(qmodel), 29) @@ -164,6 +165,23 @@ def test_quantize_rtn_fallback(self): self.assertEqual(self._count_woq_matmul(qmodel), 29) self.assertFalse(self._check_node_is_quantized(qmodel, "/h.4/mlp/fc_out/MatMul")) + @unittest.skipIf( + version.Version(ort.__version__) < version.Version("1.19.0"), + "Please use onnxruntime >= 1.19.0 for QDQ format test", + ) + def test_rtn_with_QDQ_format(self): + + quant_config = config.RTNConfig( + weight_bits=4, weight_dtype="int", weight_sym=False, weight_group_size=32, quant_format=QuantFormat.QDQ + ) + op21_model = copy.deepcopy(self.matmul_model) + op21_model.opset_import[0].version = 21 + qmodel = algos.rtn_quantize_entry(op21_model, quant_config) + + self.assertIsNotNone(qmodel) + self.assertTrue("MatMul" in [i.op_type for i in qmodel.graph.node]) + self.assertTrue("DequantizeLinear" in [i.op_type for i in qmodel.graph.node]) + class TestRTNQuantWithORTLikeAPI(TestRTNQuant): @@ -267,6 +285,29 @@ def test_rtn_with_specified_matmul(self): self.assertIsNotNone(quant.model) self.assertEqual(self._count_woq_matmul(quant.model, bits=4, group_size=32), 1) + @unittest.skipIf( + version.Version(ort.__version__) < version.Version("1.19.0"), + "Please use onnxruntime >= 1.19.0 for QDQ format test", + ) + def test_rtn_with_QDQ_format(self): + + algo_config = matmul_nbits_quantizer.RTNWeightOnlyQuantConfig(quant_format=QuantFormat.QDQ) + op21_model = copy.deepcopy(self.matmul_model) + op21_model.opset_import[0].version = 21 + + quant = matmul_nbits_quantizer.MatMulNBitsQuantizer( + op21_model, + n_bits=4, + block_size=32, + is_symmetric=False, + algo_config=algo_config, + optimization_level=ort.GraphOptimizationLevel.ORT_DISABLE_ALL, + ) + quant.process() + self.assertIsNotNone(quant.model) + self.assertTrue("MatMul" in [i.op_type for i in quant.model.graph.node]) + self.assertTrue("DequantizeLinear" in [i.op_type for i in quant.model.graph.node]) + if __name__ == "__main__": unittest.main() diff --git a/test/utils/test_general.py b/test/utils/test_general.py index b07d73115..ee42c8714 100644 --- a/test/utils/test_general.py +++ b/test/utils/test_general.py @@ -197,6 +197,12 @@ def test_api(self): DEFAULT_WEIGHT_BITS, ) + model = FakeModel() + fake_default_config.set_local("OP1_NAME", FakeAlgoConfig(weight_dtype="uint")) + config_mapping = fake_default_config.to_config_mapping(model) + self.assertEqual(config_mapping["OP1_NAME"]["weight_dtype"], "uint") + self.assertEqual(config_mapping["OP2_NAME"]["weight_dtype"], "int") + def test_config_expand_complex_tunable_type(self): target_op_type_list_options = [["Conv", "Gemm"], ["Conv", "Matmul"]] configs = FakeAlgoConfig(target_op_type_list=target_op_type_list_options) @@ -212,8 +218,7 @@ def test_mixed_two_algos(self): fake_config = FakeAlgoConfig(weight_bits=4, white_list=[OP1_NAME]) fake1_config = FakeAlgoOneConfig(weight_bits=2, white_list=[OP2_NAME]) mixed_config = fake_config + fake1_config - model_info = mixed_config.get_model_info(model) - config_mapping = mixed_config.to_config_mapping(model_info=model_info) + config_mapping = mixed_config.to_config_mapping(model=model) self.assertIn(OP1_NAME, config_mapping) self.assertIn(OP2_NAME, config_mapping) diff --git a/test/utils/test_onnx_model.py b/test/utils/test_onnx_model.py index f27f64e1f..999b0985b 100644 --- a/test/utils/test_onnx_model.py +++ b/test/utils/test_onnx_model.py @@ -88,6 +88,7 @@ def tearDownClass(self): shutil.rmtree("./gptj", ignore_errors=True) shutil.rmtree("./large_model", ignore_errors=True) os.remove("matmul_add.onnx") + os.remove("model1.onnx") def setUp(self): # print the test name @@ -102,7 +103,7 @@ def test_model_atrribute(self): # model_path self.assertEqual(model.model_path, self.matmul_add_model) # framework - self.assertEqual(model.framework(), "onnxruntime") + self.assertEqual(model.framework, "onnxruntime") # q_config quant_config = config.RTNConfig() model.q_config = quant_config