From aa5e927631f6a7b0a1e41e20c92a0f0f9c3d770c Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Tue, 4 Jun 2024 15:08:01 +0000 Subject: [PATCH] Add fp16->fp32 conv2d strategy selection tests --- .../strategy/test_select_implementation.py | 60 ++++++++++++++++--- 1 file changed, 52 insertions(+), 8 deletions(-) diff --git a/tests/python/relay/strategy/test_select_implementation.py b/tests/python/relay/strategy/test_select_implementation.py index 01a914e793c1..b95bd4072af8 100644 --- a/tests/python/relay/strategy/test_select_implementation.py +++ b/tests/python/relay/strategy/test_select_implementation.py @@ -58,7 +58,7 @@ def test_concatenate(target, expected_implementation): assert impl.name == expected_implementation -def _get_conv2d_impl(dtype, target): +def _get_conv2d_impl(in_dtype, out_dtype, target): """Returns selected conv2d implementation for a given datatype and target""" data_shape = (1, 1, 1, 4) weight_shape = (1, 1, 4, 4) @@ -68,21 +68,24 @@ def _get_conv2d_impl(dtype, target): kernel_size = (1, 1) out = relay.nn.conv2d( - relay.var("data", shape=data_shape, dtype=dtype), - relay.var("weight", shape=weight_shape, dtype=dtype), + relay.var("data", shape=data_shape, dtype=in_dtype), + relay.var("weight", shape=weight_shape, dtype=in_dtype), kernel_size=kernel_size, channels=channels, data_layout=data_layout, kernel_layout=kernel_layout, - out_dtype=dtype, + out_dtype=out_dtype, ) with target: out = run_opt_pass(out, relay.transform.AlterOpLayout()) + data_shape = out.type_args[0].shape + weight_shape = out.type_args[1].shape + impl, _ = relay.backend.te_compiler.select_implementation( out.op, out.attrs, - [te.placeholder(data_shape, dtype), te.placeholder(weight_shape, dtype)], + [te.placeholder(data_shape, in_dtype), te.placeholder(weight_shape, in_dtype)], out.checked_type, target, use_autotvm=False, @@ -131,7 +134,7 @@ def test_int8_conv2d(target, expected_impl): target = tvm.target.Target(target) dtype = "int8" - selected_impl = _get_conv2d_impl(dtype, target) + selected_impl = _get_conv2d_impl(dtype, dtype, target) assert selected_impl == expected_impl @@ -171,7 +174,7 @@ def test_fp32_conv2d(target, expected_impl): target = tvm.target.Target(target) dtype = "float32" - selected_impl = _get_conv2d_impl(dtype, target) + selected_impl = _get_conv2d_impl(dtype, dtype, target) assert selected_impl == expected_impl @@ -211,7 +214,48 @@ def test_fp16_conv2d(target, expected_impl): target = tvm.target.Target(target) dtype = "float16" - selected_impl = _get_conv2d_impl(dtype, target) + selected_impl = _get_conv2d_impl(dtype, dtype, target) + assert selected_impl == expected_impl + + +@pytest.mark.skipif( + llvm_version_major() < 15, reason=f"Requires LLVM 15+, got {llvm_version_major()}" +) +@pytest.mark.parametrize( + "target,expected_impl", + [ + ( + "llvm -device=arm_cpu -mtriple=armv8l-linux-gnu -mattr=+neon", + "conv2d_nhwc_spatial_pack.arm_cpu", + ), + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", + ), + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", + ), + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+neon", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", + ), + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v9a", + "conv2d_NHWC_hybrid_without_transform.arm_cpu", + ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9.2a,+sme", + "conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu", + ), + ], +) +def test_fp16_to_fp32_conv2d(target, expected_impl): + target = tvm.target.Target(target) + in_dtype = "float16" + out_dtype = "float32" + + selected_impl = _get_conv2d_impl(in_dtype, out_dtype, target) assert selected_impl == expected_impl