Skip to content

Commit

Permalink
Fix tests and rebase
Browse files Browse the repository at this point in the history
Change-Id: Iaddeb046bdecb0352a067174f6e6e4be335e94fd
  • Loading branch information
lhutton1 committed Jun 13, 2024
1 parent 6db06bc commit e755e43
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
15 changes: 7 additions & 8 deletions python/tvm/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from tvm.script import tir as T
import tvm.contrib.nnpack
from tvm.tir.schedule.analysis import has_block
from tvm.topi.arm_cpu.matmul import _get_transpose_interleave_intrin_name

from ..utils import traverse_inline, get_const_tuple
from .. import nn
Expand Down Expand Up @@ -776,10 +775,6 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
get_transpose_interleave_intrin_name,
)

transpose_interleave_intrin_name = _get_transpose_interleave_intrin_name(
in_dtype, out_dtype
)

# Interleave the padded im2col matrix utilizing the matrix tile
interleave_t_A_block = sch.cache_read(gemm_block, 0, "global")
sch.transform_layout(interleave_t_A_block, ("write", 0), lambda b, m, k: (b, k, m))
Expand All @@ -788,7 +783,9 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
sch.parallel(b)
sch.reorder(b, ko, mo, ki, mi)
sch.tensorize(ki, get_transpose_interleave_intrin_name(in_dtype, out_dtype, M_padded, K_padded))
sch.tensorize(
ki, get_transpose_interleave_intrin_name(in_dtype, out_dtype, M_padded, K_padded)
)

# Interleave the padded weights matrix utilizing the matrix tile
if in_dtype == "float16":
Expand All @@ -798,7 +795,9 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
no, ni = sch.split(n, factors=(None, tile_N), disable_predication=True)
sch.reorder(ko, no, ki, ni)
sch.tensorize(ki, get_transpose_interleave_intrin_name(in_dtype, out_dtype, M_padded, K_padded))
sch.tensorize(
ki, get_transpose_interleave_intrin_name(in_dtype, out_dtype, M_padded, K_padded)
)

# Split and reorder the loops of the GeMM for tensorization
b, m, n, k = sch.get_loops(gemm_block)
Expand All @@ -821,7 +820,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
)
tvm.tir.TensorIntrin.register(
sme_gemm_interleaved_intrin_name,
*get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded, in_dtype),
*get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(M_padded, K_padded, in_dtype),
override=True,
)
sch.tensorize(mi, sme_gemm_interleaved_intrin_name)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/arm_cpu/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,15 @@ def compute_conv2d_gemm_without_weight_transform(
N_padded = N + pad_N

pad_before = (0, 0, 0)
pad_after = (0, 0, 0) if use_sme else (0, pad_M, pad_K)
pad_after = (0, pad_M, pad_K)

if pad_K != 0:
A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_K")
elif pad_M != 0:
A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_M")

idxm = tvm.tir.indexmod
k = te.reduce_axis((0, K), "k")
k = te.reduce_axis((0, K if use_explicit_predication else K_padded), "k")

# Determine matrix multiplication compute definition
target = Target.current(allow_none=False)
Expand Down
3 changes: 3 additions & 0 deletions tests/python/topi/test_topi_conv2d_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype, stride, padding, dilation):
if target.features.has_sme and a_np.shape[0] > 1:
pytest.skip(f"Conv2d with batches > 1 targeting SME not implemented.")

if target.features.has_sme and (a_np.shape[3] * w_np.shape[0] * w_np.shape[1]) <= 1:
pytest.skip(f"Conv2d with unit reduction dimension targeting SME not supported.")

# SME schedule always outputs float32 results, regardless of input dtype.
# Otherwise, output dtype is the same as input dtype.
out_dtype = "float32" if target.features.has_sme else dtype
Expand Down

0 comments on commit e755e43

Please sign in to comment.