Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SME][TOPI] Add conv2d NHWC SME fp16->fp32 schedule #17048

Merged
merged 3 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from tvm import relay, topi, tir
from tvm.tir.schedule.analysis import has_block
from tvm.dlight.gpu.matmul import auto_inline_consumers

from ....auto_scheduler import is_auto_scheduler_enabled
from ....meta_schedule import is_meta_schedule_enabled
Expand Down Expand Up @@ -255,9 +256,8 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
if is_aarch64 and data.dtype in ["float32", "float16"]:
if (
target.features.has_sme
and data.dtype in ["float32"]
and kernel.dtype in ["float32"]
and out_type.dtype in ["float32"]
and kernel.dtype == data.dtype
and out_type.dtype == "float32"
):
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME),
Expand Down Expand Up @@ -536,6 +536,7 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
"""conv2d_winograd_without_weight_transform arm cpu strategy"""
layout = attrs.data_layout
data = inputs[0]
kernel = inputs[1]
strategy = _op.OpStrategy()
is_aarch64 = target.features.is_aarch64
has_dot_prod = target.features.has_dotprod
Expand Down Expand Up @@ -583,11 +584,25 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
)
elif data.dtype in ["float32", "float16"]:
# Non-quantized cases
strategy.add_implementation(
wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_hybrid_without_transform),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid_without_transform),
name="conv2d_NHWC_hybrid_without_transform.arm_cpu",
)
if (
target.features.has_sme
and kernel.dtype == "float16"
and data.dtype == "float16"
and out_type.dtype == "float32"
):
strategy.add_implementation(
wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_SME_transposed_B),
lambda: None,
name="conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu",
lhutton1 marked this conversation as resolved.
Show resolved Hide resolved
)
else:
strategy.add_implementation(
wrap_compute_conv2d_gemm(
topi.arm_cpu.compute_conv2d_NHWC_hybrid_without_transform
),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_hybrid_without_transform),
name="conv2d_NHWC_hybrid_without_transform.arm_cpu",
)
else:
raise RuntimeError(
f"Unsupported conv2d_NHWC_without_transform layout {layout}"
Expand Down Expand Up @@ -819,6 +834,8 @@ def arm_cpu_tir_strategy(sch: tir.Schedule) -> bool:
topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch)
return True
elif has_block(sch, "conv2d_gemm_output"):
conv2d_block = sch.get_block("conv2d_gemm_output")
auto_inline_consumers(sch, conv2d_block)
topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR(sch)
return True

Expand Down
22 changes: 12 additions & 10 deletions python/tvm/topi/arm_cpu/arm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,11 @@ def get_tiling_A(interleave_A, in_dtype, use_sme=False):
tile_M = 4
tile_K = 16
elif use_sme:
tile_M = 2 * 4 * tvm.tir.vscale()
tile_K = 2 * 4 * tvm.tir.vscale()
tile_M = 2 * tvm.tir.get_vscale_expr(in_dtype)
if in_dtype == "float16":
tile_K = tvm.tir.get_vscale_expr(in_dtype)
else:
tile_K = 2 * tvm.tir.get_vscale_expr(in_dtype)
else:
# In non-SME, non-quantized cases, A is not interleaved.
# We are loading 4 rows from A.
Expand Down Expand Up @@ -139,17 +142,16 @@ def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False,
tile_N = 4
tile_K = 16
elif use_sme:
tile_N = 2 * 4 * tvm.tir.vscale()
tile_K = 2 * 4 * tvm.tir.vscale()
# In non-SME, non-quantized cases, A is not interleaved.
elif use_scalable_vectors:
tile_N = 2 * tvm.tir.get_vscale_expr(in_dtype)
if in_dtype == "float16":
# Each load from B' contains 32 * vscale elements (i.e. 32 * vscale columns from B)
tile_N = 32 * tvm.tir.vscale()
tile_K = tvm.tir.get_vscale_expr(in_dtype)
else:
# Each load from B' contains 16 * vscale elements (i.e. 16 * vscale columns from B)
tile_N = 16 * tvm.tir.vscale()
tile_K = 2 * tvm.tir.get_vscale_expr(in_dtype)
# In non-SME, non-quantized cases, A is not interleaved.
elif use_scalable_vectors:
# Each load from B' contains 4 * scalable vectors (i.e. 4 * SVL columns from B)
# We are loading 4 rows from B', in the dimension of reduction (i.e. 4 rows from B)
tile_N = 4 * tvm.tir.get_vscale_expr(in_dtype)
tile_K = 4
elif in_dtype == "float16" and target.features.has_fp16_simd:
# Each load from B' contains 32 elements (i.e. 32 columns from B)
Expand Down
81 changes: 73 additions & 8 deletions python/tvm/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tvm.script import tir as T
import tvm.contrib.nnpack
from tvm.tir.schedule.analysis import has_block
from tvm.topi.arm_cpu.matmul import _get_transpose_interleave_intrin_name

