From a26b99fe7020f3e990a248329e7b1462ce48dc7e Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Wed, 5 Jun 2024 10:24:09 +0000 Subject: [PATCH] Add comments --- python/tvm/relay/op/strategy/arm_cpu.py | 6 +++++- python/tvm/topi/arm_cpu/conv2d_alter_op.py | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index b196285f6bd0..35fd2b7a78d7 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -582,8 +582,12 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ wrap_topi_schedule(interleaved_schedule), name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", ) + # Non-quantized cases elif data.dtype in ["float32", "float16"]: - # Non-quantized cases + # The SME schedule for float16->float32 prearranges the two matrices to be multiplied + # using the ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE intrinsic which expects + # the reduction axis K as the second dimension of the matrix (i.e. shape = (_, K)). + # This means that the flattened weights matrix B needs to be transposed to (N, K). if ( target.features.has_sme and kernel.dtype == "float16" diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index ed398f80e6ef..2476cb92b915 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -172,6 +172,10 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): KH, KW, IC, OC = get_const_tuple(kernel.shape) K = KH * KW * IC N = OC + # The SME schedule for float16->float32 prearranges the two matrices to be multiplied + # using the ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE intrinsic which expects + # the reduction axis K as the second dimension of the matrix (i.e. shape = (_, K)). + # This means that the flattened weights matrix B needs to be transposed to (N, K). transposed_kernel_expr = relay.transpose(inputs[1], axes=[3, 0, 1, 2]) transposed_flattened_kernel_expr = relay.reshape(transposed_kernel_expr, newshape=(N, K)) new_kernel_expr = transposed_flattened_kernel_expr