From fbfa92658568428b27c6ee5762ab7fe2f7c0b415 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sun, 17 Mar 2024 17:17:18 -0500 Subject: [PATCH] [Relax] Implement relax.transform.TopologicalSort (#16697) * [Relax] Implement relax.transform.TopologicalSort This commit implements a utility `relax.transform.TopologicalSort`, which can re-order the bindings that occur in a `relax.DataflowBlock`. This is not intended for use in a general-purpose optimization pipeline, but instead as a utility that may be used as needed in specific cases. For example, normalization of unit tests that should not depend on the order of variable binding. * Update docstring according to review comment --- python/tvm/relax/transform/__init__.py | 1 + python/tvm/relax/transform/transform.py | 23 + src/relax/transform/topological_sort.cc | 377 +++++++++++++++ .../relax/test_transform_topological_sort.py | 457 ++++++++++++++++++ 4 files changed, 858 insertions(+) create mode 100644 src/relax/transform/topological_sort.cc create mode 100644 tests/python/relax/test_transform_topological_sort.py diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index c3fb0f23be47..7daa36cd2ebc 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -72,6 +72,7 @@ StaticPlanBlockMemory, ToMixedPrecision, ToNonDataflow, + TopologicalSort, UpdateParamStructInfo, UpdateVDevice, VMBuiltinLower, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index e4c66558f5a2..9ef5133b7139 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -233,6 +233,29 @@ def ToNonDataflow() -> tvm.ir.transform.Pass: return _ffi_api.ToNonDataflow() # type: ignore +def TopologicalSort(order="depth-first", direction="from-inputs") -> tvm.ir.transform.Pass: + """Sort bindings in relax.Dataflow blocks in the order specified + + Parameters + ---------- + order: str + + The order in which bindings should be emitted. Allowed values + are "depth-first" and "breadth-first". + + direciton: str + + The direction in which the sort should be performed. Allowed + values are "from-inputs" and "from-outputs". + + Returns + ------- + ret: tvm.ir.transform.Pass + + """ + return _ffi_api.TopologicalSort(order, direction) # type: ignore + + def RemovePurityChecking() -> tvm.ir.transform.Pass: """Activate relax.force_pure on all pure functions in the module and unwrap all pure override ops into the normal versions. diff --git a/src/relax/transform/topological_sort.cc b/src/relax/transform/topological_sort.cc new file mode 100644 index 000000000000..a366ff4d1271 --- /dev/null +++ b/src/relax/transform/topological_sort.cc @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/topological_sort.cc + * \brief Perform a topological sort of Dataflow blocks + */ +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace { +struct InputNode {}; +struct OutputNode {}; + +using DataflowNode = std::variant; + +bool operator==(const DataflowNode& a, const DataflowNode& b) { + if (const tvm::relax::Var* var_a = std::get_if(&a)) { + if (const tvm::relax::Var* var_b = std::get_if(&b)) { + const tvm::relax::VarNode* ptr_a = var_a->get(); + const tvm::relax::VarNode* ptr_b = var_b->get(); + return ptr_a == ptr_b; + } + } + + return a.index() == b.index(); +} + +} // namespace + +template <> +struct std::hash { + std::size_t operator()(const DataflowNode& node) const noexcept { + if (const tvm::relax::Var* var = std::get_if(&node)) { + const tvm::relax::VarNode* ptr = var->get(); + std::hash hasher; + return hasher(ptr); + } else { + auto index = node.index(); + std::hash hasher; + return hasher(index); + } + } +}; + +namespace tvm { +namespace relax { + +namespace { + +enum class TraversalOrder { + DepthFirst, + BreadthFirst, +}; + +enum class StartingLocation { + FromInputs, + FromOutputs, +}; + +struct Dependencies { + std::vector binding_order; + std::unordered_map> downstream_users; + std::unordered_map> upstream_requirements; +}; + +class BindingOrderCollector : ExprVisitor { + public: + static Dependencies Collect(const Expr& expr) { + BindingOrderCollector visitor; + visitor.dependencies_.binding_order.push_back(InputNode()); + visitor(expr); + + // If there is a variable without any inputs (e.g. `R.const(1)`) + // or an unused variable, these must be handled somewhere, to + // ensure they are visited corrected. It's easiest to perform the + // depth/breadth-first search if handled here, with `NullOpt` + // acting as a special value, so that the later traversal doesn't + // need to check for this special case. + std::vector zero_input_bindings; + std::vector unused_bindings; + for (const auto& var : visitor.dependencies_.binding_order) { + if (std::holds_alternative(var)) { + if (!visitor.dependencies_.upstream_requirements.count(var)) { + zero_input_bindings.push_back(var); + } + if (!visitor.dependencies_.downstream_users.count(var)) { + unused_bindings.push_back(var); + } + } + } + + for (const auto& var : zero_input_bindings) { + visitor.dependencies_.upstream_requirements[var].push_back(InputNode()); + visitor.dependencies_.downstream_users[InputNode()].push_back(var); + } + for (auto it = unused_bindings.rbegin(); it != unused_bindings.rend(); it++) { + const auto& var = *it; + visitor.dependencies_.upstream_requirements[OutputNode()].push_front(var); + visitor.dependencies_.downstream_users[var].push_front(OutputNode()); + } + + visitor.dependencies_.binding_order.push_back(OutputNode()); + + return visitor.dependencies_; + } + + private: + void VisitVarDef(const Var& var) override { dependencies_.binding_order.push_back(var); } + + void VisitExpr_(const FunctionNode* op) override { + for (const auto& var : op->params) { + dependencies_.downstream_users[InputNode()].push_back(var); + dependencies_.upstream_requirements[var].push_back(InputNode()); + } + VisitExpr(op->body); + } + + void VisitBinding(const Binding& binding) override { + auto cache = current_binding_; + current_binding_ = binding->var; + ExprVisitor::VisitBinding(binding); + current_binding_ = cache; + } + + void VisitExpr_(const VarNode* op) override { + Var upstream_requirement = GetRef(op); + auto downstream_user = current_binding_; + + dependencies_.downstream_users[upstream_requirement].push_back(downstream_user); + dependencies_.upstream_requirements[downstream_user].push_back(upstream_requirement); + } + + DataflowNode current_binding_ = OutputNode(); + Dependencies dependencies_; +}; + +class TopologicalSorter : public ExprMutator { + public: + TopologicalSorter(TraversalOrder order, StartingLocation starting_location) + : order_(order), starting_location_(starting_location) {} + + Expr VisitExpr_(const FunctionNode* op) override { + auto cached = dependencies_; + dependencies_ = BindingOrderCollector::Collect(GetRef(op)); + + if (starting_location_ == StartingLocation::FromOutputs) { + std::reverse(dependencies_.binding_order.begin(), dependencies_.binding_order.end()); + } + if (order_ == TraversalOrder::DepthFirst) { + for (auto& [upstream_var, downstream_vars] : dependencies_.downstream_users) { + std::reverse(downstream_vars.begin(), downstream_vars.end()); + } + } + + auto output = ExprMutator::VisitExpr_(op); + dependencies_ = cached; + return output; + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { + auto block = GetRef(op); + + // A map from not-yet-defined variables to the binding that will + // define the variable. Items are removed from this map as they + // are collected into `new_bindings`. + std::unordered_map to_emit; + for (const auto& binding : block->bindings) { + to_emit.insert({binding->var, binding}); + } + + // A lookup map of `Var -> Var` edges, used to find the bindings + // that may be emitted next. When starting at the function + // inputs, this is the map from variables to the downstream + // variables that depend on them. When starting at the function + // outputs, this is the map from variables to the upstream + // variables that they require. + const auto& forward_edge_lookup = [&]() { + switch (starting_location_) { + case StartingLocation::FromInputs: + return dependencies_.downstream_users; + case StartingLocation::FromOutputs: + return dependencies_.upstream_requirements; + default: + LOG(FATAL) << "Invalid enum value for StartingLocation"; + } + }(); + + // A lookup map of `Var -> Var` edges, used to determine if a + // binding can legally be emitted. When starting at the function + // inputs, this is the map from variables to the upstream + // variables that they require. (i.e. A variable may not be + // defined earlier than its last input.) When starting at the + // function outputs, this is the map from variables to the + // downstream variables that depend on them. (i.e. A variable may + // not be defined later than its first usage.) + const auto& backward_edge_lookup = [&]() { + switch (starting_location_) { + case StartingLocation::FromInputs: + return dependencies_.upstream_requirements; + case StartingLocation::FromOutputs: + return dependencies_.downstream_users; + default: + LOG(FATAL) << "Invalid enum value for StartingLocation"; + } + }(); + + // The search state for nodes that must still be visited. When + // doing a depth-first search, this is used as a stack, with + // `push_back` and `pop_back`. When doing a breadth-first search, + // this is used as a queue, with `push_back` and `pop_front`. A + // `std::deque` is used to support these two use cases. + auto deque = [&]() -> std::deque { + switch (starting_location_) { + case StartingLocation::FromInputs: + return {InputNode()}; + case StartingLocation::FromOutputs: + return {OutputNode()}; + default: + LOG(FATAL) << "Invalid enum value for StartingLocation"; + } + }(); + + std::unordered_set visited; + + // Given a variable that has just been defined (or NullOpt for the + // function's output), mark nodes as ready to visit. + auto push_descendents_to_stack = [&](const DataflowNode& var) { + auto it = forward_edge_lookup.find(var); + if (it == forward_edge_lookup.end()) { + return; + } + const auto& adjacent_vars = it->second; + + for (const auto& adjacent_var : adjacent_vars) { + bool legal_to_output = [&]() -> bool { + if (visited.count(adjacent_var)) { + return false; + } + + auto it = backward_edge_lookup.find(adjacent_var); + ICHECK(it != backward_edge_lookup.end()); + const auto& prerequisites = it->second; + return std::all_of(prerequisites.begin(), prerequisites.end(), + [&visited](const auto& var) { return visited.count(var); }); + }(); + + if (legal_to_output) { + deque.push_back(adjacent_var); + } + } + }; + + std::vector new_bindings; + while (deque.size()) { + DataflowNode visiting; + switch (order_) { + case TraversalOrder::DepthFirst: { + visiting = deque.back(); + deque.pop_back(); + break; + } + case TraversalOrder::BreadthFirst: { + visiting = deque.front(); + deque.pop_front(); + break; + } + default: { + LOG(FATAL) << "Invalid value for TraversalOrder: " << static_cast(order_); + } + } + + if (auto var = std::get_if(&visiting)) { + if (auto iter_emit = to_emit.find(*var); iter_emit != to_emit.end()) { + new_bindings.push_back(iter_emit->second); + to_emit.erase(iter_emit); + } + } + visited.insert(visiting); + push_descendents_to_stack(visiting); + } + + ICHECK_EQ(to_emit.size(), 0) << "After visiting all bindings, " + << "no bindings should remain to emit. " + << "However, bindings " << + [&]() { + Array arr; + for (const auto& [var, binding] : to_emit) { + arr.push_back(var); + } + return arr; + }() << " still remain after emitting " + << Array(new_bindings.begin(), new_bindings.end()) + .Map([](const Binding& binding) { return binding->var; }); + + if (starting_location_ == StartingLocation::FromOutputs) { + std::reverse(new_bindings.begin(), new_bindings.end()); + } + + block.CopyOnWrite()->bindings = new_bindings; + return ExprMutator::VisitBindingBlock_(block.get()); + } + + private: + TraversalOrder order_; + StartingLocation starting_location_; + Dependencies dependencies_; +}; +} // namespace + +namespace transform { + +Pass TopologicalSort(TraversalOrder order, StartingLocation starting_location) { + auto pass_func = [=](Function func, IRModule, PassContext) { + TopologicalSorter mutator(order, starting_location); + return Downcast(mutator(func)); + }; + return relax::transform::CreateFunctionPass(pass_func, 0, "TopologicalSort", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.TopologicalSort") + .set_body_typed([](String order_str, String direction_str) -> Pass { + TraversalOrder order = [&]() { + if (order_str == "depth-first") { + return TraversalOrder::DepthFirst; + } else if (order_str == "breadth-first") { + return TraversalOrder::BreadthFirst; + } else { + LOG(FATAL) << "ValueError: " + << "Invalid value for traversal order: \"" << order_str << "\". " + << "Allowed values are \"depth-first\" or \"breadth-first\""; + } + }(); + + StartingLocation starting_location = [&]() { + if (direction_str == "from-inputs") { + return StartingLocation::FromInputs; + } else if (direction_str == "from-outputs") { + return StartingLocation::FromOutputs; + } else { + LOG(FATAL) << "ValueError: " + << "Invalid value for starting location: \"" << direction_str << "\". " + << "Allowed values are \"from-inputs\" or \"from-outputs\""; + } + }(); + + return TopologicalSort(order, starting_location); + }); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_topological_sort.py b/tests/python/relax/test_transform_topological_sort.py new file mode 100644 index 000000000000..3f11c081fa02 --- /dev/null +++ b/tests/python/relax/test_transform_topological_sort.py @@ -0,0 +1,457 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.script import ir as I, relax as R + + +class BaseCompare(tvm.testing.CompareBeforeAfter): + def transform(self): + return tvm.relax.transform.TopologicalSort( + order=self.order, + direction=self.direction, + ) + + +class TestDepthFirstFromInputs(BaseCompare): + """Sort DataflowBlock bindings with DFS, starting from inputs + + Starting with the inputs to the DataflowBlock, sort the variable + bindings according to their occurrence in a depth-first search. + """ + + order = "depth-first" + direction = "from-inputs" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor): + with R.dataflow(): + B1 = R.add(A, R.const(1)) + B2 = R.add(A, R.const(2)) + C1 = R.add(A, B1) + C2 = R.add(A, B2) + D = R.add(C1, C2) + R.output(D) + return D + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor): + with R.dataflow(): + B1 = R.add(A, R.const(1)) + C1 = R.add(A, B1) + B2 = R.add(A, R.const(2)) + C2 = R.add(A, B2) + D = R.add(C1, C2) + R.output(D) + return D + + +class TestDepthFirstFromInputWithConstant(BaseCompare): + """Topological sort must produce legal ordering. + + Here, both `C1` and `C2` use the input tensor `A`. However, they + also use the tensors `B1` and `B2`. The bindings for `C1` and + `C2` may not be emitted until after all their inputs have been + emitted. + + In addition, the bindings `B1` and `B2` do not require any of the + function inputs to compute. If the DFS only used the function + parameters as the initial search nodes, it would fail to output + these variable bindings. + """ + + order = "depth-first" + direction = "from-inputs" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor): + with R.dataflow(): + B1 = R.const(1) + B2 = R.const(2) + C2 = R.add(A, B2) + C1 = R.add(A, B1) + D2 = R.add(A, C2) + D1 = R.add(A, C1) + E = R.add(D1, D2) + R.output(E) + return E + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor): + with R.dataflow(): + B1 = R.const(1) + C1 = R.add(A, B1) + D1 = R.add(A, C1) + B2 = R.const(2) + C2 = R.add(A, B2) + D2 = R.add(A, C2) + E = R.add(D1, D2) + R.output(E) + return E + + +class TestDepthFirstFromInputWithMultipleInputs(BaseCompare): + """Use parameter order for deterministic sort + + Here, both `C1` and `C2` use the input tensor `A`, as well as + input tensors `B1` and `B2`, respectively. Since `B1` appears + before `B2`, `C1` should be sorted before `C2`. + """ + + order = "depth-first" + direction = "from-inputs" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor, B1: R.Tensor, B2: R.Tensor): + with R.dataflow(): + C2 = R.add(A, B2) + C1 = R.add(A, B1) + D2 = R.add(A, C2) + D1 = R.add(A, C1) + E = R.add(D1, D2) + R.output(E) + return E + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor, B1: R.Tensor, B2: R.Tensor): + with R.dataflow(): + C1 = R.add(A, B1) + D1 = R.add(A, C1) + C2 = R.add(A, B2) + D2 = R.add(A, C2) + E = R.add(D1, D2) + R.output(E) + return E + + +class TestDepthFirstBreakTiesByExistingOrder(BaseCompare): + """If DFS is ambiguous, provide deterministic output + + Here, both `B1` and `B2` use the input tensor `A`. Since there + are no other inputs for `B1` or `B2`, they remain in the same + relative order as the input function, and `B1` is emitted before + `B2`. The DFS then continues, placing `C1` immediately after + `B1`. + """ + + order = "depth-first" + direction = "from-inputs" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor): + with R.dataflow(): + B1 = R.add(A, R.const(1)) + B2 = R.add(A, R.const(2)) + C2 = R.add(A, B2) + C1 = R.add(A, B1) + D = R.add(C1, C2) + R.output(D) + return D + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor): + with R.dataflow(): + B1 = R.add(A, R.const(1)) + C1 = R.add(A, B1) + B2 = R.add(A, R.const(2)) + C2 = R.add(A, B2) + D = R.add(C1, C2) + R.output(D) + return D + + +class TestDepthFirstFromOutput(BaseCompare): + """Sort DataflowBlock bindings with DFS, starting from outputs + + Starting with the outputs to the DataflowBlock, sort the variable + bindings according to their occurrence in a depth-first search. + + Like `TestDepthFirstFromInputs`, but perform the search starting + at the output. + """ + + order = "depth-first" + direction = "from-outputs" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor): + with R.dataflow(): + B2 = R.add(A, R.const(2)) + B1 = R.add(A, R.const(1)) + C2 = R.add(A, B2) + C1 = R.add(A, B1) + D = R.add(C1, C2) + R.output(D) + return D + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor): + with R.dataflow(): + B1 = R.add(A, R.const(1)) + C1 = R.add(A, B1) + B2 = R.add(A, R.const(2)) + C2 = R.add(A, B2) + D = R.add(C1, C2) + R.output(D) + return D + + +class TestDepthFirstFromOutputTupleWithBinding(BaseCompare): + """A dataflow block may produce multiple outputs + + If a dataflow block produces multiple outputs, the result should + be sorted according to the order in which the outputs are used. + Here, `C1` is used before `C2`, so the expressions required to + compute `C1` are moved before the expressions required to compute + `C2`. + """ + + order = "depth-first" + direction = "from-outputs" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor): + with R.dataflow(): + B2 = R.add(A, R.const(2)) + B1 = R.add(A, R.const(1)) + C2 = R.add(A, B2) + C1 = R.add(A, B1) + R.output(C1, C2) + gv = (C1, C2) + return gv + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor): + with R.dataflow(): + B1 = R.add(A, R.const(1)) + C1 = R.add(A, B1) + B2 = R.add(A, R.const(2)) + C2 = R.add(A, B2) + R.output(C1, C2) + gv = (C1, C2) + return gv + + +class TestDepthFirstFromOutputTupleWithoutBinding(BaseCompare): + """A dataflow block may produce multiple outputs + + Like `TestDepthFirstFromOutputTupleWithBinding`, but the + DataflowBlock's outputs are not used as part of a variable + binding. Because in-line tuples are not normalized to variable + bindings, this case must be handled explicitly. + """ + + order = "depth-first" + direction = "from-outputs" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor): + with R.dataflow(): + B2 = R.add(A, R.const(2)) + B1 = R.add(A, R.const(1)) + C2 = R.add(A, B2) + C1 = R.add(A, B1) + R.output(C1, C2) + return (C1, C2) + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor): + with R.dataflow(): + B1 = R.add(A, R.const(1)) + C1 = R.add(A, B1) + B2 = R.add(A, R.const(2)) + C2 = R.add(A, B2) + R.output(C1, C2) + return (C1, C2) + + +class TestDepthFirstFromOutputWithUnusedVariables(BaseCompare): + """Sort DataflowBlock bindings with DFS, starting from outputs + + The variables `D1` and `D2` are unused, but must still appear + within the output DataflowBlock. + + This is analogous to `TestDepthFirstFromInputWithConstant`. + Similar to how a DFS starting from the function inputs can + accidentally skip expressions with no inputs, a DFS starting from + the function outputs can accidentally skip expressions that do not + contribute to the output. + """ + + order = "depth-first" + direction = "from-outputs" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor): + with R.dataflow(): + B2 = R.add(A, R.const(2)) + B1 = R.add(A, R.const(1)) + C2 = R.add(A, B2) + C1 = R.add(A, B1) + D1 = R.add(A, C1) + D2 = R.add(A, C2) + E = R.add(C1, C2) + R.output(E) + return E + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor): + with R.dataflow(): + B1 = R.add(A, R.const(1)) + C1 = R.add(A, B1) + D1 = R.add(A, C1) + B2 = R.add(A, R.const(2)) + C2 = R.add(A, B2) + D2 = R.add(A, C2) + E = R.add(C1, C2) + R.output(E) + return E + + +class TestDepthFirstFromInputWithUnusedParameters(BaseCompare): + """Sort DataflowBlock bindings with DFS, starting from inputs + + Functions may accept parameters that are not used. + """ + + order = "depth-first" + direction = "from-inputs" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor, Unused: R.Tensor): + with R.dataflow(): + B1 = R.add(A, R.const(1)) + B2 = R.add(A, R.const(2)) + C1 = R.add(A, B1) + C2 = R.add(A, B2) + D = R.add(C1, C2) + R.output(D) + return D + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor, Unused: R.Tensor): + with R.dataflow(): + B1 = R.add(A, R.const(1)) + C1 = R.add(A, B1) + B2 = R.add(A, R.const(2)) + C2 = R.add(A, B2) + D = R.add(C1, C2) + R.output(D) + return D + + +class TestBreadthFirst(BaseCompare): + order = "breadth-first" + direction = "from-inputs" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor): + with R.dataflow(): + B1 = R.add(A, R.const(1)) + C1 = R.add(A, B1) + B2 = R.add(A, R.const(2)) + C2 = R.add(A, B2) + D = R.add(C1, C2) + R.output(D) + return D + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor): + with R.dataflow(): + B1 = R.add(A, R.const(1)) + B2 = R.add(A, R.const(2)) + C1 = R.add(A, B1) + C2 = R.add(A, B2) + D = R.add(C1, C2) + R.output(D) + return D + + +class TestBreadthFirstBreakTiesByExistingOrder(BaseCompare): + order = "breadth-first" + direction = "from-inputs" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor): + with R.dataflow(): + B1 = R.add(A, R.const(2)) + C1 = R.add(A, B1) + B2 = R.add(A, R.const(1)) + C2 = R.add(A, B2) + D = R.add(C2, C1) + R.output(D) + return D + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor): + with R.dataflow(): + B2 = R.add(A, R.const(2)) + B1 = R.add(A, R.const(1)) + C2 = R.add(A, B2) + C1 = R.add(A, B1) + D = R.add(C1, C2) + R.output(D) + return D + + +if __name__ == "__main__": + tvm.testing.main()