from ..utils import traverse_inline, get_const_tuple
from .. import nn
Expand Down Expand Up @@ -680,6 +681,43 @@ def compute_conv2d_NHWC_hybrid_SME(cfg, data, kernel, strides, padding, dilation
)


@autotvm.register_topi_compute("conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu")
def compute_conv2d_NHWC_SME_transposed_B(
cfg,
data,
kernel,
strides,
padding,
dilation,
out_dtype,
kernel_size,
output_channels,
):
"""Compute conv2d NHWC hybrid SME transposed B"""
N, K = get_const_tuple(kernel.shape)
tile_N, tile_K = get_tiling_B_transformed(False, data.dtype, True, True)
pad_N, pad_K = tvm.topi.arm_cpu.arm_utils.get_conv2d_weights_padding(N, K, tile_N, tile_K)

kernel = tvm.topi.nn.pad(
kernel, pad_before=(0, 0), pad_after=(pad_N, pad_K), name="weight_padding"
lhutton1 marked this conversation as resolved.
Show resolved Hide resolved
)

return compute_conv2d_gemm_without_weight_transform(
cfg,
data,
kernel,
strides,
padding,
dilation,
out_dtype,
kernel_size,
output_channels,
interleave_A=False,
use_scalable_vectors=True,
use_sme=True,
)


