diff --git a/onnxoptimizer/pass_registry.h b/onnxoptimizer/pass_registry.h
index 51a2c1611..da4d451f8 100644
--- a/onnxoptimizer/pass_registry.h
+++ b/onnxoptimizer/pass_registry.h
@@ -60,6 +60,7 @@
 #include "onnxoptimizer/passes/fuse_consecutive_unsqueezes.h"
 #include "onnxoptimizer/passes/eliminate_nop_with_unit.h"
 #include "onnxoptimizer/passes/rewrite_input_dtype.h"
+#include "onnxoptimizer/passes/rewrite_where.h"
 
 namespace ONNX_NAMESPACE {
 namespace optimization {
@@ -118,6 +119,7 @@ struct GlobalPassRegistry {
     registerPass<EliminateDuplicateInitializer>();
     registerPass<AdjustSliceAndMatmul>();
     registerPass<RewriteInputDtype>();
+    registerPass<RewriteWhere>();
   }
 
   ~GlobalPassRegistry() {
diff --git a/onnxoptimizer/passes/rewrite_where.h b/onnxoptimizer/passes/rewrite_where.h
new file mode 100644
index 000000000..61b9abcb9
--- /dev/null
+++ b/onnxoptimizer/passes/rewrite_where.h
@@ -0,0 +1,56 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+// ATTENTION: The code in this file is highly EXPERIMENTAL.
+// Adventurous users should note that the APIs will probably change.
+
+#pragma once
+
+#include "onnxoptimizer/pass.h"
+#include "onnxoptimizer/passes/pass_util.h"
+
+namespace ONNX_NAMESPACE {
+namespace optimization {
+
+// where(not(b), x, y) -> where(b, y, x)
+// https://github.com/microsoft/onnxruntime/blob/v1.15.1/onnxruntime/core/optimizer/not_where_fusion.h
+struct RewriteWhere final : public PredicateBasedPass {
+  explicit RewriteWhere()
+      : PredicateBasedPass(PassType::Nop, PassEfficiency::Partial,
+                           PassOptimizationType::Compute) {}
+
+  std::string getPassName() const override {
+    return "rewrite_where";
+  }
+
+  bool patternMatchPredicate(Node* node) override {
+    bool isWhere = CheckKind(node, Symbol("Where"));
+    if (isWhere) {
+        return CheckKind(node->inputs()[0]->node(), Symbol("Not"));
+    }
+    return false;
+  }
+  bool runTransform(Node* node, Graph& graph,
+                    NodeDestroyType& destroy_current) override {
+        destroy_current = NodeDestroyType::DestroyZero;
+        Node* previous_node = node->input(0)->node();
+        if (previous_node->output()->uses().size() == 1) {
+          const bool replacing_success =
+              tryReplacingAllUsesWith(node->input(0), previous_node->input(0));
+          if (!replacing_success) {
+            return false;
+          }
+          auto x = node->inputs()[1];
+          auto y = node->inputs()[2];
+          node->replaceInput(1, y);
+          node->replaceInput(2, x);
+          previous_node->destroy();
+          return true;
+        }
+        return false;
+    }
+};
+
+}  // namespace optimization
+}  // namespace ONNX_NAMESPACE
diff --git a/onnxoptimizer/test/optimizer_test.py b/onnxoptimizer/test/optimizer_test.py
index 5cd6b32fd..e591b41ae 100644
--- a/onnxoptimizer/test/optimizer_test.py
+++ b/onnxoptimizer/test/optimizer_test.py
@@ -4597,6 +4597,32 @@ def test_eliminate_consecutive_idempotent_op(self):
         assert optimized_model.graph.node[0].op_type == "Constant"
         assert optimized_model.graph.node[1].op_type == "Reshape"
 
+    def test_rewrite_where(self):
+        model = parser.parse_model("""
+                <
+                    ir_version: 7,
+                    opset_import:["": 11]
+                >
+               agraph (bool[4] A, float[4] X, float[4] Y) => (float[4] F, float[4] H)
+               {
+                  B = Not(A)
+                  Z = Where(B, X, Y)
+                  F = Sign(Z)
+                  M = And(A,A)
+                  G = Where(M, X, Y)
+                  H = Sign(G)
+               }
+            """)
+
+        optimized_model = self._optimized(
+            model,["rewrite_where"], True)
+
+        assert len(optimized_model.graph.node) == 5
+        assert set([i.op_type for i in optimized_model.graph.node]) == {'Where', 'And', 'Sign'}
+        assert optimized_model.graph.node[0].input == ['A', 'Y', 'X']
+        assert optimized_model.graph.node[3].input == ['M', 'X', 'Y']
+
+
 
 if __name__ == "__main__":
     unittest.main()