Skip to content

Commit

Permalink
Remove empty (DQ -> Q -> graph output) sequence in TransposeOptimizer (
Browse files Browse the repository at this point in the history
…#22172)

### Description
Updates the TransposeOptimizer to also remove empty (DQ -> Q) sequences
that occur at a graph output. An empty DQ->Q sequence results from a
Transpose being optimized out.

Consider the following example model:

![image](https://github.com/user-attachments/assets/4e7bc4eb-ea8a-463b-9672-c4ec5ef779b2)

The TransposeOptimizer removes the final Transpose and leaves an empty
DQ->Q->output_0 sequence. This PR ensures that the final DQ->Q is also
removed.

### Motivation and Context
Models with quantized output can run on QNN EP. The inference latency of
a customer model is impacted by the unnecessary DQ->Q sequence at the
output.

---------

Co-authored-by: Scott McKay <[email protected]>
  • Loading branch information
adrianlizarraga and skottmckay authored Sep 25, 2024
1 parent ee6a915 commit a47254e
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2749,7 +2749,9 @@ static bool CanModifyNode(const OptimizerCtx& ctx, const api::NodeRef& node) {

/// <summary>
/// Try to remove empty DQ -> Q pair that results from moving a Transpose downstream or a Transpose being canceled out.
/// (DQ -> Q -> consumer node) => consumer node
/// Handles the following scenarios:
/// - (DQ -> Q -> consumer node) => consumer node
/// - (parent node -> DQ -> Q -> graph output) => parent node -> graph output
/// </summary>
/// <param name="ctx">Optimizer context</param>
/// <param name="q_node">QuantizeLinear node</param>
Expand All @@ -2764,12 +2766,27 @@ static bool TryRemoveEmptyDQQ(OptimizerCtx& ctx, api::NodeRef& q_node) {
}

auto& dq_node = *input_node;
std::unique_ptr<api::NodeRef> single_consumer_node;

// remove empty DQ -> Q before a consumer node if the DQ and Q have matching types, scale and zp.
if (OutputValueHasSingleConsumerNode(ctx.graph, dq_node, 0, single_consumer_node) &&
OutputValueHasSingleConsumerNode(ctx.graph, q_node, 0, single_consumer_node) &&
CheckQDQNodePairMatch(ctx.graph, dq_node, q_node)) {
// DQ should have a single consumer (the Q)
std::unique_ptr<api::NodeRef> dq_consumer_node;
if (!OutputValueHasSingleConsumerNode(ctx.graph, dq_node, 0, dq_consumer_node)) {
return false;
}

// The DQ and Q should have matching types, scale and zp.
if (!CheckQDQNodePairMatch(ctx.graph, dq_node, q_node)) {
return false;
}

std::string_view q_output = q_node.Outputs()[0];
auto q_consumers = ctx.graph.GetValueConsumers(q_output);
const size_t num_q_consumers = q_consumers->nodes.size();
const bool q_has_single_consumer = q_consumers->comprehensive && (num_q_consumers == 1);

// (DQ -> Q -> consumer node) => consumer node
if (q_has_single_consumer) {
std::unique_ptr<api::NodeRef> single_consumer_node = std::move(q_consumers->nodes[0]);

// connect Q consumer to DQ input
for (size_t j_idx = 0, j_end = single_consumer_node->Inputs().size(); j_idx < j_end; ++j_idx) {
if (single_consumer_node->Inputs()[j_idx] == q_node.Outputs()[0]) {
Expand All @@ -2787,6 +2804,40 @@ static bool TryRemoveEmptyDQQ(OptimizerCtx& ctx, api::NodeRef& q_node) {
return true;
}

// (parent node -> DQ -> Q -> graph output) => (parent node -> graph output)
if (num_q_consumers == 0 && ctx.graph.IsGraphOutput(q_output)) {
// Get the DQ's parent node.
std::string_view dq_input = dq_node.Inputs()[0];
auto dq_parent_node = ctx.graph.GetNodeProducingOutput(dq_input);
if (!dq_parent_node) {
return false; // Don't handle DQ that consumes a graph input.
}

// Find index of output from DQ's parent node
auto dq_parent_outputs = dq_parent_node->Outputs();
size_t dq_parent_output_index = 0;
for (dq_parent_output_index = 0; dq_parent_output_index < dq_parent_outputs.size(); ++dq_parent_output_index) {
if (dq_parent_outputs[dq_parent_output_index] == dq_input) break;
}

// The DQ's parent should only have a single consumer (i.e., the DQ itself).
std::unique_ptr<api::NodeRef> dq_parent_consumer;
if (!OutputValueHasSingleConsumerNode(ctx.graph, *dq_parent_node, dq_parent_output_index, dq_parent_consumer)) {
return false;
}

// Move Q's output to come out of DQ's parent node so the graph output value name is maintained.
dq_node.SetInput(0, ""); // Disconnect DQ from its parent first.
ctx.graph.MoveOutput(q_node, 0, *dq_parent_node, dq_parent_output_index);

// Disconnect Q and remove both DQ and Q from the graph.
q_node.SetInput(0, "");
ctx.graph.RemoveNode(dq_node);
ctx.graph.RemoveNode(q_node);

return true;
}

return false;
}

Expand Down
84 changes: 84 additions & 0 deletions onnxruntime/test/optimizer/transpose_optimizer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4964,6 +4964,90 @@ TEST(TransposeOptimizerTests, FixQDQNodeUnitWithPerAxisDQUnsqueezeTranspose) {
testing::ContainerEq(fetches[0].Get<Tensor>().DataAsSpan<float>()));
}

// Test that the TransposeOptimizer's qdq-fixup pass converts the sequence (Op -> DQ -> Q -> GRAPH_OUTPUT) to
// (Op -> GRAPH_OUTPUT).
TEST(TransposeOptimizerTests, RemoveEmptyDQQAtGraphOutput) {
auto model_uri = ORT_TSTR("testdata/transpose_optimizer_empty_dq_q_at_graph_output.onnx");

RandomValueGenerator random{123};
std::vector<int64_t> input_dims{1, 3, 4, 4};
std::vector<float> input0_data = random.Gaussian<float>(input_dims, 0.0f, 1.0f);

auto allocators = TestCPUExecutionProvider()->CreatePreferredAllocators();
OrtValue input0;
CreateMLValue<float>(allocators[0], input_dims, input0_data, &input0);

NameMLValMap feeds{{"input0", input0}};

std::vector<std::string> output_names{"output0"};
std::vector<OrtValue> fetches_orig;
std::vector<OrtValue> fetches;

SessionOptions so;
ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1"));
so.graph_optimization_level = TransformerLevel::Default; // off

// get results with no modifications to the model
{
InferenceSessionWrapper session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
ASSERT_STATUS_OK(session.Initialize());
ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig));
}

{
InferenceSessionWrapper session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));

Graph& graph = session.GetMutableGraph();
CPUAllocator allocator;

namespace alias_oto = onnx_transpose_optimization;
auto api_graph = MakeApiGraph(graph,
TestCPUExecutionProvider()->CreatePreferredAllocators()[0],
/*new_node_ep*/ nullptr);

alias_oto::OptimizeResult result = alias_oto::Optimize(*api_graph);
ASSERT_EQ(result.error_msg, std::nullopt);
ASSERT_TRUE(result.graph_modified);
ASSERT_TRUE(graph.GraphResolveNeeded());
ASSERT_STATUS_OK(graph.Resolve());

// Use this hack to save model for viewing if needed
// ASSERT_STATUS_OK(Model::Save(const_cast<Model&>(session.GetModel()),
// ToPathString("updated_model_empty_dqq_graph_output.onnx")));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["Transpose"], 0) << "2 pre-existing Transposes at the I/O cancel. ";

// Check that the graph ends in the sequence (Mul -> Q -> GRAPH_OUTPUT)
Node* mul_node = nullptr;
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Mul") {
mul_node = &node;
break;
}
}

// Mul should be followed by a Q node.
ASSERT_TRUE(mul_node != nullptr);
const auto& last_q_node = *(mul_node->OutputNodesBegin());
EXPECT_EQ(last_q_node.OpType(), "QuantizeLinear");

// The Q node should generate the graph's output.
const std::string& q_out_name = last_q_node.OutputDefs()[0]->Name();
const std::string& graph_out_name = graph.GetOutputs()[0]->Name();
EXPECT_EQ(q_out_name, graph_out_name);

// Run optimized model.
ASSERT_STATUS_OK(session.Initialize());
ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches));
}

