Skip to content

Commit

Permalink
Add fp16->fp32 conv2d strategy selection tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Anndrey24 committed Jun 4, 2024
1 parent fc28a8b commit aa5e927
Showing 1 changed file with 52 additions and 8 deletions.
60 changes: 52 additions & 8 deletions tests/python/relay/strategy/test_select_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_concatenate(target, expected_implementation):
assert impl.name == expected_implementation


def _get_conv2d_impl(dtype, target):
def _get_conv2d_impl(in_dtype, out_dtype, target):
"""Returns selected conv2d implementation for a given datatype and target"""
data_shape = (1, 1, 1, 4)
weight_shape = (1, 1, 4, 4)
Expand All @@ -68,21 +68,24 @@ def _get_conv2d_impl(dtype, target):
kernel_size = (1, 1)

out = relay.nn.conv2d(
relay.var("data", shape=data_shape, dtype=dtype),
relay.var("weight", shape=weight_shape, dtype=dtype),
relay.var("data", shape=data_shape, dtype=in_dtype),
relay.var("weight", shape=weight_shape, dtype=in_dtype),
kernel_size=kernel_size,
channels=channels,
data_layout=data_layout,
kernel_layout=kernel_layout,
out_dtype=dtype,
out_dtype=out_dtype,
)

with target:
out = run_opt_pass(out, relay.transform.AlterOpLayout())
data_shape = out.type_args[0].shape
weight_shape = out.type_args[1].shape

impl, _ = relay.backend.te_compiler.select_implementation(
out.op,
out.attrs,
[te.placeholder(data_shape, dtype), te.placeholder(weight_shape, dtype)],
[te.placeholder(data_shape, in_dtype), te.placeholder(weight_shape, in_dtype)],
out.checked_type,
target,
use_autotvm=False,
Expand Down Expand Up @@ -131,7 +134,7 @@ def test_int8_conv2d(target, expected_impl):
target = tvm.target.Target(target)
dtype = "int8"

selected_impl = _get_conv2d_impl(dtype, target)
selected_impl = _get_conv2d_impl(dtype, dtype, target)
assert selected_impl == expected_impl


Expand Down Expand Up @@ -171,7 +174,7 @@ def test_fp32_conv2d(target, expected_impl):
target = tvm.target.Target(target)
dtype = "float32"

selected_impl = _get_conv2d_impl(dtype, target)
selected_impl = _get_conv2d_impl(dtype, dtype, target)
assert selected_impl == expected_impl


Expand Down Expand Up @@ -211,7 +214,48 @@ def test_fp16_conv2d(target, expected_impl):
target = tvm.target.Target(target)
dtype = "float16"

selected_impl = _get_conv2d_impl(dtype, target)
selected_impl = _get_conv2d_impl(dtype, dtype, target)
assert selected_impl == expected_impl


@pytest.mark.skipif(
llvm_version_major() < 15, reason=f"Requires LLVM 15+, got {llvm_version_major()}"
)
@pytest.mark.parametrize(
"target,expected_impl",
[
(
"llvm -device=arm_cpu -mtriple=armv8l-linux-gnu -mattr=+neon",
"conv2d_nhwc_spatial_pack.arm_cpu",
),
(
"llvm -device=arm_cpu -mtriple=aarch64-linux-gnu",
"conv2d_NHWC_hybrid_without_transform.arm_cpu",
),
(
"llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon",
"conv2d_NHWC_hybrid_without_transform.arm_cpu",
),
(
"llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+neon",
"conv2d_NHWC_hybrid_without_transform.arm_cpu",
),
(
"llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v9a",
"conv2d_NHWC_hybrid_without_transform.arm_cpu",
),
(
"llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v9.2a,+sme",
"conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu",
),
],
)
def test_fp16_to_fp32_conv2d(target, expected_impl):
target = tvm.target.Target(target)
in_dtype = "float16"
out_dtype = "float32"

selected_impl = _get_conv2d_impl(in_dtype, out_dtype, target)
assert selected_impl == expected_impl


Expand Down

0 comments on commit aa5e927

Please sign in to comment.