Skip to content

Commit

Permalink
[Transform] Modify FuseTIR pass to propagate buffer attributes (#17075)
Browse files Browse the repository at this point in the history
Arguments of a fused TIR PrimFunc generated from a fused relax function do not retain all the buffer attributes from their original PrimFuncs as the buffers are created from the StructInfo of the Relax vars. This patch collects a mapping of relax vars to its corresponding TIR buffers in a fused relax function and uses that info to propagate its buffer attributes such as `axis_separators` and `storage_scope`
  • Loading branch information
quic-sanirudh authored Jun 17, 2024
1 parent 292ecfd commit 5bfca2e
Show file tree
Hide file tree
Showing 2 changed files with 248 additions and 20 deletions.
140 changes: 120 additions & 20 deletions src/relax/transform/fuse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,114 @@ class BlockNameDeduplicator : public tir::StmtMutator {

namespace relax {

static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& inplace_indices,
int num_inputs) {
Array<Integer> ret;
int last_idx = num_inputs;
for (auto idx : inplace_indices) {
int i = idx.IntValue();
if (i >= 0) {
ret.push_back(Integer(i));
} else {
CHECK_EQ(i, -1) << "The only negative index expected in inplace_indices is -1, but got " << i;
ret.push_back(Integer(last_idx));
last_idx++;
}
}

return ret;
}

class RelaxToTIRVarMapCollector : public ExprVisitor {
public:
explicit RelaxToTIRVarMapCollector(const IRModule& mod) : mod_(mod) {}
static Map<Expr, tir::Buffer> Collect(const IRModule& mod, const Function& func) {
RelaxToTIRVarMapCollector visitor(mod);
visitor(func->body);
return visitor.relax_to_tir_var_map_;
}

private:
void VisitBinding_(const VarBindingNode* binding) final {
current_var_ = binding->var;
ExprVisitor::VisitBinding_(binding);
}

void VisitExpr_(const CallNode* call) {
static const Op& call_tir_op_ = Op::Get("relax.call_tir");
static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");

ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_)
<< "Only call_tir and call_tir_inplace are supported in primitive function, but got: "
<< GetRef<Expr>(call);
CollectVarMapping(call, current_var_, call->op == call_tir_inplace_op_);
}

void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool in_place) {
GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
const auto& buffer_map = prim_func_->buffer_map;
const auto& tir_args = prim_func_->params;

const auto& relax_args = Downcast<Tuple>(call->args[1])->fields;

Array<Expr> relax_results;
if (lhs_var->IsInstance<TupleNode>()) {
relax_results = Downcast<Tuple>(lhs_var)->fields;
} else {
CHECK(lhs_var->IsInstance<VarNode>()) << "The lhs_var is expected to be either tuple or var";
relax_results = {Downcast<Var>(lhs_var)};
}

size_t num_inputs = relax_args.size();
size_t num_outputs = relax_results.size();

Array<Integer> output_idxs;
if (in_place) {
const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
output_idxs = GetInplaceOutputIndices(attrs->inplace_indices, num_inputs);
} else {
for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
output_idxs.push_back(i);
}
}

// If the `expr` is already seen (present in the map), validate whether the mapped buffer is
// structurally equal to the `new_buf` passed
auto ValidateBufferCompatibility = [this](tir::Buffer new_buf, Expr expr) {
if (auto it = relax_to_tir_var_map_.find(expr); it != relax_to_tir_var_map_.end()) {
ICHECK(StructuralEqual()((*it).second, new_buf))
<< "Inconsistent buffers " << (*it).second << " and " << new_buf
<< " mapped to the same relax var: " << expr;
}
};
for (size_t i = 0; i < tir_args.size(); ++i) {
const auto& tir_var = tir_args[i];
if (auto tir_buffer = buffer_map.Get(tir_var)) {
if (i < num_inputs) {
const auto& relax_var = relax_args[i];
ValidateBufferCompatibility(tir_buffer.value(), relax_var);
relax_to_tir_var_map_.Set(relax_var, tir_buffer.value());
}
if (auto it = std::find(output_idxs.begin(), output_idxs.end(), i);
it != output_idxs.end()) {
int result_idx = it - output_idxs.begin();
const auto& relax_var = relax_results[result_idx];
ValidateBufferCompatibility(tir_buffer.value(), relax_var);
relax_to_tir_var_map_.Set(relax_var, tir_buffer.value());
}
}
}
}

private:
/*! \brief The IRModule */
const IRModule& mod_;
Map<Expr, tir::Buffer> relax_to_tir_var_map_;
Var current_var_;
};

