Skip to content

Commit

Permalink
[SME][TOPI] Add conv2d NHWC SME fp16->fp32 schedule (#17048)
Browse files Browse the repository at this point in the history
This commit extends the SME conv2d NHWC schedule to support convolutions with float16 inputs (data and kernel) and a float32 output using the tensor intrinsics added in #16981.
  • Loading branch information
Anndrey24 authored Jun 5, 2024
1 parent 2a62c72 commit 4b82974
Show file tree
Hide file tree
Showing 9 changed files with 244 additions and 58 deletions.
39 changes: 30 additions & 9 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from tvm import relay, topi, tir
from tvm.tir.schedule.analysis import has_block
from tvm.dlight.gpu.matmul import auto_inline_consumers

from ....auto_scheduler import is_auto_scheduler_enabled
from ....meta_schedule import is_meta_schedule_enabled
Expand Down Expand Up @@ -255,9 +256,8 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
if is_aarch64 and data.dtype in ["float32", "float16"]:
if (
target.features.has_sme
and data.dtype in ["float32"]
and kernel.dtype in ["float32"]
and out_type.dtype in ["float32"]
and kernel.dtype == data.dtype
and out_type.dtype == "float32"
):
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME),
Expand Down Expand Up @@ -536,6 +536,7 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
"""conv2d_winograd_without_weight_transform arm cpu strategy"""
layout = attrs.data_layout
data = inputs[0]
kernel = inputs[1]
strategy = _op.OpStrategy()
is_aarch64 = target.features.is_aarch64
has_dot_prod = target.features.has_dotprod
Expand Down Expand Up @@ -581,13 +582,31 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
wrap_topi_schedule(interleaved_schedule),
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
)
# Non-quantized cases
elif data.dtype in ["float32", "float16"]:
# Non-quantized cases
strategy.add_implementation(
wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_hybrid_without_transform),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid_without_transform),
name="conv2d_NHWC_hybrid_without_transform.arm_cpu",
)
# The SME schedule for float16->float32 prearranges the two matrices to be multiplied
# using the ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE intrinsic which expects
# the reduction axis K as the second dimension of the matrix (i.e. shape = (_, K)).
# This means that the flattened weights matrix B needs to be transposed to (N, K).
if (
target.features.has_sme
and kernel.dtype == "float16"
and data.dtype == "float16"
and out_type.dtype == "float32"
):
strategy.add_implementation(
wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_SME_transposed_B),
lambda: None,
name="conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu",
)
else:
strategy.add_implementation(
wrap_compute_conv2d_gemm(
topi.arm_cpu.compute_conv2d_NHWC_hybrid_without_transform
),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid_without_transform),
name="conv2d_NHWC_hybrid_without_transform.arm_cpu",
)
else:
raise RuntimeError(
f"Unsupported conv2d_NHWC_without_transform layout {layout}"
Expand Down Expand Up @@ -819,6 +838,8 @@ def arm_cpu_tir_strategy(sch: tir.Schedule) -> bool:
topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch)
return True
elif has_block(sch, "conv2d_gemm_output"):
conv2d_block = sch.get_block("conv2d_gemm_output")
auto_inline_consumers(sch, conv2d_block)
topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR(sch)
return True

Expand Down
22 changes: 12 additions & 10 deletions python/tvm/topi/arm_cpu/arm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,11 @@ def get_tiling_A(interleave_A, in_dtype, use_sme=False):
tile_M = 4
tile_K = 16
elif use_sme:
tile_M = 2 * 4 * tvm.tir.vscale()
tile_K = 2 * 4 * tvm.tir.vscale()
tile_M = 2 * tvm.tir.get_vscale_expr(in_dtype)
if in_dtype == "float16":
tile_K = tvm.tir.get_vscale_expr(in_dtype)
else:
tile_K = 2 * tvm.tir.get_vscale_expr(in_dtype)
else:
# In non-SME, non-quantized cases, A is not interleaved.
# We are loading 4 rows from A.
Expand Down Expand Up @@ -139,17 +142,16 @@ def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False,
tile_N = 4
tile_K = 16
elif use_sme:
tile_N = 2 * 4 * tvm.tir.vscale()
tile_K = 2 * 4 * tvm.tir.vscale()
# In non-SME, non-quantized cases, A is not interleaved.
elif use_scalable_vectors:
tile_N = 2 * tvm.tir.get_vscale_expr(in_dtype)
if in_dtype == "float16":
# Each load from B' contains 32 * vscale elements (i.e. 32 * vscale columns from B)
tile_N = 32 * tvm.tir.vscale()
tile_K = tvm.tir.get_vscale_expr(in_dtype)
else:
# Each load from B' contains 16 * vscale elements (i.e. 16 * vscale columns from B)
tile_N = 16 * tvm.tir.vscale()
tile_K = 2 * tvm.tir.get_vscale_expr(in_dtype)
# In non-SME, non-quantized cases, A is not interleaved.
elif use_scalable_vectors:
# Each load from B' contains 4 * scalable vectors (i.e. 4 * SVL columns from B)
# We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B)
tile_N = 4 * tvm.tir.get_vscale_expr(in_dtype)
tile_K = 4
elif in_dtype == "float16" and target.features.has_fp16_simd:
# Each load from B' contains 32 elements (i.e. 32 columns from B)
Expand Down
81 changes: 73 additions & 8 deletions python/tvm/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tvm.script import tir as T
import tvm.contrib.nnpack
from tvm.tir.schedule.analysis import has_block
from tvm.topi.arm_cpu.matmul import _get_transpose_interleave_intrin_name

