diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 12f19462f704..35fd2b7a78d7 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -23,6 +23,7 @@ from tvm import relay, topi, tir from tvm.tir.schedule.analysis import has_block +from tvm.dlight.gpu.matmul import auto_inline_consumers from ....auto_scheduler import is_auto_scheduler_enabled from ....meta_schedule import is_meta_schedule_enabled @@ -255,9 +256,8 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): if is_aarch64 and data.dtype in ["float32", "float16"]: if ( target.features.has_sme - and data.dtype in ["float32"] - and kernel.dtype in ["float32"] - and out_type.dtype in ["float32"] + and kernel.dtype == data.dtype + and out_type.dtype == "float32" ): strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME), @@ -536,6 +536,7 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ """conv2d_winograd_without_weight_transform arm cpu strategy""" layout = attrs.data_layout data = inputs[0] + kernel = inputs[1] strategy = _op.OpStrategy() is_aarch64 = target.features.is_aarch64 has_dot_prod = target.features.has_dotprod @@ -581,13 +582,31 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ wrap_topi_schedule(interleaved_schedule), name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", ) + # Non-quantized cases elif data.dtype in ["float32", "float16"]: - # Non-quantized cases - strategy.add_implementation( - wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_hybrid_without_transform), - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid_without_transform), - name="conv2d_NHWC_hybrid_without_transform.arm_cpu", - ) + # The SME schedule for float16->float32 prearranges the two matrices to be multiplied + # using the ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE intrinsic which expects + # the reduction axis K as the second dimension of the matrix (i.e. shape = (_, K)). + # This means that the flattened weights matrix B needs to be transposed to (N, K). + if ( + target.features.has_sme + and kernel.dtype == "float16" + and data.dtype == "float16" + and out_type.dtype == "float32" + ): + strategy.add_implementation( + wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_SME_transposed_B), + lambda: None, + name="conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu", + ) + else: + strategy.add_implementation( + wrap_compute_conv2d_gemm( + topi.arm_cpu.compute_conv2d_NHWC_hybrid_without_transform + ), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid_without_transform), + name="conv2d_NHWC_hybrid_without_transform.arm_cpu", + ) else: raise RuntimeError( f"Unsupported conv2d_NHWC_without_transform layout {layout}" @@ -819,6 +838,8 @@ def arm_cpu_tir_strategy(sch: tir.Schedule) -> bool: topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch) return True elif has_block(sch, "conv2d_gemm_output"): + conv2d_block = sch.get_block("conv2d_gemm_output") + auto_inline_consumers(sch, conv2d_block) topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR(sch) return True diff --git a/python/tvm/topi/arm_cpu/arm_utils.py b/python/tvm/topi/arm_cpu/arm_utils.py index 5c4b3c045661..f690b2273112 100644 --- a/python/tvm/topi/arm_cpu/arm_utils.py +++ b/python/tvm/topi/arm_cpu/arm_utils.py @@ -68,8 +68,11 @@ def get_tiling_A(interleave_A, in_dtype, use_sme=False): tile_M = 4 tile_K = 16 elif use_sme: - tile_M = 2 * 4 * tvm.tir.vscale() - tile_K = 2 * 4 * tvm.tir.vscale() + tile_M = 2 * tvm.tir.get_vscale_expr(in_dtype) + if in_dtype == "float16": + tile_K = tvm.tir.get_vscale_expr(in_dtype) + else: + tile_K = 2 * tvm.tir.get_vscale_expr(in_dtype) else: # In non-SME, non-quantized cases, A is not interleaved. # We are loading 4 rows from A. @@ -139,17 +142,16 @@ def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False, tile_N = 4 tile_K = 16 elif use_sme: - tile_N = 2 * 4 * tvm.tir.vscale() - tile_K = 2 * 4 * tvm.tir.vscale() - # In non-SME, non-quantized cases, A is not interleaved. - elif use_scalable_vectors: + tile_N = 2 * tvm.tir.get_vscale_expr(in_dtype) if in_dtype == "float16": - # Each load from B' contains 32 * vscale elements (i.e. 32 * vscale columns from B) - tile_N = 32 * tvm.tir.vscale() + tile_K = tvm.tir.get_vscale_expr(in_dtype) else: - # Each load from B' contains 16 * vscale elements (i.e. 16 * vscale columns from B) - tile_N = 16 * tvm.tir.vscale() + tile_K = 2 * tvm.tir.get_vscale_expr(in_dtype) + # In non-SME, non-quantized cases, A is not interleaved. + elif use_scalable_vectors: + # Each load from B' contains 4 * scalable vectors (i.e. 4 * SVL columns from B) # We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B) + tile_N = 4 * tvm.tir.get_vscale_expr(in_dtype) tile_K = 4 elif in_dtype == "float16" and target.features.has_fp16_simd: # Each load from B' contains 32 elements (i.e. 32 columns from B) diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index d0fe251e7e23..a6c951c07830 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -24,6 +24,7 @@ from tvm.script import tir as T import tvm.contrib.nnpack from tvm.tir.schedule.analysis import has_block +from tvm.topi.arm_cpu.matmul import _get_transpose_interleave_intrin_name from ..utils import traverse_inline, get_const_tuple from .. import nn @@ -680,6 +681,43 @@ def compute_conv2d_NHWC_hybrid_SME(cfg, data, kernel, strides, padding, dilation ) +@autotvm.register_topi_compute("conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu") +def compute_conv2d_NHWC_SME_transposed_B( + cfg, + data, + kernel, + strides, + padding, + dilation, + out_dtype, + kernel_size, + output_channels, +): + """Compute conv2d NHWC hybrid SME transposed B""" + N, K = get_const_tuple(kernel.shape) + tile_N, tile_K = get_tiling_B_transformed(False, data.dtype, True, True) + pad_N, pad_K = tvm.topi.arm_cpu.arm_utils.get_conv2d_weights_padding(N, K, tile_N, tile_K) + + kernel = tvm.topi.nn.pad( + kernel, pad_before=(0, 0), pad_after=(pad_N, pad_K), name="weight_padding" + ) + + return compute_conv2d_gemm_without_weight_transform( + cfg, + data, + kernel, + strides, + padding, + dilation, + out_dtype, + kernel_size, + output_channels, + interleave_A=False, + use_scalable_vectors=True, + use_sme=True, + ) + + def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): """ Perform TIR scheduling for conv2d NHWC. @@ -688,7 +726,8 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): primfunc = sch.mod["main"] buffer_names = primfunc.params buffer_list = [primfunc.buffer_map[buf] for buf in buffer_names] - dtype = buffer_list[0].dtype + in_dtype = buffer_list[0].dtype + out_dtype = "float32" # Determine PrimFunc blocks block_list = [ @@ -698,6 +737,8 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): "A_padded_K", "A_padded_M", "weight_flatten", + "weight_padding", + "weight_transpose", "C", "conv2d_gemm_output", ] @@ -716,8 +757,8 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): M_padded = sch.get(m).extent N_padded = sch.get(n).extent K_padded = sch.get(k).extent - tile_M, tile_K = get_tiling_A(False, dtype, use_sme) - tile_N, _ = get_tiling_B_transformed(False, dtype, use_scalable_vectors, use_sme) + tile_M, tile_K = get_tiling_A(False, in_dtype, use_sme) + tile_N, _ = get_tiling_B_transformed(False, in_dtype, use_scalable_vectors, use_sme) tile_M = T.cast(tile_M, M_padded.dtype) tile_N = T.cast(tile_N, N_padded.dtype) tile_K = T.cast(tile_K, K_padded.dtype) @@ -729,12 +770,15 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): # pylint: disable=import-outside-toplevel from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes from tvm.tir.tensor_intrin.arm_cpu import ( - ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE, ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA, ARM_SME_INIT, get_sme_gemm_interleaved_mopa_2svlx2svl_intrin, ) + transpose_interleave_intrin_name = _get_transpose_interleave_intrin_name( + in_dtype, out_dtype + ) + # Interleave the padded im2col matrix utilizing the matrix tile interleave_t_A_block = sch.cache_read(gemm_block, 0, "global") sch.transform_layout(interleave_t_A_block, ("write", 0), lambda b, m, k: (b, k, m)) @@ -743,24 +787,40 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True) sch.parallel(b) sch.reorder(b, ko, mo, ki, mi) - sch.tensorize(ki, ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE) + sch.tensorize(ki, transpose_interleave_intrin_name) + + # Interleave the padded weights matrix utilizing the matrix tile + if in_dtype == "float16": + interleave_b_block = sch.cache_read(gemm_block, 1, "global") + sch.transform_layout(interleave_b_block, ("write", 0), lambda n, k: (k, n)) + n, k = sch.get_loops(interleave_b_block) + ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True) + no, ni = sch.split(n, factors=(None, tile_N), disable_predication=True) + sch.reorder(ko, no, ki, ni) + sch.tensorize(ki, transpose_interleave_intrin_name) # Split and reorder the loops of the GeMM for tensorization b, m, n, k = sch.get_loops(gemm_block) + tile_M, _ = get_tiling_A(False, out_dtype, True) + tile_N, _ = get_tiling_B_transformed(False, out_dtype, True, True) + tile_M = T.cast(tile_M, M_padded.dtype) + tile_N = T.cast(tile_N, N_padded.dtype) mo, mi = sch.split(m, factors=(None, tile_M), disable_predication=True) no, ni = sch.split(n, factors=(None, tile_N), disable_predication=True) sch.parallel(b) sch.reorder(b, mo, no, mi, ni, k) - # Tensorize the GeMM output matrix initialization to zero + # Tensorize the GeMM initialization init_block = sch.decompose_reduction(gemm_block, mi) sch.tensorize(sch.get_loops(init_block)[-2], ARM_SME_INIT) # Tensorize the GeMM update - sme_gemm_interleaved_intrin_name = ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}" + sme_gemm_interleaved_intrin_name = ( + ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}_{in_dtype}" + ) tvm.tir.TensorIntrin.register( sme_gemm_interleaved_intrin_name, - *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded, dtype), + *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded, in_dtype), override=True, ) sch.tensorize(mi, sme_gemm_interleaved_intrin_name) @@ -878,6 +938,11 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): weight_flatten_block = func_blocks["weight_flatten"] sch.compute_inline(weight_flatten_block) + # Weight transpose + if func_blocks["weight_transpose"] and func_blocks["weight_padding"]: + weight_padding_block = func_blocks["weight_padding"] + sch.compute_inline(weight_padding_block) + # Conv2d output block output_block = func_blocks["conv2d_gemm_output"] n, h, w, c = sch.get_loops(output_block) diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index fe4569ceb1ad..2476cb92b915 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -162,6 +162,34 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): inputs[0], new_kernel_expr, **new_attrs ) + if ( + topi_tmpl == "conv2d_NHWC_hybrid_SME.arm_cpu" + and data_dtype == "float16" + and kernel_dtype == "float16" + and out_dtype == "float32" + ): + assert data_layout == "NHWC" and kernel_layout == "HWIO" + KH, KW, IC, OC = get_const_tuple(kernel.shape) + K = KH * KW * IC + N = OC + # The SME schedule for float16->float32 prearranges the two matrices to be multiplied + # using the ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE intrinsic which expects + # the reduction axis K as the second dimension of the matrix (i.e. shape = (_, K)). + # This means that the flattened weights matrix B needs to be transposed to (N, K). + transposed_kernel_expr = relay.transpose(inputs[1], axes=[3, 0, 1, 2]) + transposed_flattened_kernel_expr = relay.reshape(transposed_kernel_expr, newshape=(N, K)) + new_kernel_expr = transposed_flattened_kernel_expr + new_kernel = te.placeholder((N, K), kernel.dtype) + new_workload_name = "conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu" + new_workload = autotvm.task.args_to_workload( + [data, new_kernel, strides, padding, dilation, out_dtype, (KH, KW), OC], + new_workload_name, + ) + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv2d_gemm_without_weight_transform( + inputs[0], new_kernel_expr, **new_attrs + ) + # Only microTVM does layout alteration for NHWC layout with real data types if data_layout == "NHWC" and data_dtype not in ["uint8", "int8"]: return None diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index 0c3908bb7017..e637aa91e5b4 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -289,6 +289,17 @@ def compute_conv2d_gemm_without_weight_transform( tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1] - tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1] ) + elif use_sme and in_dtype == "float16" and out_dtype == "float32": + assert len(B_interleaved_t.shape) == 2 + C = te.compute( + (batches, M_padded, N_padded), + lambda b, x, y: te.sum( + A[b, x, k].astype(out_dtype) * B_interleaved_t[y, k].astype(out_dtype), + axis=k, + ), + name="C", + ) + zero = tvm.tir.const(0) elif use_scalable_vectors or use_sme: assert len(B_interleaved_t.shape) == 2 C = te.compute( diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 8d61c622504b..205730ff22d6 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -654,7 +654,12 @@ def conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors=Fa kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N), name="weight_padding" ) - if use_sme or use_scalable_vectors: + if use_sme and kernel.dtype == "float16": + return te.compute( + (N_padded, K_padded), lambda x, y: kernel_flat[y, x], name="weight_transpose" + ) + + if use_scalable_vectors or use_sme: return kernel_flat if kernel.dtype in ["int8", "uint8"]: diff --git a/tests/python/relay/strategy/arm_cpu/test_conv2d.py b/tests/python/relay/strategy/arm_cpu/test_conv2d.py index 2708094afb08..f4fa250ecfe0 100644 --- a/tests/python/relay/strategy/arm_cpu/test_conv2d.py +++ b/tests/python/relay/strategy/arm_cpu/test_conv2d.py @@ -120,7 +120,8 @@ class TestConv2d_NCHW_Spatial_Pack(Conv2dTests): schedule_name = parameter("conv2d_nchw_spatial_pack.arm_cpu") -dtype = tvm.testing.parameter("float32") +in_dtype = tvm.testing.parameter("float16", "float32") +out_dtype = tvm.testing.parameter("float32") batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = tvm.testing.parameters( # Pad M, N, K @@ -154,30 +155,35 @@ class TestConv2d_NCHW_Spatial_Pack(Conv2dTests): @tvm.testing.fixture() -def ref_data(dtype, batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation): +def ref_data( + in_dtype, out_dtype, batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation +): np.random.seed(0) in_height = in_width = in_size a_shape = (batch, in_height, in_width, in_channel) w_shape = (kernel, kernel, in_channel, num_filter) - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - return a_np, w_np + a_np = np.random.uniform(size=a_shape).astype(in_dtype) + w_np = np.random.uniform(size=w_shape).astype(in_dtype) + dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) + b_np = tvm.topi.testing.conv2d_nhwc_python( + a_np.astype(out_dtype), dw_np.astype(out_dtype), stride, padding + ).astype(out_dtype) + return a_np, w_np, dw_np, b_np @pytest.mark.skipif( llvm_version_major() < 16, reason="SME is not supported in earlier versions of LLVM" ) @tvm.testing.requires_aprofile_aem_fvp -def test_conv2d_fp32(target, ref_data, dtype, stride, padding, dilation): - a_np, w_np = ref_data - dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) +def test_conv2d_sme(target, ref_data, in_dtype, out_dtype, stride, padding, dilation): + a_np, w_np, dw_np, b_np = ref_data kernel_size = get_const_tuple(w_np.shape[:2]) out_channels = w_np.shape[3] - x = relay.var("data", shape=a_np.shape, dtype=dtype) - weight = relay.const(w_np, dtype=dtype) + x = relay.var("data", shape=a_np.shape, dtype=in_dtype) + weight = relay.const(w_np, dtype=in_dtype) conv2d = relay.nn.conv2d( x, weight, @@ -188,7 +194,7 @@ def test_conv2d_fp32(target, ref_data, dtype, stride, padding, dilation): padding=get_pad_tuple(padding, dw_np.shape[:2]), data_layout="NHWC", kernel_layout="HWIO", - out_dtype=dtype, + out_dtype=out_dtype, ) func = relay.Function(relay.analysis.free_vars(conv2d), conv2d) @@ -198,7 +204,7 @@ def test_conv2d_fp32(target, ref_data, dtype, stride, padding, dilation): inputs = {"data": a_np} params = {} - ref_outputs = generate_ref_data(ir_mod, inputs, params) + ref_outputs = {"output": b_np} target = tvm.target.Target("llvm -mtriple=aarch64-none-elf -mattr=+v9.2a,+sme") runtime = tvm.relay.backend.Runtime("crt", {"system-lib": True}) @@ -220,9 +226,12 @@ def test_conv2d_fp32(target, ref_data, dtype, stride, padding, dilation): runtime=runtime, params=params, ) - generated_func = executor_factory.lowered_ir_mods.items()[0][1][ - "tvmgen_default_fused_nn_conv2d" - ] + + if in_dtype == "float16": + func_name = "tvmgen_default_fused_nn_contrib_conv2d_gemm_without_weight_transform" + else: + func_name = "tvmgen_default_fused_nn_conv2d" + generated_func = executor_factory.lowered_ir_mods.items()[0][1][func_name] extra_memory_in_bytes = calculate_extra_workspace_size_from_scalable_extents(generated_func, 4) test_model = AOTTestModel( diff --git a/tests/python/relay/strategy/test_select_implementation.py b/tests/python/relay/strategy/test_select_implementation.py index 01a914e793c1..b95bd4072af8 100644 --- a/tests/python/relay/strategy/test_select_implementation.py +++ b/tests/python/relay/strategy/test_select_implementation.py @@ -58,7 +58,7 @@ def test_concatenate(target, expected_implementation): assert impl.name == expected_implementation -def _get_conv2d_impl(dtype, target): +def _get_conv2d_impl(in_dtype, out_dtype, target): """Returns selected conv2d implementation for a given datatype and target""" data_shape = (1, 1, 1, 4) weight_shape = (1, 1, 4, 4) @@ -68,21 +68,24 @@ def _get_conv2d_impl(dtype, target): kernel_size = (1, 1) out = relay.nn.conv2d( - relay.var("data", shape=data_shape, dtype=dtype), - relay.var("weight", shape=weight_shape, dtype=dtype), + relay.var("data", shape=data_shape, dtype=in_dtype), + relay.var("weight", shape=weight_shape, dtype=in_dtype), kernel_size=kernel_size, channels=channels, data_layout=data_layout, kernel_layout=kernel_layout, - out_dtype=dtype, + out_dtype=out_dtype, ) with target: out = run_opt_pass(out, relay.transform.AlterOpLayout()) + data_shape = out.type_args[0].shape + weight_shape = out.type_args[1].shape + impl, _ = relay.backend.te_compiler.select_implementation( out.op, out.attrs, - [te.placeholder(data_shape, dtype), te.placeholder(weight_shape, dtype)], + [te.placeholder(data_shape, in_dtype), te.placeholder(weight_shape, in_dtype)], out.checked_type, target, use_autotvm=False, @@ -131,7 +134,7 @@ def test_int8_conv2d(target, expected_impl): target = tvm.target.Target(target) dtype = "int8" - selected_impl = _get_conv2d_impl(dtype, target) + selected_impl = _get_conv2d_impl(dtype, dtype, target) assert selected_impl == expected_impl @@ -171,7 +174,7 @@ def test_fp32_conv2d(target, expected_impl): target = tvm.target.Target(target) dtype = "float32" - selected_impl = _get_conv2d_impl(dtype, target) + selected_impl = _get_conv2d_impl(dtype, dtype, target) assert selected_impl == expected_impl @@ -211,7 +214,48 @@ def test_fp16_conv2d(target, expected_impl): target = tvm.target.Target(target) dtype = "float16" - selected_impl = _get_conv2d_impl(dtype, target) + selected_impl = _get_conv2d_impl(dtype, dtype, target) + assert selected_impl == expected_impl + + +@pytest.mark.skipif( + llvm_version_major() < 15, reason=f"Requires LLVM 15+, got {llvm_version_major()}" +) +@pytest.mark.parametrize( + "target,expected_impl", + [ + ( + "llvm -device=arm_cpu -mtriple=armv8l-linux-gnu -mattr=+neon", + "conv2d_nhwc_spatial_pack.arm_cpu", + ), + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", + ), + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", + ), + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+neon", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", + ), + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v9a", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", + ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9.2a,+sme", + "conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu", + ), + ], +) +def test_fp16_to_fp32_conv2d(target, expected_impl): + target = tvm.target.Target(target) + in_dtype = "float16" + out_dtype = "float32" + + selected_impl = _get_conv2d_impl(in_dtype, out_dtype, target) assert selected_impl == expected_impl diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py b/tests/python/topi/test_topi_conv2d_nhwc.py index 02f16b59c00b..d46db1b28b37 100644 --- a/tests/python/topi/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/test_topi_conv2d_nhwc.py @@ -68,7 +68,7 @@ False, ), ( - "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a", + "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+fullfp16", topi.arm_cpu.compute_conv2d_NHWC_hybrid, topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR, True, @@ -173,13 +173,14 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype, stride, padding, dilation): if target.features.has_sme and llvm_version_major() < 16: pytest.skip(f"LLVM {llvm_version_major()} does not support targetting SME.") - if target.features.has_sme and dtype == "float16": - pytest.skip(f"Conv2d fp16 targetting SME not implemented.") + # SME schedule always outputs float32 results, regardless of input dtype. + # Otherwise, output dtype is the same as input dtype. + out_dtype = "float32" if target.features.has_sme else dtype with target: a = tvm.nd.array(a_np, dev) w = tvm.nd.array(w_np, dev) - B = compute(A, W, stride, padding, dilation, dtype) + B = compute(A, W, stride, padding, dilation, out_dtype) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) if use_tir_schedule: primfunc = te.create_prim_func([A, W, B]) @@ -190,22 +191,22 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype, stride, padding, dilation): func = tvm.build(s, [A, W, B], target) # Run only on AArch64 devices - # Do not run SVE schedules on non-SVE devices + # Do not run SVE/SME schedules on non-SVE/SME devices build_only = ( platform.machine() != "aarch64" - or (target.features.has_sve and not tvm.testing.requires_aarch64_sve.run_time_check()) or ( dtype == "float16" and target.features.has_fp16_simd and not tvm.testing.requires_arm_fp16.run_time_check() ) + or (target.features.has_sve and not tvm.testing.requires_aarch64_sve.run_time_check()) or (target.features.has_sme and not tvm.testing.requires_aarch64_sme.run_time_check()) ) if build_only: return func(a, w, b) - tol = get_tolerance(dtype, w_np, b_np) + tol = get_tolerance(out_dtype, w_np, b_np) tvm.testing.assert_allclose(b.numpy(), b_np, rtol=tol["rtol"], atol=tol["atol"])