Skip to content

Commit

Permalink
[Quant Tool] Update QDQ Pad, Slice, Softmax (#22676)
Browse files Browse the repository at this point in the history
### Description
Updates python quantization tool:
- Ensures QDQ Pad has equal quantization parameters across input and
output for certain Pad configurations.
- Ensures QDQ Slice always has equal quantization parameters across
input and output.
- Fixes bug when Softmax is _excluded_ from quantization.


### Motivation and Context
QDQ Pad and Slice have lower latency on QNN EP when their quantization
parameters are equal.
  • Loading branch information
adrianlizarraga authored Nov 6, 2024
1 parent 0221693 commit aa0cf1c
Show file tree
Hide file tree
Showing 7 changed files with 461 additions and 2 deletions.
2 changes: 2 additions & 0 deletions onnxruntime/python/tools/quantization/base_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,4 +554,6 @@ def adjust_tensor_ranges(self):
self.tensors_range[node.input[0]] = td
# Adjust Softmax to range from 0.0 to 1.0
elif node.op_type == "Softmax":
if not self.should_quantize_node(node):
continue
self.tensors_range[node.output[0]] = TensorData(lowest=np.float32(0.0), highest=np.float32(1.0))
72 changes: 72 additions & 0 deletions onnxruntime/python/tools/quantization/operators/pad.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from __future__ import annotations

from typing import Any

import numpy as np
import onnx

from ..quant_utils import (
Expand All @@ -8,6 +17,7 @@
quantize_nparray,
)
from .base_operator import QuantOperatorBase
from .qdq_base_operator import QDQOperatorBase


class QPad(QuantOperatorBase):
Expand Down Expand Up @@ -98,3 +108,65 @@ def quantize(self):
node.input[0] = quantized_input_value.q_name
node.output[0] = quantized_output_value.q_name
self.quantizer.new_nodes += [node]


class QDQPad(QDQOperatorBase):
def __init__(self, onnx_quantizer, onnx_node):
super().__init__(onnx_quantizer, onnx_node)

def _get_pad_const_val(self, attrs_dict: dict[str, Any]) -> np.ndarray | None:
"""
Returns the Pad's constant padding value. Returns `None` if the padding value is
not constant (i.e., comes from a dynamic input).
"""
const_val = None
onnx_tensor_type = self.quantizer.model.get_tensor_type(self.node.input[0])
if onnx_tensor_type is None:
return None

np_dtype = onnx.helper.tensor_dtype_to_np_dtype(onnx_tensor_type.elem_type)
if self.quantizer.opset_version < 11:
const_val = np.array(attrs_dict.get("value", 0), dtype=np_dtype)
elif len(self.node.input) >= 3 and self.node.input[2]:
const_val = self.quantizer.model.get_constant_value(self.node.input[2])
else:
const_val = np.array(0, dtype=np_dtype)

return const_val

def _should_quantize_output_same_as_input(self) -> bool:
"""
Returns true if Pad's output should use the same quantization parameters as input[0]
"""
attrs_dict = {}
for attribute in self.node.attribute:
kv = attribute_to_kwarg(attribute)
attrs_dict.update(kv)

pad_mode = attrs_dict.get("mode", b"constant")
if pad_mode in (b"reflect", b"edge", b"wrap"):
# These modes pad the output with a value that already exists in the input.
# So, we can quantize the output the same as the input.
return True

# For 'constant' mode, if padding with 0, we can also quantize the output the same as the input
# because our quantization floating-point range always includes 0.
if pad_mode == b"constant":
pad_val = self._get_pad_const_val(attrs_dict)
if pad_val is not None and pad_val.dtype in (np.float32, np.float16):
return float(pad_val.item()) == 0

return False

def quantize(self):
assert self.node.op_type == "Pad"

for input_name in self.node.input:
if input_name:
self.quantizer.quantize_activation_tensor(input_name)

if not self.disable_qdq_for_node_output:
if self._should_quantize_output_same_as_input():
self.quantizer.quantize_output_same_as_input(self.node.output[0], self.node.input[0], self.node.name)
else:
self.quantizer.quantize_activation_tensor(self.node.output[0])
4 changes: 3 additions & 1 deletion onnxruntime/python/tools/quantization/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .operators.matmul import MatMulInteger, QDQMatMul, QLinearMatMul
from .operators.maxpool import QDQMaxPool, QMaxPool
from .operators.norm import QDQNormalization
from .operators.pad import QPad
from .operators.pad import QDQPad, QPad
from .operators.pooling import QLinearPool
from .operators.qdq_base_operator import QDQOperatorBase
from .operators.resize import QDQResize, QResize
Expand Down Expand Up @@ -76,6 +76,8 @@
"Resize": QDQResize,
"MaxPool": QDQMaxPool,
"AveragePool": QDQDirect8BitOp,
"Slice": QDQDirect8BitOp,
"Pad": QDQPad,
"MatMul": QDQMatMul,
"Split": QDQSplit,
"Gather": QDQGather,
Expand Down
33 changes: 33 additions & 0 deletions onnxruntime/test/python/quantization/op_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations

import uuid
from pathlib import Path

Expand Down Expand Up @@ -661,3 +668,29 @@ def generate_random_initializer(initializer_name, tensor_shape, tensor_dtype, me
tensor = np.random.normal(mean, dev, tensor_shape).astype(tensor_dtype)
init = onnx.numpy_helper.from_array(tensor, initializer_name)
return init


def get_tensor_consumers_and_producers(
model: onnx.ModelProto,
) -> tuple[dict[str, list[onnx.NodeProto]], dict[str, onnx.NodeProto]]:
"""
Returns a tuple containing the following python dictionaries:
- consumers: maps a tensor name to the list of nodes that have that tensor as an input.
- producers: maps a tensor name to the node that generates this tensor as an output.
"""
consumers: dict[str, list[onnx.NodeProto]] = {}
producers: dict[str, onnx.NodeProto] = {}
for node in model.graph.node:
# Iterate through node's inputs to build the consumers dictionary.
for input_name in node.input:
if input_name:
if input_name not in consumers:
consumers[input_name] = []

consumers[input_name].append(node)

# Iterate through node's outputs to build the producers dictionary.
for output_name in node.output:
producers[output_name] = node

return (consumers, producers)
165 changes: 164 additions & 1 deletion onnxruntime/test/python/quantization/test_op_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,23 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations

import itertools
import os
import tempfile
import unittest

import numpy as np
import onnx
from onnx import TensorProto, helper
from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type
from op_test_utils import (
TestDataFeeds,
check_model_correctness,
check_op_type_count,
check_qtype_by_node_type,
get_tensor_consumers_and_producers,
)

from onnxruntime.quantization import QuantFormat, QuantType, quantize_dynamic, quantize_static

Expand Down Expand Up @@ -519,5 +528,159 @@ def test_pad_with_empty_string_input_name(self):
self.assertNotEqual(name, "_quantized")


class TestQDQPad(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.pad_")

# Note: swap with the commented line if you want to see the models in local test dir.
cls._tmp_dir_path = cls._tmp_model_dir.name
# cls._tmp_dir_path = "."

@classmethod
def tearDownClass(cls):
cls._tmp_model_dir.cleanup()

def build_pad_model(
self,
mode: str,
constant_value: float | None = None,
opset: int = 21,
float_type: onnx.TensorProto.DataType = onnx.TensorProto.FLOAT,
) -> onnx.ModelProto:
input_0 = onnx.helper.make_tensor_value_info("input_0", float_type, (3, 2))
output_0 = onnx.helper.make_tensor_value_info("output_0", float_type, (3, 4))

initializers = []
pad_input_names = ["input_0"]
attrs = {"mode": mode}

pads_data = np.array([0, 2, 0, 0], dtype=np.int64) # Pad two vals at beginning of axis 1.
if opset >= 11:
initializers.append(onnx.numpy_helper.from_array(pads_data, "pads"))
pad_input_names.append("pads")
else:
attrs["pads"] = pads_data.tolist()

if mode == "constant" and constant_value is not None:
if opset >= 11:
initializers.append(onnx.helper.make_tensor("constant_value", float_type, [], [constant_value]))
pad_input_names.append("constant_value")
else:
attrs["value"] = float(constant_value)

pad_node = onnx.helper.make_node("Pad", pad_input_names, ["output_0"], name="Pad0", **attrs)

graph = onnx.helper.make_graph(
[pad_node],
"PadFloat",
[input_0],
[output_0],
initializer=initializers,
)
opset_imports = [onnx.helper.make_opsetid("", opset)]
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
model = onnx.shape_inference.infer_shapes(model)
onnx.checker.check_model(model, True)
return model

def test_qdq_pad_qparams(self):
"""
Test that QDQ Pad has equal scale/zero-point for its input and output for certain configurations.
"""
test_configs = [
# Opset 21
("constant", None, 21, onnx.TensorProto.FLOAT),
("constant", None, 21, onnx.TensorProto.FLOAT16),
("constant", 0, 21, onnx.TensorProto.FLOAT),
("constant", 0, 21, onnx.TensorProto.FLOAT16),
("constant", 10.0, 21, onnx.TensorProto.FLOAT),
("constant", 10.0, 21, onnx.TensorProto.FLOAT16),
("reflect", None, 21, onnx.TensorProto.FLOAT),
("reflect", None, 21, onnx.TensorProto.FLOAT16),
("edge", None, 21, onnx.TensorProto.FLOAT),
("edge", None, 21, onnx.TensorProto.FLOAT16),
("wrap", None, 21, onnx.TensorProto.FLOAT),
("wrap", None, 21, onnx.TensorProto.FLOAT16),
# Model with opset 10 will use pad of opset 2, which uses attributes instead of inputs.
# Opset 10 Q/DQ ops don't support float16.
("constant", None, 10, onnx.TensorProto.FLOAT),
("constant", 0, 10, onnx.TensorProto.FLOAT),
("constant", 10.0, 10, onnx.TensorProto.FLOAT),
("reflect", None, 10, onnx.TensorProto.FLOAT),
("edge", None, 10, onnx.TensorProto.FLOAT),
]

for pad_mode, constant_value, opset, float_type in test_configs:
with self.subTest(pad_mode=pad_mode, constant_value=constant_value, opset=opset, float_type=float_type):
label = f"_{pad_mode}_{constant_value}_opset{opset}_{onnx.TensorProto.DataType.Name(float_type)}"
float_model_path = os.path.join(self._tmp_dir_path, f"pad{label}.float.onnx")
qdq_model_path = os.path.join(self._tmp_dir_path, f"pad{label}.qdq.onnx")

float_model = self.build_pad_model(pad_mode, constant_value, opset=opset, float_type=float_type)
onnx.save_model(float_model, float_model_path)

# Create a data reader
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(float_type)
input_data_list = [
{"input_0": np.array([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]], dtype=np_dtype)},
{"input_0": np.array([[2.3, 3.4], [4.5, 5.7], [1.0, 1.2]], dtype=np_dtype)},
]
data_reader = TestDataFeeds(input_data_list)

# quantize model to QDQ
quantize_static(
float_model_path,
qdq_model_path,
data_reader,
quant_format=QuantFormat.QDQ,
activation_type=QuantType.QUInt8,
weight_type=QuantType.QInt8,
)

expected_op_counts = {"DequantizeLinear": 2, "QuantizeLinear": 2, "Pad": 1}
if constant_value is not None and opset >= 11:
expected_op_counts["DequantizeLinear"] += 1 # The constant padding value is quantized.
check_op_type_count(self, qdq_model_path, **expected_op_counts)

if pad_mode != "reflect":
# Do not check model correctness for 'reflect' mode because ONNX Runtime implementation does
# not match the ONNX reference implementation. See the following issue:
# https://github.com/microsoft/onnxruntime/issues/20801
data_reader.rewind()
check_model_correctness(self, float_model_path, qdq_model_path, data_reader.get_next())

qdq_model = onnx.load_model(qdq_model_path)
quant_output_same_as_input = False

if pad_mode in ("reflect", "edge", "wrap"):
quant_output_same_as_input = True

if pad_mode == "constant" and constant_value in (None, 0):
quant_output_same_as_input = True

pad_node = next((node for node in qdq_model.graph.node if node.op_type == "Pad"), None)
self.assertNotEqual(pad_node, None)
self.assertEqual(pad_node.op_type, "Pad")

# Get the parent and child nodes of the Pad and check that they are DQ/Q.
consumers, producers = get_tensor_consumers_and_producers(qdq_model)
input_dq_node = producers.get(pad_node.input[0], None)
self.assertNotEqual(input_dq_node, None)
self.assertEqual(input_dq_node.op_type, "DequantizeLinear")

output_q_node = consumers.get(pad_node.output[0], [None])[0]
self.assertNotEqual(output_q_node, None)
self.assertEqual(output_q_node.op_type, "QuantizeLinear")

# Check that the Pad's input DQ uses the same scale/zp as the Pad's output Q.
if quant_output_same_as_input:
self.assertEqual(input_dq_node.input[1], output_q_node.input[1]) # Same scale
self.assertEqual(input_dq_node.input[2], output_q_node.input[2]) # Same zero-point
else:
self.assertNotEqual(input_dq_node.input[1], output_q_node.input[1])
self.assertNotEqual(input_dq_node.input[2], output_q_node.input[2])


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit aa0cf1c

Please sign in to comment.