def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
"""
Perform TIR scheduling for conv2d NHWC.
Expand All @@ -688,7 +726,8 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
primfunc = sch.mod["main"]
buffer_names = primfunc.params
buffer_list = [primfunc.buffer_map[buf] for buf in buffer_names]
dtype = buffer_list[0].dtype
in_dtype = buffer_list[0].dtype
out_dtype = "float32"

# Determine PrimFunc blocks
block_list = [
Expand All @@ -698,6 +737,8 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
"A_padded_K",
"A_padded_M",
"weight_flatten",
"weight_padding",
"weight_transpose",
lhutton1 marked this conversation as resolved.
Show resolved Hide resolved
"C",
"conv2d_gemm_output",
]
Expand All @@ -716,8 +757,8 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
M_padded = sch.get(m).extent
N_padded = sch.get(n).extent
K_padded = sch.get(k).extent
tile_M, tile_K = get_tiling_A(False, dtype, use_sme)
tile_N, _ = get_tiling_B_transformed(False, dtype, use_scalable_vectors, use_sme)
tile_M, tile_K = get_tiling_A(False, in_dtype, use_sme)
tile_N, _ = get_tiling_B_transformed(False, in_dtype, use_scalable_vectors, use_sme)
tile_M = T.cast(tile_M, M_padded.dtype)
tile_N = T.cast(tile_N, N_padded.dtype)
tile_K = T.cast(tile_K, K_padded.dtype)
Expand All @@ -729,12 +770,15 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
# pylint: disable=import-outside-toplevel
from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes
from tvm.tir.tensor_intrin.arm_cpu import (
ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE,
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA,
ARM_SME_INIT,
get_sme_gemm_interleaved_mopa_2svlx2svl_intrin,
)

transpose_interleave_intrin_name = _get_transpose_interleave_intrin_name(
in_dtype, out_dtype
)

# Interleave the padded im2col matrix utilizing the matrix tile
interleave_t_A_block = sch.cache_read(gemm_block, 0, "global")
sch.transform_layout(interleave_t_A_block, ("write", 0), lambda b, m, k: (b, k, m))
Expand All @@ -743,24 +787,40 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
sch.parallel(b)
sch.reorder(b, ko, mo, ki, mi)
sch.tensorize(ki, ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE)
sch.tensorize(ki, transpose_interleave_intrin_name)

# Interleave the padded weights matrix utilizing the matrix tile
if in_dtype == "float16":
interleave_b_block = sch.cache_read(gemm_block, 1, "global")
sch.transform_layout(interleave_b_block, ("write", 0), lambda n, k: (k, n))
n, k = sch.get_loops(interleave_b_block)
ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
no, ni = sch.split(n, factors=(None, tile_N), disable_predication=True)
sch.reorder(ko, no, ki, ni)
sch.tensorize(ki, transpose_interleave_intrin_name)

# Split and reorder the loops of the GeMM for tensorization
b, m, n, k = sch.get_loops(gemm_block)
tile_M, _ = get_tiling_A(False, out_dtype, True)
tile_N, _ = get_tiling_B_transformed(False, out_dtype, True, True)
tile_M = T.cast(tile_M, M_padded.dtype)
tile_N = T.cast(tile_N, N_padded.dtype)
mo, mi = sch.split(m, factors=(None, tile_M), disable_predication=True)
no, ni = sch.split(n, factors=(None, tile_N), disable_predication=True)
sch.parallel(b)
sch.reorder(b, mo, no, mi, ni, k)

# Tensorize the GeMM output matrix initialization to zero
# Tensorize the GeMM initialization
init_block = sch.decompose_reduction(gemm_block, mi)
sch.tensorize(sch.get_loops(init_block)[-2], ARM_SME_INIT)

# Tensorize the GeMM update
sme_gemm_interleaved_intrin_name = ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}"
sme_gemm_interleaved_intrin_name = (
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}_{in_dtype}"
lhutton1 marked this conversation as resolved.
Show resolved Hide resolved
)
tvm.tir.TensorIntrin.register(
sme_gemm_interleaved_intrin_name,
*get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded, dtype),
*get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded, in_dtype),
override=True,
)
sch.tensorize(mi, sme_gemm_interleaved_intrin_name)
Expand Down Expand Up @@ -878,6 +938,11 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
weight_flatten_block = func_blocks["weight_flatten"]
sch.compute_inline(weight_flatten_block)

# Weight transpose
if func_blocks["weight_transpose"] and func_blocks["weight_padding"]:
weight_padding_block = func_blocks["weight_padding"]
sch.compute_inline(weight_padding_block)

# Conv2d output block
output_block = func_blocks["conv2d_gemm_output"]
n, h, w, c = sch.get_loops(output_block)
Expand Down
24 changes: 24 additions & 0 deletions python/tvm/topi/arm_cpu/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,30 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
inputs[0], new_kernel_expr, **new_attrs
)

if (
topi_tmpl == "conv2d_NHWC_hybrid_SME.arm_cpu"
and data_dtype == "float16"
and kernel_dtype == "float16"
and out_dtype == "float32"
):
assert data_layout == "NHWC" and kernel_layout == "HWIO"
KH, KW, IC, OC = get_const_tuple(kernel.shape)
K = KH * KW * IC
N = OC
transposed_kernel_expr = relay.transpose(inputs[1], axes=[3, 0, 1, 2])
transposed_flattened_kernel_expr = relay.reshape(transposed_kernel_expr, newshape=(N, K))
new_kernel_expr = transposed_flattened_kernel_expr
new_kernel = te.placeholder((N, K), kernel.dtype)
new_workload_name = "conv2d_NHWC_hybrid_SME_transposed_B.arm_cpu"
lhutton1 marked this conversation as resolved.
Show resolved Hide resolved
new_workload = autotvm.task.args_to_workload(
[data, new_kernel, strides, padding, dilation, out_dtype, (KH, KW), OC],
new_workload_name,
)
dispatch_ctx.update(target, new_workload, cfg)
return relay.nn.contrib_conv2d_gemm_without_weight_transform(
inputs[0], new_kernel_expr, **new_attrs
)

# Only microTVM does layout alteration for NHWC layout with real data types
if data_layout == "NHWC" and data_dtype not in ["uint8", "int8"]:
return None
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/topi/arm_cpu/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,17 @@ def compute_conv2d_gemm_without_weight_transform(
tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
- tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
)
elif use_sme and in_dtype == "float16" and out_dtype == "float32":
assert len(B_interleaved_t.shape) == 2
C = te.compute(
(batches, M_padded, N_padded),
lambda b, x, y: te.sum(
A[b, x, k].astype(out_dtype) * B_interleaved_t[y, k].astype(out_dtype),
axis=k,
),
name="C",
)
zero = tvm.tir.const(0)
elif use_scalable_vectors or use_sme:
assert len(B_interleaved_t.shape) == 2
C = te.compute(
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,12 @@ def conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors=Fa
kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N), name="weight_padding"
)

if use_sme or use_scalable_vectors:
if use_sme and kernel.dtype == "float16":
return te.compute(
(N_padded, K_padded), lambda x, y: kernel_flat[y, x], name="weight_transpose"
)

if use_scalable_vectors or use_sme:
return kernel_flat

if kernel.dtype in ["int8", "uint8"]:
Expand Down
Loading
Loading