class FusedTIRConstructor : public ExprVisitor {
public:
/*!
Expand Down Expand Up @@ -391,10 +499,11 @@ class FusedTIRConstructor : public ExprVisitor {
: mod_(mod), func_name_(func_name) {}

void VisitExpr_(const FunctionNode* func) final {
auto relax_to_tir_var_map = RelaxToTIRVarMapCollector::Collect(mod_, GetRef<Function>(func));
std::vector<Variant<tir::Var, tir::Buffer>> prim_func_params;
for (const Var& relax_param : func->params) {
size_t size_before = prim_func_params.size();
CollectPrimFuncParams(relax_param, &prim_func_params);
CollectPrimFuncParams(relax_param, &prim_func_params, relax_to_tir_var_map.Get(relax_param));

auto param_buffers = [&]() -> Array<tir::Buffer> {
Array<tir::Buffer> out;
Expand Down Expand Up @@ -676,23 +785,6 @@ class FusedTIRConstructor : public ExprVisitor {
MapArgsToBuffer(arg_list, buffer_list);
}

static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& inplace_indices,
int num_inputs) {
Array<Integer> ret;
int last_idx = num_inputs;
for (auto idx : inplace_indices) {
int i = idx.IntValue();
if (i >= 0) {
ret.push_back(Integer(i));
} else {
ret.push_back(Integer(last_idx));
last_idx++;
}
}

return ret;
}

static Array<tir::Var> GetPrimFuncOutputParams(const tir::PrimFunc& func,
const Array<Integer>& output_indices) {
size_t n = func->params.size();
Expand Down Expand Up @@ -799,7 +891,8 @@ class FusedTIRConstructor : public ExprVisitor {
* \param out The vector into which to collect the params/buffers
*/
static void CollectPrimFuncParams(const Var& relax_param,
std::vector<Variant<tir::Var, tir::Buffer>>* out) {
std::vector<Variant<tir::Var, tir::Buffer>>* out,
const tvm::runtime::Optional<tir::Buffer>& tir_buffer_param) {
auto struct_info = GetStructInfo(relax_param);

CHECK(!struct_info.as<TupleStructInfoNode>())
Expand All @@ -814,7 +907,14 @@ class FusedTIRConstructor : public ExprVisitor {
const auto* shape_expr = tensor->shape.as<ShapeExprNode>();
ICHECK(shape_expr) << "FuseTIR expects all Tensor parameters have a known shape.";
DataType dtype = tensor->dtype;
tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name_hint);
tir::Buffer buffer;
if (tir_buffer_param.defined()) {
buffer =
tir::decl_buffer(shape_expr->values, dtype, name_hint, tir_buffer_param.value().scope(),
tir_buffer_param.value()->axis_separators);
} else {
buffer = tir::decl_buffer(shape_expr->values, dtype, name_hint);
}
out->push_back(std::move(buffer));

} else if (const auto* prim_value = struct_info.as<PrimStructInfoNode>()) {
Expand Down
128 changes: 128 additions & 0 deletions tests/python/relax/test_transform_fuse_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.

import pytest

import tvm
import tvm.testing
from tvm import relax, topi
Expand Down Expand Up @@ -2314,5 +2316,131 @@ def take(
_check(Before, Before)


def test_fuse_with_axis_separators():
@I.ir_module
class Before:
@T.prim_func(private=True)
def add(a: T.handle, b: T.handle, c: T.handle):
A = T.match_buffer(a, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
B = T.match_buffer(b, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])

for iters in T.grid(T.int64(16), T.int64(32)):
with T.block("compute"):
i, j = T.axis.remap("SS", iters)
C[i, j] = A[i, j] + B[i, j]

@R.function(private=True)
def fused_function(
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
R.func_attr({"Primitive": 1})
cls = Before
with R.dataflow():
w = R.call_tir(
cls.add, [x, y], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32")
)
out = R.call_tir(
cls.add, [w, z], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32")
)
R.output(out)
return out

@R.function
def main(
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
cls = Before
with R.dataflow():
gv = cls.fused_function(x, y, z)
R.output(gv)
return gv

@I.ir_module
class Expected:
@T.prim_func(private=True)
def fused_function(x: T.handle, y: T.handle, z: T.handle, c: T.handle):
T.func_attr({"tir.noalias": True})
X = T.match_buffer(x, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
Y = T.match_buffer(y, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
Z = T.match_buffer(z, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
Temp = T.alloc_buffer(X.shape, "float32", axis_separators=[1])
for iters in T.grid(*X.shape):
with T.block("compute_Y"):
i, j = T.axis.remap("SS", iters)
Temp[i, j] = X[i, j] + Y[i, j]

for iters in T.grid(*X.shape):
with T.block("compute_Z"):
i, j = T.axis.remap("SS", iters)
C[i, j] = Temp[i, j] + Z[i, j]

@R.function
def main(
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
cls = Expected
with R.dataflow():
gv = R.call_tir(
cls.fused_function,
[x, y, z],
out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32"),
)
R.output(gv)
return gv

_check(Before, Expected)


def test_fuse_with_axis_separators_inconsistent_buffer_mapping():
@I.ir_module
class Before:
@T.prim_func(private=True)
def mul(a: T.handle, b: T.handle, c: T.handle):
A = T.match_buffer(a, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
B = T.match_buffer(b, [T.int64(16), T.int64(32)], "float32", axis_separators=[])
C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])

for iters in T.grid(T.int64(16), T.int64(32)):
with T.block("compute"):
i, j = T.axis.remap("SS", iters)
C[i, j] = A[i, j] * B[i, j]

@R.function(private=True)
def fused_function(
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
R.func_attr({"Primitive": 1})
cls = Before
with R.dataflow():
out = R.call_tir(
cls.mul, [x, x], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32")
)
R.output(out)
return out

@R.function
def main(
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
cls = Before
with R.dataflow():
gv = cls.fused_function(x)
R.output(gv)
return gv

with pytest.raises(
tvm.TVMError, match=r"Inconsistent buffers.*and.*mapped to the same relax var:.*"
):
relax.transform.FuseTIR()(Before)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 5bfca2e

Please sign in to comment.