Skip to content

Commit

Permalink
[TIR] Validate tir::Buffer axis_separators on construction (#17219)
Browse files Browse the repository at this point in the history
* [TIR] Validate tir::Buffer axis_separators on construction

Prior to this commit, the `axis_separators` field of a TIR buffer
wasn't validated until the `tir.FlattenBuffer` legalization pass.
Delaying the error until this point makes it difficult to determine
where it invalid `axis_separators` were initially defined.

This commit updates the `tir::Buffer` constructor to validate the
`axis_separators` field immediately, allowing these invalid values to
be caught on construction.

Closes #17215

* Update metaschedule primitive to only set axis_separators of alloc

* Allow axis separators to be increasing, rather than strictly increasing
  • Loading branch information
Lunderberg authored Aug 5, 2024
1 parent cd09ab6 commit bd7f1f8
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 25 deletions.
45 changes: 30 additions & 15 deletions src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,24 +334,37 @@ inline Array<PrimExpr> BufferOffset(const BufferNode* n, Array<PrimExpr> index,
return offsets;
}

Buffer Buffer::GetFlattenedBuffer() const {
auto self = operator->();

static void ValidateAxisSeparators(const Array<IntImm>& axis_separators, size_t buffer_dim) {
// These checks ensure that all output axes contain at least one
// input axis.
for (size_t i = 0; (i + 1) < self->axis_separators.size(); i++) {
auto sep = self->axis_separators[i]->value;
auto next_sep = self->axis_separators[i + 1]->value;
ICHECK_LT(sep, next_sep) << "Axis separators must be in strictly increasing order.";
}
if (self->axis_separators.size()) {
auto first_sep = self->axis_separators[0]->value;
ICHECK_GT(first_sep, 0) << "First axis separator must be strictly greater than 0, "
<< "so that first output axis contains at least one input axis";
auto last_sep = self->axis_separators[self->axis_separators.size() - 1]->value;
ICHECK_LT(last_sep, self->shape.size())
<< "Last output axis must contain at least one input axis.";
for (size_t i = 0; (i + 1) < axis_separators.size(); i++) {
auto sep = axis_separators[i]->value;
auto next_sep = axis_separators[i + 1]->value;
CHECK_LE(sep, next_sep) << "ValueError: "
<< "Axis separators must be in increasing order, "
<< "but axis_separators[" << i << "] = " << sep
<< " is greater than or equal to axis_separators[" << (i + 1)
<< "] = " << next_sep << ".";
}
if (axis_separators.size()) {
auto first_sep = axis_separators[0]->value;
CHECK_GE(first_sep, 0) << "ValueError: "
<< "All axis separators must be non-negative. "
<< "However, the axis_separators[0] = " << first_sep;
auto last_sep = axis_separators[axis_separators.size() - 1]->value;
CHECK_LE(last_sep, buffer_dim)
<< "ValueError: "
<< "All axis separators must be within the range "
<< "0 <= sep <= buffer_dim. "
<< "However, the last axis_separators[" << (axis_separators.size() - 1)
<< "] = " << last_sep << " is greater than the buffer's dimensionality of " << buffer_dim;
}
}

Buffer Buffer::GetFlattenedBuffer() const {
auto self = operator->();

ValidateAxisSeparators(self->axis_separators, self->shape.size());

Array<PrimExpr> output_shape;
if (self->strides.size()) {
Expand Down Expand Up @@ -565,6 +578,8 @@ Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr>
ICHECK(data->type_annotation.as<PointerTypeNode>()->element_type.as<PrimTypeNode>())
<< "Variable " << data->name_hint << " does not point to a primitive.";

ValidateAxisSeparators(axis_separators, shape.size());

auto n = make_object<BufferNode>();
n->data = std::move(data);
n->dtype = dtype;
Expand Down
15 changes: 10 additions & 5 deletions src/tir/schedule/primitive/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1485,11 +1485,16 @@ class BufferAxisSeparatorMutator : private ReplaceBufferMutator {
if (it != buffer_var_map_.end()) {
const Buffer& new_source_buffer = it->second;
Buffer new_target_buffer = match_buffer->buffer;
new_target_buffer.CopyOnWrite()->axis_separators = new_source_buffer->axis_separators;
if (new_target_buffer->shape.size() != new_source_buffer->shape.size()) {
LOG(WARNING)
<< "Target buffer in match_buffer doesn't have the same dimensionality as its source "
"buffer. `axis_separators` for the target buffer might be incorrect.";

if (new_target_buffer->shape.size() == new_source_buffer->shape.size()) {
new_target_buffer.CopyOnWrite()->axis_separators = new_source_buffer->axis_separators;
} else {
new_target_buffer.CopyOnWrite()->axis_separators =
Array<IntImm>(new_source_buffer->axis_separators.size(), IntImm(DataType::Int(32), 0));
LOG(WARNING) << "Buffer view " << new_target_buffer
<< " has different dimensionality than backing buffer " << new_source_buffer
<< ". The `axis_separators` for " << new_target_buffer << "."
<< "`axis_separators` for the view might be incorrect.";
}
buffer_var_map_[new_target_buffer->data.get()] = new_target_buffer;
return MatchBufferRegion(new_target_buffer,
Expand Down
12 changes: 9 additions & 3 deletions tests/python/tir-base/test_tir_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ def test_buffer_index_merge_mult_mod():
A_stride = tvm.tir.decl_buffer((m, n), "float32", strides=(s, 1))

def assert_simplified_equal(index_simplified, index_direct):
tvm.ir.assert_structural_equal(
index_simplified, index_direct
), "index_simplified=%s, index_direct=%s" % (index_simplified, index_direct)
(
tvm.ir.assert_structural_equal(index_simplified, index_direct),
"index_simplified=%s, index_direct=%s" % (index_simplified, index_direct),
)

idxd = tvm.tir.indexdiv
idxm = tvm.tir.indexmod
Expand Down Expand Up @@ -276,5 +277,10 @@ def test_buffer_flatten_uses_axis_separators():
tvm.ir.assert_structural_equal(flat.shape, [4 * 16, 32])


def test_invalid_axis_separators_raises_exception():
with pytest.raises(ValueError):
tvm.tir.decl_buffer([1], axis_separators=[1, 2])


if __name__ == "__main__":
tvm.testing.main()
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ def element_wise_subregion_match_set_axis_separator(A: T.Buffer((128, 128), "flo
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1])
B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[0])
B_subregion0[()] = A[vi, vj] * T.float32(2)
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1])
B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[0])
C[vi, vj] = B_subregion1[()] + T.float32(1)


Expand Down

0 comments on commit bd7f1f8

Please sign in to comment.