diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc index 88a6496859a5..d640e9fa073e 100644 --- a/src/tir/usmp/utils.cc +++ b/src/tir/usmp/utils.cc @@ -181,7 +181,11 @@ Map GetIOPoolAllocations( } static Integer CalculateExtentsSize(const DataType& dtype, const Array& extents) { - size_t element_size_bytes = dtype.bytes(); + if (dtype.is_scalable_vector()) { + // We cannot statically calculate workspace for scalable types + return Integer(); + } + size_t element_size_bytes = dtype.bytes() * dtype.lanes(); size_t num_elements = 1; for (const auto& ext : extents) { if (ext->IsInstance()) { diff --git a/tests/python/tir-analysis/test_tir_analysis_calculate_workspace.py b/tests/python/tir-analysis/test_tir_analysis_calculate_workspace.py index 12c892a04b07..00e9df9d8b34 100644 --- a/tests/python/tir-analysis/test_tir_analysis_calculate_workspace.py +++ b/tests/python/tir-analysis/test_tir_analysis_calculate_workspace.py @@ -91,6 +91,18 @@ def primfunc_local_allocates(placeholder_162: T.handle, placeholder_163: T.handl # fmt: on +@T.prim_func +def prim_func_decl_vector_type(a: T.handle, b: T.handle): + T.func_attr({"tir.noalias": True}) + A = T.match_buffer(a, (4,), "float32x4") + B = T.match_buffer(b, (4,), "float32x4") + C = T.decl_buffer((4,), "float32x4") + for i in range(3): + with T.block("block"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + C[vi] + + @pytest.mark.parametrize("alignment,size,consts", [(1, 663552, 0), (10, 663560, 0)]) def test_global_allocates(alignment, size, consts): primfunc = primfunc_global_allocates @@ -105,6 +117,11 @@ def test_local_allocates(alignment, size, consts): assert tvm.tir.analysis.calculate_workspace_bytes(primfunc, alignment) == size +def test_vector_type(): + primfunc = prim_func_decl_vector_type + assert tvm.tir.analysis.calculate_workspace_bytes(primfunc, 1) == 64 + + if __name__ == "__main__": test_global_allocates() test_local_allocates()