ASSERT_THAT(fetches_orig[0].Get<Tensor>().DataAsSpan<uint8_t>(),
testing::ContainerEq(fetches[0].Get<Tensor>().DataAsSpan<uint8_t>()));
}

// Tests the in-place unsqueeze and transpose of a constant consumed by a per-axis DQ.
TEST(TransposeOptimizerTests, InPlaceUnsqueezeTransposePerAxisDQ) {
// Model contains a Mul with a constant/broadcastable/per-axis DQ input[1].
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import numpy as np
import onnx


def make_model(model_path: str):
"""
Creates a QDQ model with a (DQ -> Transpose -> Q -> GRAPH OUTPUT) sequence. The Transpose is optimized out
and the TransposeOptimizer should also remove the empty (DQ -> Q) sequence.
"""
input0_shape = (1, 3, 4, 4)

inputs = [onnx.helper.make_tensor_value_info("input0", onnx.TensorProto.FLOAT, input0_shape)]
outputs = [onnx.helper.make_tensor_value_info("output0", onnx.TensorProto.UINT8, None)]

mul_weight_scale_data = np.array(1.0, dtype=np.float32)
mul_weight_zp_data = np.array(0, dtype=np.int8)

initializers = [
onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "scale_1"),
onnx.numpy_helper.from_array(np.array(128, dtype=np.uint8), "zp_128"),
onnx.numpy_helper.from_array(np.array(1.0 / 255.0, dtype=np.float32), "scale_inv_255"),
onnx.numpy_helper.from_array(np.array(0, dtype=np.uint8), "zp_0"),
onnx.numpy_helper.from_array(mul_weight_scale_data, "mul_weight_scale"),
onnx.numpy_helper.from_array(mul_weight_zp_data, "mul_weight_zp"),
]
nodes = []