from ..utils import traverse_inline, get_const_tuple
from .. import nn
Expand Down Expand Up @@ -680,6 +681,43 @@ def compute_conv2d_NHWC_hybrid_SME(cfg, data, kernel, strides, padding, dilation
)


@autotvm.register_topi_compute("conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu")
def compute_conv2d_NHWC_SME_transposed_B(
cfg,
data,
kernel,
strides,
padding,
dilation,
out_dtype,
kernel_size,
output_channels,
):
"""Compute conv2d NHWC hybrid SME transposed B"""
N, K = get_const_tuple(kernel.shape)
tile_N, tile_K = get_tiling_B_transformed(False, data.dtype, True, True)
pad_N, pad_K = tvm.topi.arm_cpu.arm_utils.get_conv2d_weights_padding(N, K, tile_N, tile_K)

kernel = tvm.topi.nn.pad(
kernel, pad_before=(0, 0), pad_after=(pad_N, pad_K), name="weight_padding"
)

return compute_conv2d_gemm_without_weight_transform(
cfg,
data,
kernel,
strides,
padding,
dilation,
out_dtype,
kernel_size,
output_channels,
interleave_A=False,
use_scalable_vectors=True,
use_sme=True,
)


def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
"""
Perform TIR scheduling for conv2d NHWC.
Expand All @@ -688,7 +726,8 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
primfunc = sch.mod["main"]
buffer_names = primfunc.params
buffer_list = [primfunc.buffer_map[buf] for buf in buffer_names]
dtype = buffer_list[0].dtype
in_dtype = buffer_list[0].dtype
out_dtype = "float32"

# Determine PrimFunc blocks
block_list = [
Expand All @@ -698,6 +737,8 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
"A_padded_K",
"A_padded_M",
"weight_flatten",
"weight_padding",
"weight_transpose",
"C",
"conv2d_gemm_output",
]
Expand All @@ -716,8 +757,8 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
M_padded = sch.get(m).extent
N_padded = sch.get(n).extent
K_padded = sch.get(k).extent
tile_M, tile_K = get_tiling_A(False, dtype, use_sme)
tile_N, _ = get_tiling_B_transformed(False, dtype, use_scalable_vectors, use_sme)
tile_M, tile_K = get_tiling_A(False, in_dtype, use_sme)
tile_N, _ = get_tiling_B_transformed(False, in_dtype, use_scalable_vectors, use_sme)
tile_M = T.cast(tile_M, M_padded.dtype)
tile_N = T.cast(tile_N, N_padded.dtype)
tile_K = T.cast(tile_K, K_padded.dtype)
Expand All @@ -729,12 +770,15 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
# pylint: disable=import-outside-toplevel
from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes
from tvm.tir.tensor_intrin.arm_cpu import (
ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE,
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA,
ARM_SME_INIT,
get_sme_gemm_interleaved_mopa_2svlx2svl_intrin,
)

transpose_interleave_intrin_name = _get_transpose_interleave_intrin_name(
in_dtype, out_dtype
)

# Interleave the padded im2col matrix utilizing the matrix tile
interleave_t_A_block = sch.cache_read(gemm_block, 0, "global")
sch.transform_layout(interleave_t_A_block, ("write", 0), lambda b, m, k: (b, k, m))
Expand All @@ -743,24 +787,40 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
sch.parallel(b)
sch.reorder(b, ko, mo, ki, mi)
sch.tensorize(ki, ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE)
sch.tensorize(ki, transpose_interleave_intrin_name)

# Interleave the padded weights matrix utilizing the matrix tile
if in_dtype == "float16":
interleave_b_block = sch.cache_read(gemm_block, 1, "global")
sch.transform_layout(interleave_b_block, ("write", 0), lambda n, k: (k, n))
n, k = sch.get_loops(interleave_b_block)
ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
no, ni = sch.split(n, factors=(None, tile_N), disable_predication=True)
sch.reorder(ko, no, ki, ni)
sch.tensorize(ki, transpose_interleave_intrin_name)

# Split and reorder the loops of the GeMM for tensorization
b, m, n, k = sch.get_loops(gemm_block)
tile_M, _ = get_tiling_A(False, out_dtype, True)
tile_N, _ = get_tiling_B_transformed(False, out_dtype, True, True)
tile_M = T.cast(tile_M, M_padded.dtype)
tile_N = T.cast(tile_N, N_padded.dtype)
mo, mi = sch.split(m, factors=(None, tile_M), disable_predication=True)
no, ni = sch.split(n, factors=(None, tile_N), disable_predication=True)
sch.parallel(b)
sch.reorder(b, mo, no, mi, ni, k)

# Tensorize the GeMM output matrix initialization to zero
# Tensorize the GeMM initialization
init_block = sch.decompose_reduction(gemm_block, mi)
sch.tensorize(sch.get_loops(init_block)[-2], ARM_SME_INIT)

# Tensorize the GeMM update
sme_gemm_interleaved_intrin_name = ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}"
sme_gemm_interleaved_intrin_name = (
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}_{in_dtype}"
)
tvm.tir.TensorIntrin.register(
sme_gemm_interleaved_intrin_name,
*get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded, dtype),
*get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded, in_dtype),
override=True,
)
sch.tensorize(mi, sme_gemm_interleaved_intrin_name)
Expand Down Expand Up @@ -878,6 +938,11 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
weight_flatten_block = func_blocks["weight_flatten"]
sch.compute_inline(weight_flatten_block)

# Weight transpose
if func_blocks["weight_transpose"] and func_blocks["weight_padding"]:
weight_padding_block = func_blocks["weight_padding"]
sch.compute_inline(weight_padding_block)

# Conv2d output block
output_block = func_blocks["conv2d_gemm_output"]
n, h, w, c = sch.get_loops(output_block)
Expand Down
28 changes: 28 additions & 0 deletions python/tvm/topi/arm_cpu/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,34 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
inputs[0], new_kernel_expr, **new_attrs
)

if (
topi_tmpl == "conv2d_NHWC_hybrid_SME.arm_cpu"
and data_dtype == "float16"
and kernel_dtype == "float16"
and out_dtype == "float32"
):
assert data_layout == "NHWC" and kernel_layout == "HWIO"
KH, KW, IC, OC = get_const_tuple(kernel.shape)
K = KH * KW * IC
N = OC
# The SME schedule for float16->float32 prearranges the two matrices to be multiplied
# using the ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE intrinsic which expects
# the reduction axis K as the second dimension of the matrix (i.e. shape = (_, K)).
# This means that the flattened weights matrix B needs to be transposed to (N, K).
transposed_kernel_expr = relay.transpose(inputs[1], axes=[3, 0, 1, 2])
transposed_flattened_kernel_expr = relay.reshape(transposed_kernel_expr, newshape=(N, K))
new_kernel_expr = transposed_flattened_kernel_expr
new_kernel = te.placeholder((N, K), kernel.dtype)
new_workload_name = "conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu"
new_workload = autotvm.task.args_to_workload(
[data, new_kernel, strides, padding, dilation, out_dtype, (KH, KW), OC],
new_workload_name,
)
dispatch_ctx.update(target, new_workload, cfg)
return relay.nn.contrib_conv2d_gemm_without_weight_transform(
inputs[0], new_kernel_expr, **new_attrs
)

# Only microTVM does layout alteration for NHWC layout with real data types
if data_layout == "NHWC" and data_dtype not in ["uint8", "int8"]:
return None
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/topi/arm_cpu/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,17 @@ def compute_conv2d_gemm_without_weight_transform(
tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
- tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
)
elif use_sme and in_dtype == "float16" and out_dtype == "float32":
assert len(B_interleaved_t.shape) == 2
C = te.compute(
(batches, M_padded, N_padded),
lambda b, x, y: te.sum(
A[b, x, k].astype(out_dtype) * B_interleaved_t[y, k].astype(out_dtype),
axis=k,
),
name="C",
)
zero = tvm.tir.const(0)
elif use_scalable_vectors or use_sme:
assert len(B_interleaved_t.shape) == 2
C = te.compute(
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,12 @@ def conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors=Fa
kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N), name="weight_padding"
)

if use_sme or use_scalable_vectors:
if use_sme and kernel.dtype == "float16":
return te.compute(
(N_padded, K_padded), lambda x, y: kernel_flat[y, x], name="weight_transpose"
)

if use_scalable_vectors or use_sme:
return kernel_flat

if kernel.dtype in ["int8", "uint8"]:
Expand Down
Loading

0 comments on commit 4b82974

Please sign in to comment.