Skip to content

Commit

Permalink
[Relax] Remove segfault in R.call_tir_inplace validation (#17242)
Browse files Browse the repository at this point in the history
Prior to this commit, the error message produced when validating
`R.call_tir_inplace` included the shape of the argument that will be
mutated in-place.  This correctly caught and raised an error when the argument is a
tensor with known shape that is incompatible with the output tensor's
shape.  However, this same error message could be also be reached if
the input does not have `TensorStructInfo` at all, which would trigger
a segfault.

This commit updates the validation to print the argument's
`StructInfo` directly, rather than a field from the struct info.  This
correctly raises an error for the cases where the argument is not a
tensor, or is a tensor with unknown dimensionality, while still
printing the explicit shape of the mismatched tensor when avalable.
  • Loading branch information
Lunderberg authored Aug 6, 2024
1 parent 5f22be4 commit 591cf1e
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 75 deletions.
80 changes: 39 additions & 41 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -419,13 +419,19 @@ Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) {
// may result in an error if performed before normalization.
call = Downcast<Call>(NormalizeCallTIR(ctx, std::move(call)));

Array<StructInfo> sinfo_outputs = [&]() -> Array<StructInfo> {
auto out_sinfo = call->sinfo_args[0];
if (auto* tuple_output = out_sinfo.as<TupleStructInfoNode>()) {
return tuple_output->fields;
} else {
return {out_sinfo};
}
}();

// there must be an inplace index for each output
const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
size_t num_outputs = 1U;
if (auto* tup_info = call->sinfo_args[0].as<TupleStructInfoNode>()) {
num_outputs = tup_info->fields.size();
}
if (attrs->inplace_indices.size() != num_outputs) {
ICHECK(attrs);
if (attrs->inplace_indices.size() != sinfo_outputs.size()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "There must be an in-place index specified for each output");
}
Expand Down Expand Up @@ -459,45 +465,37 @@ Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) {
// input shape
// TODO(@slyubomirsky): eventually we will want to handle cases where that is not true
Tuple call_args = Downcast<Tuple>(call->args[1]);
if (attrs->inplace_indices.size() == 1) {
auto* out_sinfo = call->sinfo_args[0].as<TensorStructInfoNode>();
if (!out_sinfo) {
ctx->ReportFatal(Diagnostic::Error(call) << "The output struct info must be a tensor");

for (size_t i_output = 0; i_output < attrs->inplace_indices.size(); i_output++) {
auto i_input = attrs->inplace_indices[i_output].IntValue();
if (i_input == -1) {
continue;
}
auto* input_sinfo = GetStructInfoAs<TensorStructInfoNode>(
call_args->fields[attrs->inplace_indices[0].IntValue()]);
if (!input_sinfo || !input_sinfo->shape.defined() ||
!CanProveShapeEqual(input_sinfo->shape.value(), out_sinfo->shape.value(),
ctx->GetAnalyzer())) {

auto sinfo_output = sinfo_outputs[i_output];
auto tinfo_output = sinfo_output.as<TensorStructInfoNode>();

if (!tinfo_output || !tinfo_output->shape.defined() || tinfo_output->IsUnknownDtype()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "The shape of output 0 must match input "
<< attrs->inplace_indices[0].IntValue() << ", whereas we have "
<< out_sinfo->shape.value() << " in output 0 versus "
<< input_sinfo->shape.value() << " in input "
<< attrs->inplace_indices[0].IntValue());
<< "The output struct info for an in-place mutation must be a tensor "
<< "with a defined shape and dtype, "
<< "but output " << i_output << " has struct info " << sinfo_output);
}
} else {
auto out_sinfos = call->sinfo_args[0].as<TupleStructInfoNode>()->fields;
for (size_t i = 0; i < attrs->inplace_indices.size(); i++) {
if (attrs->inplace_indices[i].IntValue() == -1) {
continue;
}
auto* out_sinfo = out_sinfos[i].as<TensorStructInfoNode>();
if (!out_sinfo) {
ctx->ReportFatal(Diagnostic::Error(call) << "The output struct info must be a tensor");
}
auto* input_sinfo = GetStructInfoAs<TensorStructInfoNode>(
call_args->fields[attrs->inplace_indices[i].IntValue()]);
if (!input_sinfo || !input_sinfo->shape.defined() ||
!CanProveShapeEqual(input_sinfo->shape.value(), out_sinfo->shape.value(),
ctx->GetAnalyzer())) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "The shape of output " << i << " must match that of input "
<< attrs->inplace_indices[i].IntValue() << ", whereas we have "
<< out_sinfo->shape.value() << " in output " << i << " versus "
<< input_sinfo->shape.value() << " in input "
<< attrs->inplace_indices[i].IntValue());
}

auto sinfo_input = GetStructInfo(call_args->fields[i_input]);
auto tinfo_input = sinfo_input.as<TensorStructInfoNode>();

if (!tinfo_input ||
(tinfo_output->IsUnknownDtype() || tinfo_output->dtype != tinfo_input->dtype) ||
(!tinfo_input->shape.defined() ||
!CanProveShapeEqual(tinfo_input->shape.value(), tinfo_output->shape.value(),
ctx->GetAnalyzer()))) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "The input used for an in-place mutation must be "
<< "a tensor with identical shape and dtype as the output. "
<< "However, output " << i_output << " with struct info " << sinfo_output
<< " is specified as an in-place mutation of input " << i_input
<< " with struct info " << sinfo_input);
}
}

