From 576c2f87294ad8e7e9ea7e9ad18e0694bf51d7a6 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sun, 28 Jul 2024 10:12:28 -0700 Subject: [PATCH] Clean up --- onnxruntime/core/framework/node_unit.cc | 6 +- onnxruntime/core/framework/node_unit.h | 6 +- .../providers/qnn/builder/qnn_model_wrapper.h | 6 -- .../providers/qnn/builder/qnn_node_group.h | 5 +- .../qnn_node_group/conv_activation_fusion.cc | 65 ++++++++++--------- .../qnn_node_group/conv_activation_fusion.h | 11 ++-- .../qnn/builder/qnn_node_group/dq_q_fusion.cc | 12 ++-- .../qnn/builder/qnn_node_group/dq_q_fusion.h | 5 +- .../qnn_node_group/hardsigmoid_mul_fusion.cc | 12 ++-- .../qnn_node_group/hardsigmoid_mul_fusion.h | 5 +- .../builder/qnn_node_group/qnn_node_group.cc | 38 +++++------ 11 files changed, 82 insertions(+), 89 deletions(-) diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc index 84d6ccb4d7acb..d2930a770c0a0 100644 --- a/onnxruntime/core/framework/node_unit.cc +++ b/onnxruntime/core/framework/node_unit.cc @@ -272,9 +272,9 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g } } -NodeUnit::NodeUnit(gsl::span dq_nodes, const Node& target_node, - gsl::span q_nodes, Type type, - gsl::span inputs, gsl::span outputs, +NodeUnit::NodeUnit(gsl::span dq_nodes, const Node& target_node, + gsl::span q_nodes, Type type, + gsl::span inputs, gsl::span outputs, size_t input_edge_count, Node::EdgeSet output_edges) : dq_nodes_(dq_nodes.begin(), dq_nodes.end()), target_node_(target_node), diff --git a/onnxruntime/core/framework/node_unit.h b/onnxruntime/core/framework/node_unit.h index c2297c13a41e6..8bc2f79c4a372 100644 --- a/onnxruntime/core/framework/node_unit.h +++ b/onnxruntime/core/framework/node_unit.h @@ -68,9 +68,9 @@ class NodeUnit { public: explicit NodeUnit(const Node& node); explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group); - NodeUnit(gsl::span dq_nodes, const Node& target_node, - gsl::span q_nodes, Type type, - gsl::span inputs, gsl::span outputs, + NodeUnit(gsl::span dq_nodes, const Node& target_node, + gsl::span q_nodes, Type type, + gsl::span inputs, gsl::span outputs, size_t input_edge_count, Node::EdgeSet output_edges); Type UnitType() const noexcept { return type_; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index fdf6616393ff8..9ab122b7f8e28 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -52,12 +52,6 @@ class QnnModelWrapper { ~QnnModelWrapper() = default; - const QNN_INTERFACE_VER_TYPE& GetQnnInterface() const { return qnn_interface_; } - const Qnn_BackendHandle_t& GetQnnBackendHandle() const { return backend_handle_; } - const std::unordered_map& GetInputIndexMap() const { return input_index_map_; } - const std::unordered_map& GetOutputIndexMap() const { return output_index_map_; } - const std::unordered_set& GetInitializerLookup() const { return initializer_lookup_; } - bool CreateQnnGraph(const Qnn_ContextHandle_t& context, const std::string& graph_name, const QnnGraph_Config_t** graph_configs = nullptr); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group.h index fb6aa221aac3e..bd2e58c2d3973 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -18,11 +19,9 @@ class IQnnNodeGroup { virtual ~IQnnNodeGroup() = default; virtual Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const = 0; virtual Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const = 0; - virtual std::vector GetNodeUnits() const = 0; + virtual gsl::span GetNodeUnits() const = 0; virtual const NodeUnit* GetTargetNodeUnit() const = 0; virtual std::string_view Type() const = 0; - - size_t index_ = 0; }; Status GetQnnNodeGroups(/*out*/ std::vector>& qnn_node_groups, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc index 4da4a748f801d..065a2810a2920 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc @@ -332,6 +332,7 @@ static Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, const size_t num_dqs = dq_node_units.size(); constexpr size_t max_num_dqs = 3; ORT_RETURN_IF_NOT(num_dqs == 2 || num_dqs == max_num_dqs, "QDQ Conv should have 2 or 3 DQs"); + ORT_RETURN_IF_NOT(conv_node_unit->OpType() == "Conv" && q_node_unit->OpType() == "QuantizeLinear"); std::array dq_nodes_buf = {}; for (size_t i = 0; i < num_dqs; i++) { @@ -447,64 +448,66 @@ std::unique_ptr TryConvActivationFusion(QnnModelWrapper& qnn_mode return nullptr; } - return std::make_unique(dq_node_units, conv_node_unit, - *activation_node_unit, *q_node_unit); + return std::make_unique(*dq_node_units[0], + *dq_node_units[1], + dq_node_units.size() == 3 ? dq_node_units[2] : nullptr, + conv_node_unit, + *activation_node_unit, + *q_node_unit); } namespace conv_act_fusion { -QnnNodeGroup::QnnNodeGroup(gsl::span dq_node_units, +QnnNodeGroup::QnnNodeGroup(const NodeUnit& dq_node_unit_0, + const NodeUnit& dq_node_unit_1, + const NodeUnit* dq_node_unit_2, const NodeUnit& conv_node_unit, const NodeUnit& activation_node_unit, const NodeUnit& q_node_unit) - : dq_node_units_{}, - conv_node_unit_(conv_node_unit), - activation_node_unit_(activation_node_unit), - q_node_unit_(q_node_unit) { - assert(dq_node_units.size() <= dq_node_units_.size()); - std::copy(dq_node_units.begin(), dq_node_units.end(), dq_node_units_.data()); + : node_units_{} { + size_t i = 0; + node_units_[i++] = &dq_node_unit_0; + node_units_[i++] = &dq_node_unit_1; + if (dq_node_unit_2 != nullptr) { + node_units_[i++] = dq_node_unit_2; + } + node_units_[i++] = &conv_node_unit; + node_units_[i++] = &activation_node_unit; + node_units_[i++] = &q_node_unit; + assert((!dq_node_unit_2 && i == 5) || (dq_node_unit_2 && i == 6)); } Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { - const size_t num_dqs = dq_node_units_.back() != nullptr ? 3 : 2; - gsl::span dq_node_units(dq_node_units_.data(), num_dqs); + const size_t num_dqs = node_units_.back() != nullptr ? 3 : 2; + gsl::span dq_node_units(node_units_.data(), num_dqs); return QnnConvActivationFusionAdd(qmw, dq_node_units, - &conv_node_unit_, - &q_node_unit_, + node_units_[num_dqs], // Conv + node_units_[num_dqs + 2], // Q logger, /*validate*/ true); } Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { - const size_t num_dqs = dq_node_units_.back() != nullptr ? 3 : 2; - gsl::span dq_node_units(dq_node_units_.data(), num_dqs); + const size_t num_dqs = node_units_.back() != nullptr ? 3 : 2; + gsl::span dq_node_units(node_units_.data(), num_dqs); return QnnConvActivationFusionAdd(qmw, dq_node_units, - &conv_node_unit_, - &q_node_unit_, + node_units_[num_dqs], // Conv + node_units_[num_dqs + 2], // Q logger, /*validate*/ false); } -std::vector QnnNodeGroup::GetNodeUnits() const { - const size_t num_dqs = dq_node_units_.back() != nullptr ? 3 : 2; - - std::vector node_units; - node_units.reserve(6); - for (size_t i = 0; i < num_dqs; i++) { - node_units.push_back(dq_node_units_[i]); - } - node_units.push_back(&conv_node_unit_); - node_units.push_back(&activation_node_unit_); - node_units.push_back(&q_node_unit_); - - return node_units; +gsl::span QnnNodeGroup::GetNodeUnits() const { + const size_t num_node_units = node_units_.back() != nullptr ? 6 : 5; + return gsl::make_span(node_units_.data(), num_node_units); } const NodeUnit* QnnNodeGroup::GetTargetNodeUnit() const { - return &conv_node_unit_; + const size_t conv_index = node_units_.back() != nullptr ? 3 : 2; + return node_units_[conv_index]; } } // namespace conv_act_fusion diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h index a195c86d2393a..43a3aa63fe9ea 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h @@ -27,7 +27,9 @@ namespace conv_act_fusion { class QnnNodeGroup : public IQnnNodeGroup { public: - QnnNodeGroup(gsl::span dq_node_units, + QnnNodeGroup(const NodeUnit& dq_node_unit_0, + const NodeUnit& dq_node_unit_1, + const NodeUnit* dq_node_unit_2, const NodeUnit& conv_node_unit, const NodeUnit& activation_node_unit, const NodeUnit& q_node_unit); @@ -35,15 +37,12 @@ class QnnNodeGroup : public IQnnNodeGroup { Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; - std::vector GetNodeUnits() const override; + gsl::span GetNodeUnits() const override; const NodeUnit* GetTargetNodeUnit() const override; std::string_view Type() const override { return "ConvActivationFusion"; } private: - std::array dq_node_units_; // Last DQ is nullptr if bias is missing. - const NodeUnit& conv_node_unit_; - const NodeUnit& activation_node_unit_; - const NodeUnit& q_node_unit_; + std::array node_units_; // Last elem is nullptr if bias DQ is missing. }; } // namespace conv_act_fusion diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc index ac782b5b1420a..e31219c8b3b76 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc @@ -101,23 +101,23 @@ std::unique_ptr TryDQQFusion( namespace dq_q_fusion { QnnNodeGroup::QnnNodeGroup(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit) - : dq_node_unit_(dq_node_unit), q_node_unit_(q_node_unit) { + : node_units_{&dq_node_unit, &q_node_unit} { } Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { - return QnnDQQFusionAdd(qmw, dq_node_unit_, q_node_unit_, logger, /*validate*/ true); + return QnnDQQFusionAdd(qmw, *node_units_[0], *node_units_[1], logger, /*validate*/ true); } Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { - return QnnDQQFusionAdd(qmw, dq_node_unit_, q_node_unit_, logger, /*validate*/ false); + return QnnDQQFusionAdd(qmw, *node_units_[0], *node_units_[1], logger, /*validate*/ false); } -std::vector QnnNodeGroup::GetNodeUnits() const { - return std::vector{&dq_node_unit_, &q_node_unit_}; +gsl::span QnnNodeGroup::GetNodeUnits() const { + return node_units_; } const NodeUnit* QnnNodeGroup::GetTargetNodeUnit() const { - return &dq_node_unit_; + return node_units_[0]; } } // namespace dq_q_fusion diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h index 2e5b612c41a81..c5d779c8234ff 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h @@ -43,13 +43,12 @@ class QnnNodeGroup : public IQnnNodeGroup { Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; - std::vector GetNodeUnits() const override; + gsl::span GetNodeUnits() const override; const NodeUnit* GetTargetNodeUnit() const override; std::string_view Type() const override { return "DQQFusion"; } private: - const NodeUnit& dq_node_unit_; - const NodeUnit& q_node_unit_; + std::array node_units_; }; } // namespace dq_q_fusion diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc index 817e2190e7825..e77d613d607c6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc @@ -116,23 +116,23 @@ std::unique_ptr TryHardSigmoidMulFusion( namespace hs_mul_fusion { QnnNodeGroup::QnnNodeGroup(const NodeUnit& hardsigmoid_node_unit, const NodeUnit& mul_node_unit) - : hardsigmoid_node_unit_(hardsigmoid_node_unit), mul_node_unit_(mul_node_unit) { + : node_units_{&hardsigmoid_node_unit, &mul_node_unit} { } Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { - return QnnHardSigmoidMulFusionAdd(qmw, hardsigmoid_node_unit_, mul_node_unit_, logger, /*validate*/ true); + return QnnHardSigmoidMulFusionAdd(qmw, *node_units_[0], *node_units_[1], logger, /*validate*/ true); } Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { - return QnnHardSigmoidMulFusionAdd(qmw, hardsigmoid_node_unit_, mul_node_unit_, logger, /*validate*/ false); + return QnnHardSigmoidMulFusionAdd(qmw, *node_units_[0], *node_units_[1], logger, /*validate*/ false); } -std::vector QnnNodeGroup::GetNodeUnits() const { - return std::vector{&hardsigmoid_node_unit_, &mul_node_unit_}; +gsl::span QnnNodeGroup::GetNodeUnits() const { + return node_units_; } const NodeUnit* QnnNodeGroup::GetTargetNodeUnit() const { - return &hardsigmoid_node_unit_; + return node_units_[0]; } } // namespace hs_mul_fusion diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h index 1cfb6119e3acc..3b04dccf1f6a5 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h @@ -44,13 +44,12 @@ class QnnNodeGroup : public IQnnNodeGroup { Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; - std::vector GetNodeUnits() const override; + gsl::span GetNodeUnits() const override; const NodeUnit* GetTargetNodeUnit() const override; std::string_view Type() const override { return "HardSigmoidMulFusion"; } private: - const NodeUnit& hardsigmoid_node_unit_; - const NodeUnit& mul_node_unit_; + std::array node_units_; }; } // namespace hs_mul_fusion diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index 8486d20dd6065..7a5abd6c9c9e2 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -3,6 +3,7 @@ #include "core/providers/qnn/builder/qnn_node_group.h" +#include #include #include #include @@ -24,35 +25,35 @@ namespace qnn { class QnnNodeUnitWrapper : public IQnnNodeGroup { public: - QnnNodeUnitWrapper(const NodeUnit& node_unit) : node_unit_(node_unit) {} + QnnNodeUnitWrapper(const NodeUnit& node_unit) : node_unit_(&node_unit) {} ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnNodeUnitWrapper); Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override { - const std::string& op_type = node_unit_.OpType(); + const std::string& op_type = node_unit_->OpType(); const auto* op_builder = qnn::GetOpBuilder(op_type); ORT_RETURN_IF_NOT(op_builder != nullptr, "Operators of type `", op_type, "` are not supported by QNN EP.", op_type, " node `", - node_unit_.Name(), "` will not be assigned to QNN EP."); + node_unit_->Name(), "` will not be assigned to QNN EP."); - return op_builder->IsOpSupported(qmw, node_unit_, logger); + return op_builder->IsOpSupported(qmw, *node_unit_, logger); } Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override { - const std::string& op_type = node_unit_.OpType(); + const std::string& op_type = node_unit_->OpType(); const auto* op_builder = qnn::GetOpBuilder(op_type); ORT_RETURN_IF_NOT(op_builder != nullptr, "[QNN EP]: Missing OpBuilder for OpType ", op_type); - return op_builder->AddToModelBuilder(qmw, node_unit_, logger, /*do_op_validation*/ false); + return op_builder->AddToModelBuilder(qmw, *node_unit_, logger, /*do_op_validation*/ false); } - std::vector GetNodeUnits() const override { - return std::vector{&node_unit_}; + gsl::span GetNodeUnits() const override { + return gsl::span{&node_unit_, 1ULL}; } - const NodeUnit* GetTargetNodeUnit() const override { return &node_unit_; } + const NodeUnit* GetTargetNodeUnit() const override { return node_unit_; } std::string_view Type() const override { return "NodeUnitWrapper"; } private: - const NodeUnit& node_unit_; + const NodeUnit* node_unit_; }; using FusionFunc = std::unique_ptr (*)( @@ -106,6 +107,7 @@ Status GetQnnNodeGroups(/*out*/ std::vector>& qnn { std::unordered_map node_unit_to_qnn_node_group; + std::unordered_map fused_qnn_node_group_indices; std::vector> sorted_node_units; sorted_node_units.reserve(num_node_units); @@ -135,7 +137,7 @@ Status GetQnnNodeGroups(/*out*/ std::vector>& qnn if (fused_node_group) { const size_t index = tmp_qnn_node_groups.size(); - fused_node_group->index_ = index; + fused_qnn_node_group_indices[fused_node_group.get()] = index; for (const NodeUnit* fused_node_unit : fused_node_group->GetNodeUnits()) { assert(fused_node_unit != nullptr); @@ -151,19 +153,18 @@ Status GetQnnNodeGroups(/*out*/ std::vector>& qnn const auto it = node_unit_to_qnn_node_group.find(node_unit); if (it != node_unit_to_qnn_node_group.end()) { // Already handled this NodeUnit. - gsl::not_null qnn_node_group = it->second; - if (node_unit == qnn_node_group->GetTargetNodeUnit()) { - sorted_qnn_node_group_indices.push_back(qnn_node_group->index_); + gsl::not_null fused_qnn_node_group = it->second; + if (node_unit == fused_qnn_node_group->GetTargetNodeUnit()) { + sorted_qnn_node_group_indices.push_back(fused_qnn_node_group_indices[fused_qnn_node_group]); } continue; } const size_t index = tmp_qnn_node_groups.size(); - auto fused_node_group = std::make_unique(*node_unit); - fused_node_group->index_ = index; - tmp_qnn_node_groups.push_back(std::move(fused_node_group)); + auto qnn_node_group = std::make_unique(*node_unit); - node_unit_to_qnn_node_group.insert({node_unit, fused_node_group.get()}); + node_unit_to_qnn_node_group.insert({node_unit, qnn_node_group.get()}); + tmp_qnn_node_groups.push_back(std::move(qnn_node_group)); sorted_qnn_node_group_indices.push_back(index); } @@ -176,7 +177,6 @@ Status GetQnnNodeGroups(/*out*/ std::vector>& qnn for (auto index : sorted_qnn_node_group_indices) { assert(index < tmp_qnn_node_groups.size()); std::unique_ptr qnn_node_group = std::move(tmp_qnn_node_groups[index]); - qnn_node_group->index_ = qnn_node_groups.size(); qnn_node_groups.push_back(std::move(qnn_node_group)); }