Skip to content

Commit

Permalink
[AOT] Correctly calculate workspace for vector types
Browse files Browse the repository at this point in the history
When calculating the size of the workspace for a given prim func, the
lanes of the data type was not being considered, meaning sizes
calculated for dtypes such as "float32x4" were smaller than what they
should be. This commit also considers lanes in the calculation.

Change-Id: I23a1329ad3c7910784a046e7007a104676ad3664
  • Loading branch information
lhutton1 committed Jun 10, 2024
1 parent 5d077c5 commit ad9a613
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/tir/usmp/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,11 @@ Map<String, PoolAllocation> GetIOPoolAllocations(
}

static Integer CalculateExtentsSize(const DataType& dtype, const Array<PrimExpr>& extents) {
size_t element_size_bytes = dtype.bytes();
if (dtype.is_scalable_vector()) {
// We cannot statically calculate workspace for scalable types
return Integer();
}
size_t element_size_bytes = dtype.bytes() * dtype.lanes();
size_t num_elements = 1;
for (const auto& ext : extents) {
if (ext->IsInstance<IntImmNode>()) {
Expand Down
17 changes: 17 additions & 0 deletions tests/python/tir-analysis/test_tir_analysis_calculate_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,18 @@ def primfunc_local_allocates(placeholder_162: T.handle, placeholder_163: T.handl
# fmt: on


@T.prim_func
def prim_func_decl_vector_type(a: T.handle, b: T.handle):
T.func_attr({"tir.noalias": True})
A = T.match_buffer(a, (4,), "float32x4")
B = T.match_buffer(b, (4,), "float32x4")
C = T.decl_buffer((4,), "float32x4")
for i in range(3):
with T.block("block"):
vi = T.axis.remap("S", [i])
B[vi] = A[vi] + C[vi]


@pytest.mark.parametrize("alignment,size,consts", [(1, 663552, 0), (10, 663560, 0)])
def test_global_allocates(alignment, size, consts):
primfunc = primfunc_global_allocates
Expand All @@ -105,6 +117,11 @@ def test_local_allocates(alignment, size, consts):
assert tvm.tir.analysis.calculate_workspace_bytes(primfunc, alignment) == size


def test_vector_type():
primfunc = prim_func_decl_vector_type
assert tvm.tir.analysis.calculate_workspace_bytes(primfunc, 1) == 64


if __name__ == "__main__":
test_global_allocates()
test_local_allocates()

0 comments on commit ad9a613

Please sign in to comment.