Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianlizarraga committed Jul 28, 2024
1 parent 44dc696 commit 576c2f8
Show file tree
Hide file tree
Showing 11 changed files with 82 additions and 89 deletions.
6 changes: 3 additions & 3 deletions onnxruntime/core/framework/node_unit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,9 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g
}
}

NodeUnit::NodeUnit(gsl::span<const Node*> dq_nodes, const Node& target_node,
gsl::span<const Node*> q_nodes, Type type,
gsl::span<NodeUnitIODef> inputs, gsl::span<NodeUnitIODef> outputs,
NodeUnit::NodeUnit(gsl::span<const Node* const> dq_nodes, const Node& target_node,
gsl::span<const Node* const> q_nodes, Type type,
gsl::span<const NodeUnitIODef> inputs, gsl::span<const NodeUnitIODef> outputs,
size_t input_edge_count, Node::EdgeSet output_edges)
: dq_nodes_(dq_nodes.begin(), dq_nodes.end()),
target_node_(target_node),
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/framework/node_unit.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ class NodeUnit {
public:
explicit NodeUnit(const Node& node);
explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group);
NodeUnit(gsl::span<const Node*> dq_nodes, const Node& target_node,
gsl::span<const Node*> q_nodes, Type type,
gsl::span<NodeUnitIODef> inputs, gsl::span<NodeUnitIODef> outputs,
NodeUnit(gsl::span<const Node* const> dq_nodes, const Node& target_node,
gsl::span<const Node* const> q_nodes, Type type,
gsl::span<const NodeUnitIODef> inputs, gsl::span<const NodeUnitIODef> outputs,
size_t input_edge_count, Node::EdgeSet output_edges);

Type UnitType() const noexcept { return type_; }
Expand Down
6 changes: 0 additions & 6 deletions onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,6 @@ class QnnModelWrapper {

~QnnModelWrapper() = default;

const QNN_INTERFACE_VER_TYPE& GetQnnInterface() const { return qnn_interface_; }
const Qnn_BackendHandle_t& GetQnnBackendHandle() const { return backend_handle_; }
const std::unordered_map<std::string, size_t>& GetInputIndexMap() const { return input_index_map_; }
const std::unordered_map<std::string, size_t>& GetOutputIndexMap() const { return output_index_map_; }
const std::unordered_set<std::string>& GetInitializerLookup() const { return initializer_lookup_; }

bool CreateQnnGraph(const Qnn_ContextHandle_t& context,
const std::string& graph_name,
const QnnGraph_Config_t** graph_configs = nullptr);
Expand Down
5 changes: 2 additions & 3 deletions onnxruntime/core/providers/qnn/builder/qnn_node_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include <gsl/gsl>
#include <memory>
#include <unordered_map>
#include <vector>
Expand All @@ -18,11 +19,9 @@ class IQnnNodeGroup {
virtual ~IQnnNodeGroup() = default;
virtual Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const = 0;
virtual Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const = 0;
virtual std::vector<const NodeUnit*> GetNodeUnits() const = 0;
virtual gsl::span<const NodeUnit* const> GetNodeUnits() const = 0;
virtual const NodeUnit* GetTargetNodeUnit() const = 0;
virtual std::string_view Type() const = 0;

size_t index_ = 0;
};

Status GetQnnNodeGroups(/*out*/ std::vector<std::unique_ptr<IQnnNodeGroup>>& qnn_node_groups,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ static Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper,
const size_t num_dqs = dq_node_units.size();
constexpr size_t max_num_dqs = 3;
ORT_RETURN_IF_NOT(num_dqs == 2 || num_dqs == max_num_dqs, "QDQ Conv should have 2 or 3 DQs");
ORT_RETURN_IF_NOT(conv_node_unit->OpType() == "Conv" && q_node_unit->OpType() == "QuantizeLinear");

std::array<const Node*, max_num_dqs> dq_nodes_buf = {};
for (size_t i = 0; i < num_dqs; i++) {
Expand Down Expand Up @@ -447,64 +448,66 @@ std::unique_ptr<IQnnNodeGroup> TryConvActivationFusion(QnnModelWrapper& qnn_mode
return nullptr;
}

return std::make_unique<conv_act_fusion::QnnNodeGroup>(dq_node_units, conv_node_unit,
*activation_node_unit, *q_node_unit);
return std::make_unique<conv_act_fusion::QnnNodeGroup>(*dq_node_units[0],
*dq_node_units[1],
dq_node_units.size() == 3 ? dq_node_units[2] : nullptr,
conv_node_unit,
*activation_node_unit,
*q_node_unit);
}

namespace conv_act_fusion {
QnnNodeGroup::QnnNodeGroup(gsl::span<const NodeUnit*> dq_node_units,
QnnNodeGroup::QnnNodeGroup(const NodeUnit& dq_node_unit_0,
const NodeUnit& dq_node_unit_1,
const NodeUnit* dq_node_unit_2,
const NodeUnit& conv_node_unit,
const NodeUnit& activation_node_unit,
const NodeUnit& q_node_unit)
: dq_node_units_{},
conv_node_unit_(conv_node_unit),
activation_node_unit_(activation_node_unit),
q_node_unit_(q_node_unit) {
assert(dq_node_units.size() <= dq_node_units_.size());
std::copy(dq_node_units.begin(), dq_node_units.end(), dq_node_units_.data());
: node_units_{} {
size_t i = 0;
node_units_[i++] = &dq_node_unit_0;
node_units_[i++] = &dq_node_unit_1;
if (dq_node_unit_2 != nullptr) {
node_units_[i++] = dq_node_unit_2;
}
node_units_[i++] = &conv_node_unit;
node_units_[i++] = &activation_node_unit;
node_units_[i++] = &q_node_unit;
assert((!dq_node_unit_2 && i == 5) || (dq_node_unit_2 && i == 6));
}

Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const {
const size_t num_dqs = dq_node_units_.back() != nullptr ? 3 : 2;
gsl::span<const NodeUnit* const> dq_node_units(dq_node_units_.data(), num_dqs);
const size_t num_dqs = node_units_.back() != nullptr ? 3 : 2;
gsl::span<const NodeUnit* const> dq_node_units(node_units_.data(), num_dqs);

return QnnConvActivationFusionAdd(qmw,
dq_node_units,
&conv_node_unit_,
&q_node_unit_,
node_units_[num_dqs], // Conv
node_units_[num_dqs + 2], // Q
logger,
/*validate*/ true);
}

Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const {
const size_t num_dqs = dq_node_units_.back() != nullptr ? 3 : 2;
gsl::span<const NodeUnit* const> dq_node_units(dq_node_units_.data(), num_dqs);
const size_t num_dqs = node_units_.back() != nullptr ? 3 : 2;
gsl::span<const NodeUnit* const> dq_node_units(node_units_.data(), num_dqs);

return QnnConvActivationFusionAdd(qmw,
dq_node_units,
&conv_node_unit_,
&q_node_unit_,
node_units_[num_dqs], // Conv
node_units_[num_dqs + 2], // Q
logger,
/*validate*/ false);
}

std::vector<const NodeUnit*> QnnNodeGroup::GetNodeUnits() const {
const size_t num_dqs = dq_node_units_.back() != nullptr ? 3 : 2;

std::vector<const NodeUnit*> node_units;
node_units.reserve(6);
for (size_t i = 0; i < num_dqs; i++) {
node_units.push_back(dq_node_units_[i]);
}
node_units.push_back(&conv_node_unit_);
node_units.push_back(&activation_node_unit_);
node_units.push_back(&q_node_unit_);

return node_units;
gsl::span<const NodeUnit* const> QnnNodeGroup::GetNodeUnits() const {
const size_t num_node_units = node_units_.back() != nullptr ? 6 : 5;
return gsl::make_span<const NodeUnit* const>(node_units_.data(), num_node_units);
}

const NodeUnit* QnnNodeGroup::GetTargetNodeUnit() const {
return &conv_node_unit_;
const size_t conv_index = node_units_.back() != nullptr ? 3 : 2;
return node_units_[conv_index];
}

} // namespace conv_act_fusion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,22 @@ namespace conv_act_fusion {

class QnnNodeGroup : public IQnnNodeGroup {
public:
QnnNodeGroup(gsl::span<const NodeUnit*> dq_node_units,
QnnNodeGroup(const NodeUnit& dq_node_unit_0,
const NodeUnit& dq_node_unit_1,
const NodeUnit* dq_node_unit_2,
const NodeUnit& conv_node_unit,
const NodeUnit& activation_node_unit,
const NodeUnit& q_node_unit);
ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnNodeGroup);

Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override;
Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override;
std::vector<const NodeUnit*> GetNodeUnits() const override;
gsl::span<const NodeUnit* const> GetNodeUnits() const override;
const NodeUnit* GetTargetNodeUnit() const override;
std::string_view Type() const override { return "ConvActivationFusion"; }

private:
std::array<const NodeUnit*, 3> dq_node_units_; // Last DQ is nullptr if bias is missing.
const NodeUnit& conv_node_unit_;
const NodeUnit& activation_node_unit_;
const NodeUnit& q_node_unit_;
std::array<const NodeUnit*, 6> node_units_; // Last elem is nullptr if bias DQ is missing.
};

} // namespace conv_act_fusion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,23 +101,23 @@ std::unique_ptr<IQnnNodeGroup> TryDQQFusion(

namespace dq_q_fusion {
QnnNodeGroup::QnnNodeGroup(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit)
: dq_node_unit_(dq_node_unit), q_node_unit_(q_node_unit) {
: node_units_{&dq_node_unit, &q_node_unit} {
}

Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const {
return QnnDQQFusionAdd(qmw, dq_node_unit_, q_node_unit_, logger, /*validate*/ true);
return QnnDQQFusionAdd(qmw, *node_units_[0], *node_units_[1], logger, /*validate*/ true);
}

Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const {
return QnnDQQFusionAdd(qmw, dq_node_unit_, q_node_unit_, logger, /*validate*/ false);
return QnnDQQFusionAdd(qmw, *node_units_[0], *node_units_[1], logger, /*validate*/ false);
}

std::vector<const NodeUnit*> QnnNodeGroup::GetNodeUnits() const {
return std::vector<const NodeUnit*>{&dq_node_unit_, &q_node_unit_};
gsl::span<const NodeUnit* const> QnnNodeGroup::GetNodeUnits() const {
return node_units_;
}

const NodeUnit* QnnNodeGroup::GetTargetNodeUnit() const {
return &dq_node_unit_;
return node_units_[0];
}

} // namespace dq_q_fusion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,12 @@ class QnnNodeGroup : public IQnnNodeGroup {

Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override;
Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override;
std::vector<const NodeUnit*> GetNodeUnits() const override;
gsl::span<const NodeUnit* const> GetNodeUnits() const override;
const NodeUnit* GetTargetNodeUnit() const override;
std::string_view Type() const override { return "DQQFusion"; }

private:
const NodeUnit& dq_node_unit_;
const NodeUnit& q_node_unit_;
std::array<const NodeUnit*, 2> node_units_;
};

} // namespace dq_q_fusion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,23 +116,23 @@ std::unique_ptr<IQnnNodeGroup> TryHardSigmoidMulFusion(
namespace hs_mul_fusion {

QnnNodeGroup::QnnNodeGroup(const NodeUnit& hardsigmoid_node_unit, const NodeUnit& mul_node_unit)
: hardsigmoid_node_unit_(hardsigmoid_node_unit), mul_node_unit_(mul_node_unit) {
: node_units_{&hardsigmoid_node_unit, &mul_node_unit} {
}

Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const {
return QnnHardSigmoidMulFusionAdd(qmw, hardsigmoid_node_unit_, mul_node_unit_, logger, /*validate*/ true);
return QnnHardSigmoidMulFusionAdd(qmw, *node_units_[0], *node_units_[1], logger, /*validate*/ true);
}

Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const {
return QnnHardSigmoidMulFusionAdd(qmw, hardsigmoid_node_unit_, mul_node_unit_, logger, /*validate*/ false);
return QnnHardSigmoidMulFusionAdd(qmw, *node_units_[0], *node_units_[1], logger, /*validate*/ false);
}

std::vector<const NodeUnit*> QnnNodeGroup::GetNodeUnits() const {
return std::vector<const NodeUnit*>{&hardsigmoid_node_unit_, &mul_node_unit_};
gsl::span<const NodeUnit* const> QnnNodeGroup::GetNodeUnits() const {
return node_units_;
}

const NodeUnit* QnnNodeGroup::GetTargetNodeUnit() const {
return &hardsigmoid_node_unit_;
return node_units_[0];
}

} // namespace hs_mul_fusion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,12 @@ class QnnNodeGroup : public IQnnNodeGroup {

Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override;
Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override;
std::vector<const NodeUnit*> GetNodeUnits() const override;
gsl::span<const NodeUnit* const> GetNodeUnits() const override;
const NodeUnit* GetTargetNodeUnit() const override;
std::string_view Type() const override { return "HardSigmoidMulFusion"; }

private:
const NodeUnit& hardsigmoid_node_unit_;
const NodeUnit& mul_node_unit_;
std::array<const NodeUnit*, 2> node_units_;
};

} // namespace hs_mul_fusion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "core/providers/qnn/builder/qnn_node_group.h"

#include <gsl/gsl>
#include <limits>
#include <memory>
#include <string>
Expand All @@ -24,35 +25,35 @@ namespace qnn {

class QnnNodeUnitWrapper : public IQnnNodeGroup {
public:
QnnNodeUnitWrapper(const NodeUnit& node_unit) : node_unit_(node_unit) {}
QnnNodeUnitWrapper(const NodeUnit& node_unit) : node_unit_(&node_unit) {}
ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnNodeUnitWrapper);

Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override {
const std::string& op_type = node_unit_.OpType();
const std::string& op_type = node_unit_->OpType();
const auto* op_builder = qnn::GetOpBuilder(op_type);
ORT_RETURN_IF_NOT(op_builder != nullptr, "Operators of type `", op_type,
"` are not supported by QNN EP.", op_type, " node `",
node_unit_.Name(), "` will not be assigned to QNN EP.");
node_unit_->Name(), "` will not be assigned to QNN EP.");

return op_builder->IsOpSupported(qmw, node_unit_, logger);
return op_builder->IsOpSupported(qmw, *node_unit_, logger);
}

Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override {
const std::string& op_type = node_unit_.OpType();
const std::string& op_type = node_unit_->OpType();
const auto* op_builder = qnn::GetOpBuilder(op_type);
ORT_RETURN_IF_NOT(op_builder != nullptr, "[QNN EP]: Missing OpBuilder for OpType ", op_type);
return op_builder->AddToModelBuilder(qmw, node_unit_, logger, /*do_op_validation*/ false);
return op_builder->AddToModelBuilder(qmw, *node_unit_, logger, /*do_op_validation*/ false);
}

std::vector<const NodeUnit*> GetNodeUnits() const override {
return std::vector<const NodeUnit*>{&node_unit_};
gsl::span<const NodeUnit* const> GetNodeUnits() const override {
return gsl::span<const NodeUnit* const>{&node_unit_, 1ULL};
}

const NodeUnit* GetTargetNodeUnit() const override { return &node_unit_; }
const NodeUnit* GetTargetNodeUnit() const override { return node_unit_; }
std::string_view Type() const override { return "NodeUnitWrapper"; }

private:
const NodeUnit& node_unit_;
const NodeUnit* node_unit_;
};

using FusionFunc = std::unique_ptr<IQnnNodeGroup> (*)(
Expand Down Expand Up @@ -106,6 +107,7 @@ Status GetQnnNodeGroups(/*out*/ std::vector<std::unique_ptr<IQnnNodeGroup>>& qnn

{
std::unordered_map<const NodeUnit*, const IQnnNodeGroup*> node_unit_to_qnn_node_group;
std::unordered_map<const IQnnNodeGroup*, size_t> fused_qnn_node_group_indices;
std::vector<gsl::not_null<const NodeUnit*>> sorted_node_units;
sorted_node_units.reserve(num_node_units);

Expand Down Expand Up @@ -135,7 +137,7 @@ Status GetQnnNodeGroups(/*out*/ std::vector<std::unique_ptr<IQnnNodeGroup>>& qnn

if (fused_node_group) {
const size_t index = tmp_qnn_node_groups.size();
fused_node_group->index_ = index;
fused_qnn_node_group_indices[fused_node_group.get()] = index;

for (const NodeUnit* fused_node_unit : fused_node_group->GetNodeUnits()) {
assert(fused_node_unit != nullptr);
Expand All @@ -151,19 +153,18 @@ Status GetQnnNodeGroups(/*out*/ std::vector<std::unique_ptr<IQnnNodeGroup>>& qnn
const auto it = node_unit_to_qnn_node_group.find(node_unit);
if (it != node_unit_to_qnn_node_group.end()) {
// Already handled this NodeUnit.
gsl::not_null<const IQnnNodeGroup*> qnn_node_group = it->second;
if (node_unit == qnn_node_group->GetTargetNodeUnit()) {
sorted_qnn_node_group_indices.push_back(qnn_node_group->index_);
gsl::not_null<const IQnnNodeGroup*> fused_qnn_node_group = it->second;
if (node_unit == fused_qnn_node_group->GetTargetNodeUnit()) {
sorted_qnn_node_group_indices.push_back(fused_qnn_node_group_indices[fused_qnn_node_group]);
}
continue;
}

const size_t index = tmp_qnn_node_groups.size();
auto fused_node_group = std::make_unique<QnnNodeUnitWrapper>(*node_unit);
fused_node_group->index_ = index;
tmp_qnn_node_groups.push_back(std::move(fused_node_group));
auto qnn_node_group = std::make_unique<QnnNodeUnitWrapper>(*node_unit);

node_unit_to_qnn_node_group.insert({node_unit, fused_node_group.get()});
node_unit_to_qnn_node_group.insert({node_unit, qnn_node_group.get()});
tmp_qnn_node_groups.push_back(std::move(qnn_node_group));
sorted_qnn_node_group_indices.push_back(index);
}

Expand All @@ -176,7 +177,6 @@ Status GetQnnNodeGroups(/*out*/ std::vector<std::unique_ptr<IQnnNodeGroup>>& qnn
for (auto index : sorted_qnn_node_group_indices) {
assert(index < tmp_qnn_node_groups.size());
std::unique_ptr<IQnnNodeGroup> qnn_node_group = std::move(tmp_qnn_node_groups[index]);
qnn_node_group->index_ = qnn_node_groups.size();
qnn_node_groups.push_back(std::move(qnn_node_group));
}

Expand Down

0 comments on commit 576c2f8

Please sign in to comment.