From bac24e101039e680f847cc02ed544ba13354727f Mon Sep 17 00:00:00 2001 From: Anirudh Sundar Date: Mon, 10 Jun 2024 10:18:49 +0530 Subject: [PATCH] [Transform] Modify FuseTIR pass to propagate buffer attributes Arguments of a fused TIR PrimFunc generated from a fused relax function do not retain all the buffer attributes from their original PrimFuncs as the buffers are created from the StructInfo of the Relax vars. This patch collects a mapping of relax vars to its corresponding TIR buffers in a fused relax function and uses that info to propagate its buffer attributes such as `axis_separators` and `storage_scope` --- src/relax/transform/fuse_tir.cc | 140 +++++++++++++++--- tests/python/relax/test_transform_fuse_tir.py | 128 ++++++++++++++++ 2 files changed, 248 insertions(+), 20 deletions(-) diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index e712b5022a7d..b203b322ab96 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -362,6 +362,114 @@ class BlockNameDeduplicator : public tir::StmtMutator { namespace relax { +static Array GetInplaceOutputIndices(const Array& inplace_indices, + int num_inputs) { + Array ret; + int last_idx = num_inputs; + for (auto idx : inplace_indices) { + int i = idx.IntValue(); + if (i >= 0) { + ret.push_back(Integer(i)); + } else { + CHECK_EQ(i, -1) << "The only negative index expected in inplace_indices is -1, but got " << i; + ret.push_back(Integer(last_idx)); + last_idx++; + } + } + + return ret; +} + +class RelaxToTIRVarMapCollector : public ExprVisitor { + public: + explicit RelaxToTIRVarMapCollector(const IRModule& mod) : mod_(mod) {} + static Map Collect(const IRModule& mod, const Function& func) { + RelaxToTIRVarMapCollector visitor(mod); + visitor(func->body); + return visitor.relax_to_tir_var_map_; + } + + private: + void VisitBinding_(const VarBindingNode* binding) final { + current_var_ = binding->var; + ExprVisitor::VisitBinding_(binding); + } + + void VisitExpr_(const CallNode* call) { + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace"); + + ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_) + << "Only call_tir and call_tir_inplace are supported in primitive function, but got: " + << GetRef(call); + CollectVarMapping(call, current_var_, call->op == call_tir_inplace_op_); + } + + void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool in_place) { + GlobalVar gv = Downcast(call->args[0]); + tir::PrimFunc prim_func_ = Downcast(mod_->Lookup(gv)); + const auto& buffer_map = prim_func_->buffer_map; + const auto& tir_args = prim_func_->params; + + const auto& relax_args = Downcast(call->args[1])->fields; + + Array relax_results; + if (lhs_var->IsInstance()) { + relax_results = Downcast(lhs_var)->fields; + } else { + CHECK(lhs_var->IsInstance()) << "The lhs_var is expected to be either tuple or var"; + relax_results = {Downcast(lhs_var)}; + } + + size_t num_inputs = relax_args.size(); + size_t num_outputs = relax_results.size(); + + Array output_idxs; + if (in_place) { + const auto* attrs = call->attrs.as(); + CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call"; + output_idxs = GetInplaceOutputIndices(attrs->inplace_indices, num_inputs); + } else { + for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) { + output_idxs.push_back(i); + } + } + + // If the `expr` is already seen (present in the map), validate whether the mapped buffer is + // structurally equal to the `new_buf` passed + auto ValidateBufferCompatibility = [this](tir::Buffer new_buf, Expr expr) { + if (auto it = relax_to_tir_var_map_.find(expr); it != relax_to_tir_var_map_.end()) { + ICHECK(StructuralEqual()((*it).second, new_buf)) + << "Inconsistent buffers " << (*it).second << " and " << new_buf + << " mapped to the same relax var: " << expr; + } + }; + for (size_t i = 0; i < tir_args.size(); ++i) { + const auto& tir_var = tir_args[i]; + if (auto tir_buffer = buffer_map.Get(tir_var)) { + if (i < num_inputs) { + const auto& relax_var = relax_args[i]; + ValidateBufferCompatibility(tir_buffer.value(), relax_var); + relax_to_tir_var_map_.Set(relax_var, tir_buffer.value()); + } + if (auto it = std::find(output_idxs.begin(), output_idxs.end(), i); + it != output_idxs.end()) { + int result_idx = it - output_idxs.begin(); + const auto& relax_var = relax_results[result_idx]; + ValidateBufferCompatibility(tir_buffer.value(), relax_var); + relax_to_tir_var_map_.Set(relax_var, tir_buffer.value()); + } + } + } + } + + private: + /*! \brief The IRModule */ + const IRModule& mod_; + Map relax_to_tir_var_map_; + Var current_var_; +}; + class FusedTIRConstructor : public ExprVisitor { public: /*! @@ -391,10 +499,11 @@ class FusedTIRConstructor : public ExprVisitor { : mod_(mod), func_name_(func_name) {} void VisitExpr_(const FunctionNode* func) final { + auto relax_to_tir_var_map = RelaxToTIRVarMapCollector::Collect(mod_, GetRef(func)); std::vector> prim_func_params; for (const Var& relax_param : func->params) { size_t size_before = prim_func_params.size(); - CollectPrimFuncParams(relax_param, &prim_func_params); + CollectPrimFuncParams(relax_param, &prim_func_params, relax_to_tir_var_map.Get(relax_param)); auto param_buffers = [&]() -> Array { Array out; @@ -676,23 +785,6 @@ class FusedTIRConstructor : public ExprVisitor { MapArgsToBuffer(arg_list, buffer_list); } - static Array GetInplaceOutputIndices(const Array& inplace_indices, - int num_inputs) { - Array ret; - int last_idx = num_inputs; - for (auto idx : inplace_indices) { - int i = idx.IntValue(); - if (i >= 0) { - ret.push_back(Integer(i)); - } else { - ret.push_back(Integer(last_idx)); - last_idx++; - } - } - - return ret; - } - static Array GetPrimFuncOutputParams(const tir::PrimFunc& func, const Array& output_indices) { size_t n = func->params.size(); @@ -799,7 +891,8 @@ class FusedTIRConstructor : public ExprVisitor { * \param out The vector into which to collect the params/buffers */ static void CollectPrimFuncParams(const Var& relax_param, - std::vector>* out) { + std::vector>* out, + const tvm::runtime::Optional& tir_buffer_param) { auto struct_info = GetStructInfo(relax_param); CHECK(!struct_info.as()) @@ -814,7 +907,14 @@ class FusedTIRConstructor : public ExprVisitor { const auto* shape_expr = tensor->shape.as(); ICHECK(shape_expr) << "FuseTIR expects all Tensor parameters have a known shape."; DataType dtype = tensor->dtype; - tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name_hint); + tir::Buffer buffer; + if (tir_buffer_param.defined()) { + buffer = + tir::decl_buffer(shape_expr->values, dtype, name_hint, tir_buffer_param.value().scope(), + tir_buffer_param.value()->axis_separators); + } else { + buffer = tir::decl_buffer(shape_expr->values, dtype, name_hint); + } out->push_back(std::move(buffer)); } else if (const auto* prim_value = struct_info.as()) { diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 90baeaad04bb..99e7a5d2b737 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +import pytest + import tvm import tvm.testing from tvm import relax, topi @@ -2314,5 +2316,131 @@ def take( _check(Before, Before) +def test_fuse_with_axis_separators(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def add(a: T.handle, b: T.handle, c: T.handle): + A = T.match_buffer(a, [T.int64(16), T.int64(32)], "float32", axis_separators=[1]) + B = T.match_buffer(b, [T.int64(16), T.int64(32)], "float32", axis_separators=[1]) + C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1]) + + for iters in T.grid(T.int64(16), T.int64(32)): + with T.block("compute"): + i, j = T.axis.remap("SS", iters) + C[i, j] = A[i, j] + B[i, j] + + @R.function(private=True) + def fused_function( + x: R.Tensor([T.int64(16), T.int64(32)], "float32"), + y: R.Tensor([T.int64(16), T.int64(32)], "float32"), + z: R.Tensor([T.int64(16), T.int64(32)], "float32"), + ) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"): + R.func_attr({"Primitive": 1}) + cls = Before + with R.dataflow(): + w = R.call_tir( + cls.add, [x, y], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32") + ) + out = R.call_tir( + cls.add, [w, z], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32") + ) + R.output(out) + return out + + @R.function + def main( + x: R.Tensor([T.int64(16), T.int64(32)], "float32"), + y: R.Tensor([T.int64(16), T.int64(32)], "float32"), + z: R.Tensor([T.int64(16), T.int64(32)], "float32"), + ) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"): + cls = Before + with R.dataflow(): + gv = cls.fused_function(x, y, z) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def fused_function(x: T.handle, y: T.handle, z: T.handle, c: T.handle): + T.func_attr({"tir.noalias": True}) + X = T.match_buffer(x, [T.int64(16), T.int64(32)], "float32", axis_separators=[1]) + Y = T.match_buffer(y, [T.int64(16), T.int64(32)], "float32", axis_separators=[1]) + Z = T.match_buffer(z, [T.int64(16), T.int64(32)], "float32", axis_separators=[1]) + C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1]) + Temp = T.alloc_buffer(X.shape, "float32", axis_separators=[1]) + for iters in T.grid(*X.shape): + with T.block("compute_Y"): + i, j = T.axis.remap("SS", iters) + Temp[i, j] = X[i, j] + Y[i, j] + + for iters in T.grid(*X.shape): + with T.block("compute_Z"): + i, j = T.axis.remap("SS", iters) + C[i, j] = Temp[i, j] + Z[i, j] + + @R.function + def main( + x: R.Tensor([T.int64(16), T.int64(32)], "float32"), + y: R.Tensor([T.int64(16), T.int64(32)], "float32"), + z: R.Tensor([T.int64(16), T.int64(32)], "float32"), + ) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"): + cls = Expected + with R.dataflow(): + gv = R.call_tir( + cls.fused_function, + [x, y, z], + out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32"), + ) + R.output(gv) + return gv + + _check(Before, Expected) + + +def test_fuse_with_axis_separators_inconsistent_buffer_mapping(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def mul(a: T.handle, b: T.handle, c: T.handle): + A = T.match_buffer(a, [T.int64(16), T.int64(32)], "float32", axis_separators=[1]) + B = T.match_buffer(b, [T.int64(16), T.int64(32)], "float32", axis_separators=[]) + C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1]) + + for iters in T.grid(T.int64(16), T.int64(32)): + with T.block("compute"): + i, j = T.axis.remap("SS", iters) + C[i, j] = A[i, j] * B[i, j] + + @R.function(private=True) + def fused_function( + x: R.Tensor([T.int64(16), T.int64(32)], "float32"), + ) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"): + R.func_attr({"Primitive": 1}) + cls = Before + with R.dataflow(): + out = R.call_tir( + cls.mul, [x, x], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32") + ) + R.output(out) + return out + + @R.function + def main( + x: R.Tensor([T.int64(16), T.int64(32)], "float32"), + ) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"): + cls = Before + with R.dataflow(): + gv = cls.fused_function(x) + R.output(gv) + return gv + + with pytest.raises( + tvm.TVMError, match=r"Inconsistent buffers.*and.*mapped to the same relax var:.*" + ): + relax.transform.FuseTIR()(Before) + + if __name__ == "__main__": tvm.testing.main()