# Transpose to channel-last
tp0_node = onnx.helper.make_node("Transpose", ["input0"], ["tp0_out"], name="tp0_node", perm=(0, 2, 3, 1))
nodes.append(tp0_node)

# Q_0
q0_node = onnx.helper.make_node("QuantizeLinear", ["tp0_out", "scale_1", "zp_128"], ["q0_out"], name="q0_node")
nodes.append(q0_node)

# DQ_0
dq0_node = onnx.helper.make_node("DequantizeLinear", ["q0_out", "scale_1", "zp_128"], ["dq0_out"], name="dq0_node")
nodes.append(dq0_node)

# Sigmoid
sigmoid_node = onnx.helper.make_node("Sigmoid", ["dq0_out"], ["sigmoid_out"], name="sigmoid_node")
nodes.append(sigmoid_node)

# Q_1
q1_node = onnx.helper.make_node(
"QuantizeLinear", ["sigmoid_out", "scale_inv_255", "zp_0"], ["q1_out"], name="q1_node"
)
nodes.append(q1_node)

# DQ_1
dq1_node = onnx.helper.make_node(
"DequantizeLinear", ["q1_out", "scale_inv_255", "zp_0"], ["dq1_out"], name="dq1_node"
)
nodes.append(dq1_node)

# DQ for mul input[1]
mul_weight_i8_data = np.array([1, 2, 3], dtype=np.int8)
mul_weight = onnx.numpy_helper.from_array(mul_weight_i8_data, "mul_weight")
initializers.append(mul_weight)

nodes.append(
onnx.helper.make_node(
"DequantizeLinear",
["mul_weight", "mul_weight_scale", "mul_weight_zp"],
["mul_input_1"],
name="dq_mul_input_1",
)
)

# Mul
mul_node = onnx.helper.make_node("Mul", ["dq1_out", "mul_input_1"], ["mul_out"], name="mul_node")
nodes.append(mul_node)

# Q_2
q2_node = onnx.helper.make_node("QuantizeLinear", ["mul_out", "scale_inv_255", "zp_0"], ["q2_out"], name="q2_node")
nodes.append(q2_node)

# DQ_2
dq2_node = onnx.helper.make_node(
"DequantizeLinear", ["q2_out", "scale_inv_255", "zp_0"], ["dq2_out"], name="dq2_node"
)
nodes.append(dq2_node)

# Transpose to channel-first
tp1_node = onnx.helper.make_node("Transpose", ["dq2_out"], ["tp1_out"], name="tp1_node", perm=(0, 3, 1, 2))
nodes.append(tp1_node)

# Q_3 to graph output
nodes.append(
onnx.helper.make_node("QuantizeLinear", ["tp1_out", "scale_inv_255", "zp_0"], ["output0"], name="q3_node")
)

graph = onnx.helper.make_graph(
nodes,
"transpose_opt_empty_dqq_graph_output",
inputs,
outputs,
initializer=initializers,
)
opset_imports = [
onnx.helper.make_opsetid("", 19),
]
qdq_model = onnx.helper.make_model(graph, opset_imports=opset_imports)

print("[INFO]: Running onnx.checker on qdq model")
qdq_model = onnx.shape_inference.infer_shapes(qdq_model)
onnx.checker.check_model(qdq_model, True)

print(f"[INFO]: Saving {model_path}")
onnx.save_model(qdq_model, model_path)


if __name__ == "__main__":
make_model("transpose_optimizer_empty_dq_q_at_graph_output.onnx")
Binary file not shown.

0 comments on commit a47254e

Please sign in to comment.