Expand Down
197 changes: 163 additions & 34 deletions tests/python/relax/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tvm import relax

import tvm.script
from tvm.script import tir as T, relax as R
from tvm.script import ir as I, tir as T, relax as R


def test_to_non_dataflow():
Expand Down Expand Up @@ -446,45 +446,174 @@ def foo(
tvm.ir.assert_structural_equal(Expected["foo"], new_mod["foo"], map_free_vars=True)


@pytest.mark.xfail()
def test_call_tir_inplace_repeated_input():
@tvm.script.ir_module
class Input:
@T.prim_func
def func(
A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32"), C: T.Buffer((2, 3), "int32")
):
T.evaluate(0)
with pytest.raises(tvm.error.DiagnosticError):

@tvm.script.ir_module
class Input:
@T.prim_func
def func(
A: T.Buffer((2, 3), "int32"),
B: T.Buffer((2, 3), "int32"),
C: T.Buffer((2, 3), "int32"),
):
T.evaluate(0)

@R.function
def foo(
x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32"), z: R.Tensor((2, 3), "int32")
) -> R.Tuple(R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")):
R.func_attr({"relax.force_pure": True})
gv0 = R.call_tir_inplace(
Input.func,
(x, y, z),
# repeated 0 -> that's an error
[0, 0],
[R.Tensor((2, 3), dtype="int32"), R.Tensor((2, 3), dtype="int32")],
)
return gv0
@R.function
def foo(
x: R.Tensor((2, 3), "int32"),
y: R.Tensor((2, 3), "int32"),
z: R.Tensor((2, 3), "int32"),
) -> R.Tuple(R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")):
R.func_attr({"relax.force_pure": True})
gv0 = R.call_tir_inplace(
Input.func,
(x, y, z),
# repeated 0 -> that's an error
[0, 0],
[R.Tensor((2, 3), dtype="int32"), R.Tensor((2, 3), dtype="int32")],
)
return gv0


@pytest.mark.xfail()
def test_call_tir_inplace_all_new():
@tvm.script.ir_module
class Input:
@T.prim_func
def func(A: T.Buffer((2, 3), "int32")):
T.evaluate(0)
with pytest.raises(tvm.error.DiagnosticError):

@R.function
def foo(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
R.func_attr({"relax.force_pure": True})
# cannot make the only output a fresh one
gv0 = R.call_tir_inplace(Input.func, x, -1, R.Tensor((2, 3), dtype="int32"))
return gv0
@tvm.script.ir_module
class Input:
@T.prim_func
def func(A: T.Buffer((2, 3), "int32")):
T.evaluate(0)

@R.function
def foo(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
R.func_attr({"relax.force_pure": True})
# cannot make the only output a fresh one
gv0 = R.call_tir_inplace(Input.func, x, -1, R.Tensor((2, 3), dtype="int32"))
return gv0


def test_inplace_mutation_with_tuple_argument_raises_error():
"""TIR PrimFuncs do not support Tuple arguments
The `R.call_tir_inplace` operator must receive an in-line tuple of
arguments, where each argument in the tuple may be expressed in
TIR. Here, `[[A]]` specifies a tuple of arguments, where the
first argument is itself a tuple. Since PrimFuncs do not support
Tuple arguments, this is invalid.
This is a regression test. In previous implementations, this
triggered a segfault rather than raising an exception.
"""
with pytest.raises(tvm.error.DiagnosticError):

@I.ir_module
class Module:
@R.function
def main(A: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"):
cls = Module
gv1 = R.call_tir_inplace(
cls.multiply_by_two,
[[A]],
out_sinfo=R.Tensor((16,), dtype="float32"),
inplace_indices=[0],
)
return gv1

@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer((16,), "float32")):
for i in range(16):
A[i] = A[i] * T.float32(2)


def test_inplace_mutation_with_non_tensor_argument_raises_error():
"""In-place argument must be a tensor
The `R.call_tir_inplace` operator must receive an in-line tuple of
arguments, where each argument in the tuple may be expressed in
TIR. Here, the argument `A` is not a tensor.
This is a regression test. In previous implementations, this
triggered a segfault rather than raising an exception.
"""
with pytest.raises(tvm.error.DiagnosticError):

@I.ir_module
class Module:
@R.function
def main(A: R.Object):
gv1 = R.call_tir_inplace(
Module.multiply_by_two,
[A],
out_sinfo=R.Tensor((16,), dtype="float32"),
inplace_indices=[0],
)
return gv1

@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer((16,), "float32")):
for i in range(16):
A[i] = A[i] * T.float32(2)


def test_inplace_mutation_with_incompatible_tensor_shape_raises_error():
"""In-place argument must have compatible shape
The `R.call_tir_inplace` operator must receive an in-line tuple of
arguments, where the shape of each in-place argument is compatible
with the corresponding output. Here, the shape of argument `A` is
different than the output's shape (`[32]` as opposed to `[16]`).
"""
with pytest.raises(tvm.error.DiagnosticError):

@I.ir_module
class Module:
@R.function
def main(A: R.Tensor([32], dtype="float32")):
gv1 = R.call_tir_inplace(
Module.multiply_by_two,
[A],
out_sinfo=R.Tensor((16,), dtype="float32"),
inplace_indices=[0],
)
return gv1

@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer((16,), "float32")):
for i in range(16):
A[i] = A[i] * T.float32(2)


def test_inplace_mutation_with_incompatible_tensor_dtype_raises_error():
"""In-place argument must have compatible dtype
The `R.call_tir_inplace` operator must receive an in-line tuple of
arguments, where the shape of each in-place argument is compatible
with the corresponding output. Here, the dtype of argument `A` is
different than the output's dtype (`int32` as opposed to `float32`).
"""
with pytest.raises(tvm.error.DiagnosticError):

@I.ir_module
class Module:
@R.function
def main(A: R.Tensor([16], dtype="int32")):
gv1 = R.call_tir_inplace(
Module.multiply_by_two,
[A],
out_sinfo=R.Tensor((16,), dtype="float32"),
inplace_indices=[0],
)
return gv1

@T.prim_func(private=True)
def multiply_by_two(A: T.Buffer((16,), "float32")):
for i in range(16):
A[i] = A[i] * T.float32(2)


if __name__ == "__main__":
Expand Down

0 comments on commit 591cf1e

Please sign in to comment.