diff --git a/onnxoptimizer/pass_registry.h b/onnxoptimizer/pass_registry.h index 51a2c1611..da4d451f8 100644 --- a/onnxoptimizer/pass_registry.h +++ b/onnxoptimizer/pass_registry.h @@ -60,6 +60,7 @@ #include "onnxoptimizer/passes/fuse_consecutive_unsqueezes.h" #include "onnxoptimizer/passes/eliminate_nop_with_unit.h" #include "onnxoptimizer/passes/rewrite_input_dtype.h" +#include "onnxoptimizer/passes/rewrite_where.h" namespace ONNX_NAMESPACE { namespace optimization { @@ -118,6 +119,7 @@ struct GlobalPassRegistry { registerPass<EliminateDuplicateInitializer>(); registerPass<AdjustSliceAndMatmul>(); registerPass<RewriteInputDtype>(); + registerPass<RewriteWhere>(); } ~GlobalPassRegistry() { diff --git a/onnxoptimizer/passes/rewrite_where.h b/onnxoptimizer/passes/rewrite_where.h new file mode 100644 index 000000000..61b9abcb9 --- /dev/null +++ b/onnxoptimizer/passes/rewrite_where.h @@ -0,0 +1,56 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +// ATTENTION: The code in this file is highly EXPERIMENTAL. +// Adventurous users should note that the APIs will probably change. + +#pragma once + +#include "onnxoptimizer/pass.h" +#include "onnxoptimizer/passes/pass_util.h" + +namespace ONNX_NAMESPACE { +namespace optimization { + +// where(not(b), x, y) -> where(b, y, x) +// https://github.com/microsoft/onnxruntime/blob/v1.15.1/onnxruntime/core/optimizer/not_where_fusion.h +struct RewriteWhere final : public PredicateBasedPass { + explicit RewriteWhere() + : PredicateBasedPass(PassType::Nop, PassEfficiency::Partial, + PassOptimizationType::Compute) {} + + std::string getPassName() const override { + return "rewrite_where"; + } + + bool patternMatchPredicate(Node* node) override { + bool isWhere = CheckKind(node, Symbol("Where")); + if (isWhere) { + return CheckKind(node->inputs()[0]->node(), Symbol("Not")); + } + return false; + } + bool runTransform(Node* node, Graph& graph, + NodeDestroyType& destroy_current) override { + destroy_current = NodeDestroyType::DestroyZero; + Node* previous_node = node->input(0)->node(); + if (previous_node->output()->uses().size() == 1) { + const bool replacing_success = + tryReplacingAllUsesWith(node->input(0), previous_node->input(0)); + if (!replacing_success) { + return false; + } + auto x = node->inputs()[1]; + auto y = node->inputs()[2]; + node->replaceInput(1, y); + node->replaceInput(2, x); + previous_node->destroy(); + return true; + } + return false; + } +}; + +} // namespace optimization +} // namespace ONNX_NAMESPACE diff --git a/onnxoptimizer/test/optimizer_test.py b/onnxoptimizer/test/optimizer_test.py index 5cd6b32fd..e591b41ae 100644 --- a/onnxoptimizer/test/optimizer_test.py +++ b/onnxoptimizer/test/optimizer_test.py @@ -4597,6 +4597,32 @@ def test_eliminate_consecutive_idempotent_op(self): assert optimized_model.graph.node[0].op_type == "Constant" assert optimized_model.graph.node[1].op_type == "Reshape" + def test_rewrite_where(self): + model = parser.parse_model(""" + < + ir_version: 7, + opset_import:["": 11] + > + agraph (bool[4] A, float[4] X, float[4] Y) => (float[4] F, float[4] H) + { + B = Not(A) + Z = Where(B, X, Y) + F = Sign(Z) + M = And(A,A) + G = Where(M, X, Y) + H = Sign(G) + } + """) + + optimized_model = self._optimized( + model,["rewrite_where"], True) + + assert len(optimized_model.graph.node) == 5 + assert set([i.op_type for i in optimized_model.graph.node]) == {'Where', 'And', 'Sign'} + assert optimized_model.graph.node[0].input == ['A', 'Y', 'X'] + assert optimized_model.graph.node[3].input == ['M', 'X', 'Y'] + + if __name__ == "__main__": unittest.main()