Skip to content

Commit

Permalink
[SME] Utilize predication in fp32 matmul and conv2d schedules
Browse files Browse the repository at this point in the history
Prior to this commit, the matmul and conv2d schedules required padding
of the inputs to some multiple of vscale and a final "unpadding" stage.

Instead, we can leverage predicated operations to avoid the
the requirement for padding. Both the transpose interleave and outer
product fp32 intrinsics are updated to use predication. The
`get_active_lane_mask` intrinsic is utilized to generate a variably
sized mask of active lanes depending on the global position the tensor
intrinsic is operating on.

For now this relies on using `offset_of` and `stride` information from
the tensor we're predicating an access on. Likely we will want to
build on this in the future with a more intuitive API for determining
the current tile location.

Support for batched conv2d was removed since this causes numerical
issues which is suspected to be due to how the current tile is
determined (paragraph above).

Change-Id: I79620200c9a94e2ca9d7297c4ed2abf87549cc41
  • Loading branch information
lhutton1 committed Jun 13, 2024
1 parent 5618628 commit 6db06bc
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 89 deletions.
7 changes: 7 additions & 0 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
"""conv2d arm cpu strategy"""
strategy = _op.OpStrategy()
data, kernel = inputs
data_shape = data.shape
kernel_shape = kernel.shape
dilation_h, dilation_w = attrs.get_int_tuple("dilation")
stride_h, stride_w = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
Expand Down Expand Up @@ -258,6 +260,11 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
target.features.has_sme
and kernel.dtype == data.dtype
and out_type.dtype == "float32"
and data_shape[0] == 1
# The schedule uses tensorization which does not work when the
# reduction axis of the gemm has unit iters. See
# https://github.com/apache/tvm/issues/16566
and (data_shape[3] * kernel_shape[0] * kernel_shape[1]) > 1
):
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME),
Expand Down
134 changes: 110 additions & 24 deletions python/tvm/tir/tensor_intrin/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,51 @@ def _create_ptrue_mask(dtype):
return T.broadcast(T.bool(True), tir.get_vscale_expr(dtype))


def get_sme_transpose_interleave_2svlx2svl_fp32_intrin():
def _create_active_lane_mask(tensor, relative_offsets, vertical_limit):
"""
Get the active lane mask intrinsic call for predicated accesses.
Parameters
----------
tensor : tvm.tir.Buffer
The tensor the buffer access will be performed on.
relative_offsets : Tuple[PrimExpr, PrimExpr]
The vertical and horizontal offsets into the accumulator tile.
vertical_limit : PrimExpr
An absolute offset specifying the limit at which rows should be stored.
Returns
-------
PrimExpr
The active lane mask intrinsic.
"""
vertical_offset, horizontal_offset = relative_offsets
stride = tensor.strides[0]

# The base is the offset of the first value we wish to store
base = T.int32(tensor.offset_of([vertical_offset, horizontal_offset])[0])

# The limit is the maximum offset in the current row of 'base' that we wish to allow values
# to be stored. Calculating this limit is a bit tricky since we can only request offsets of
# elements in the tensorized tile of the output tensor. One way to calculate this is to find
# the offset of the first value in the row of the output tensor that 'base' is in and add
# 'stride' to it.
limit = (
base
- T.int32(horizontal_offset)
- T.int32((tensor.offset_of([0, 0])[0] % stride))
+ T.int32(stride)
)
limit = T.Min(limit, T.Cast("int32", vertical_limit) * stride)

return T.get_active_lane_mask(
"uint1xvscalex4",
T.Cast("int32", base),
T.Cast("int32", limit),
)


def get_sme_transpose_interleave_2svlx2svl_fp32_intrin(cols, rows):
"""
Transpose a matrix of size 2SVL x 2SVL (where 'SVL' is the Scalable Vector Length) using
the Scalable Matrix Extension (SME).
Expand Down Expand Up @@ -247,9 +291,6 @@ def impl():
strides=[T.int32(), 1],
)

# Disable predication
ptrue = _create_ptrue_mask("float32")

with T.block("root"):
T.reads(A[0:SVF2, 0:SVF2])
T.writes(A_t[0:SVF2, 0:SVF2])
Expand All @@ -263,19 +304,22 @@ def impl():

input_ptr = A.access_ptr("r", offset=offset)
sub_tile = T.int32(sub_tile_idx)
predicate = _create_active_lane_mask(
A, (row_offset + slice_idx, col_offset), cols
)
T.evaluate(
T.call_llvm_intrin(
"void",
"llvm.aarch64.sme.ld1w.horiz",
T.uint32(4),
ptrue,
predicate,
input_ptr,
sub_tile,
slice_idx,
)
)

# Store columns to the ouptut matrix
# Store columns to the output matrix
with T.serial(0, SVF) as slice_idx:
for sub_tile_idx in range(0, sub_tile_count):
col_offset = SVF if sub_tile_idx >= (sub_tile_count // 2) else 0
Expand All @@ -284,12 +328,15 @@ def impl():

output_ptr = A_t.access_ptr("w", offset=offset)
sub_tile = T.int32(sub_tile_idx)
predicate = _create_active_lane_mask(
A_t, (row_offset + slice_idx, col_offset), rows
)
T.evaluate(
T.call_llvm_intrin(
"void",
"llvm.aarch64.sme.st1w.vert",
T.uint32(4),
ptrue,
predicate,
output_ptr,
sub_tile,
slice_idx,
Expand Down Expand Up @@ -445,7 +492,24 @@ def impl():
return desc, impl()


def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K, in_dtype):
def get_transpose_interleave_intrin_name(in_dtype, out_dtype, extent_cols, extent_rows):
if in_dtype == "float32" and out_dtype == "float32":
sme_transpose_interleave_intrin_name = (
ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE + f"_{extent_cols}_{extent_rows}"
)
tir.TensorIntrin.register(
sme_transpose_interleave_intrin_name,
*get_sme_transpose_interleave_2svlx2svl_fp32_intrin(extent_cols, extent_rows),
override=True,
)
return sme_transpose_interleave_intrin_name
elif in_dtype == "float16" and out_dtype == "float32":
return ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE
else:
raise ValueError("Input/output data type combination not supported.")


def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(M, K, in_dtype):
"""
Compute a GEMM of size 2SVL x 2SVL (where 'SVL' is the Scalable Vector Length using
outer product operations from the Scalable Matrix Extension (SME).
Expand Down Expand Up @@ -579,15 +643,39 @@ def impl():
k_row = k * rows_per_iter
in_dtype_svf = tir.get_vscale_expr(in_dtype)

a_low = T.BufferLoad(A, [k_row, T.Ramp(0, 1, in_dtype_svf)])
b_low = T.BufferLoad(B, [k_row, T.Ramp(0, 1, in_dtype_svf)])

# Ideally we'd rely on predicating the loads and use the same predicate
# for the outer product operation. However, support for predicated
# buffers is not currently supported by multiple lowering passes such as
# "LowerMatchBuffer", therefore the predicate is passed directly to the
# outer product operation for now.
if in_dtype == "float32":
a_high = T.BufferLoad(A, [k_row, T.Ramp(in_dtype_svf, 1, in_dtype_svf)])
b_high = T.BufferLoad(B, [k_row, T.Ramp(in_dtype_svf, 1, in_dtype_svf)])
a_low = (
T.BufferLoad(A, [k_row, T.Ramp(0, 1, in_dtype_svf)]),
_create_active_lane_mask(A, (k_row, 0), K),
)
b_low = (
T.BufferLoad(B, [k_row, T.Ramp(0, 1, in_dtype_svf)]),
_create_active_lane_mask(B, (k_row, 0), K),
)
a_high = (
T.BufferLoad(A, [k_row, T.Ramp(in_dtype_svf, 1, in_dtype_svf)]),
_create_active_lane_mask(A, (k_row, in_dtype_svf), K),
)
b_high = (
T.BufferLoad(B, [k_row, T.Ramp(in_dtype_svf, 1, in_dtype_svf)]),
_create_active_lane_mask(B, (k_row, in_dtype_svf), K),
)
else:
a_high = T.BufferLoad(A, [k_row + 1, T.Ramp(0, 1, in_dtype_svf)])
b_high = T.BufferLoad(B, [k_row + 1, T.Ramp(0, 1, in_dtype_svf)])
a_low = (T.BufferLoad(A, [k_row, T.Ramp(0, 1, in_dtype_svf)]), ptrue)
b_low = (T.BufferLoad(B, [k_row, T.Ramp(0, 1, in_dtype_svf)]), ptrue)
a_high = (
T.BufferLoad(A, [k_row + 1, T.Ramp(0, 1, in_dtype_svf)]),
ptrue,
)
b_high = (
T.BufferLoad(B, [k_row + 1, T.Ramp(0, 1, in_dtype_svf)]),
ptrue,
)

input_combinations = [
(a_low, b_low),
Expand All @@ -606,10 +694,10 @@ def impl():
fmopa_intrin,
T.uint32(5),
sub_tile,
ptrue,
ptrue,
input_1,
input_2,
input_1[1],
input_2[1],
input_1[0],
input_2[0],
)
)

Expand All @@ -626,7 +714,9 @@ def impl():
"void",
"llvm.aarch64.sme.st1w.horiz",
T.uint32(4),
_create_ptrue_mask("float32"),
_create_active_lane_mask(
C, (vert_offset + slice_idx, horiz_offset), M
),
output_ptr,
T.int32(sub_tile_idx),
T.int32(slice_idx),
Expand Down Expand Up @@ -691,10 +781,6 @@ def impl(c: T.handle) -> None:
# in versions of LLVM >= 15. Installations with older versions of LLVM will
# not be able to use them.
if llvm_version_major() >= 15:
TensorIntrin.register(
ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE,
*get_sme_transpose_interleave_2svlx2svl_fp32_intrin(),
)
TensorIntrin.register(
ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE,
*get_sme_transpose_interleave_block2_2svl_fp16_intrin(),
Expand Down
29 changes: 16 additions & 13 deletions python/tvm/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA,
ARM_SME_INIT,
get_sme_gemm_interleaved_mopa_2svlx2svl_intrin,
get_transpose_interleave_intrin_name,
)

transpose_interleave_intrin_name = _get_transpose_interleave_intrin_name(
Expand All @@ -787,7 +788,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
sch.parallel(b)
sch.reorder(b, ko, mo, ki, mi)
sch.tensorize(ki, transpose_interleave_intrin_name)
sch.tensorize(ki, get_transpose_interleave_intrin_name(in_dtype, out_dtype, M_padded, K_padded))

# Interleave the padded weights matrix utilizing the matrix tile
if in_dtype == "float16":
Expand All @@ -797,7 +798,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
no, ni = sch.split(n, factors=(None, tile_N), disable_predication=True)
sch.reorder(ko, no, ki, ni)
sch.tensorize(ki, transpose_interleave_intrin_name)
sch.tensorize(ki, get_transpose_interleave_intrin_name(in_dtype, out_dtype, M_padded, K_padded))

# Split and reorder the loops of the GeMM for tensorization
b, m, n, k = sch.get_loops(gemm_block)
Expand All @@ -816,7 +817,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):

# Tensorize the GeMM update
sme_gemm_interleaved_intrin_name = (
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}_{in_dtype}"
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{M_padded}_{K_padded}_{in_dtype}"
)
tvm.tir.TensorIntrin.register(
sme_gemm_interleaved_intrin_name,
Expand Down Expand Up @@ -922,16 +923,18 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
reshape_block = func_blocks["T_reshape"]
A_pad_block = func_blocks["A_padded_K"] if func_blocks["A_padded_K"] else None
A_pad_block = func_blocks["A_padded_M"] if func_blocks["A_padded_M"] else A_pad_block
if use_sme:
sch.compute_inline(reshape_block)
elif A_pad_block:
sch.compute_inline(reshape_block)
b, m, k = sch.get_loops(A_pad_block)
_, k_inner = sch.split(k, [None, tile_N])
sch.vectorize(k_inner)
sch.compute_at(A_pad_block, mi)
else:
sch.compute_at(reshape_block, mi)
use_explicit_predication = use_sme and in_dtype == "float32"
if not use_explicit_predication:
if use_sme:
sch.compute_inline(reshape_block)
elif A_pad_block:
sch.compute_inline(reshape_block)
b, m, k = sch.get_loops(A_pad_block)
_, k_inner = sch.split(k, [None, tile_N])
sch.vectorize(k_inner)
sch.compute_at(A_pad_block, mi)
else:
sch.compute_at(reshape_block, mi)

# Weight flattening
if func_blocks["weight_flatten"]:
Expand Down
39 changes: 26 additions & 13 deletions python/tvm/topi/arm_cpu/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,23 +133,25 @@ def compute_conv2d_gemm_without_weight_transform(
)

# Pad to tiles (if necessary)
pad_M, pad_K = arm_utils.get_conv2d_im2col_padding(M, K, tile_M, tile_K_A)
pad_N, _ = arm_utils.get_conv2d_weights_padding(N, K, tile_N, tile_K_B)
use_explicit_predication = use_sme and in_dtype == "float32"
if not use_explicit_predication:
pad_M, pad_K = arm_utils.get_conv2d_im2col_padding(M, K, tile_M, tile_K_A)
pad_N, _ = arm_utils.get_conv2d_weights_padding(N, K, tile_N, tile_K_B)

M_padded = M + pad_M
K_padded = K + pad_K
N_padded = N + pad_N
M_padded = M + pad_M
K_padded = K + pad_K
N_padded = N + pad_N

pad_before = (0, 0, 0)
pad_after = (0, pad_M, pad_K)
pad_before = (0, 0, 0)
pad_after = (0, 0, 0) if use_sme else (0, pad_M, pad_K)

if pad_K != 0:
A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_K")
elif pad_M != 0:
A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_M")
if pad_K != 0:
A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_K")
elif pad_M != 0:
A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_M")

idxm = tvm.tir.indexmod
k = te.reduce_axis((0, K_padded), "k")
k = te.reduce_axis((0, K), "k")

# Determine matrix multiplication compute definition
target = Target.current(allow_none=False)
Expand Down Expand Up @@ -300,7 +302,18 @@ def compute_conv2d_gemm_without_weight_transform(
name="C",
)
zero = tvm.tir.const(0)
elif use_scalable_vectors or use_sme:
elif use_explicit_predication:
assert len(B_interleaved_t.shape) == 2
C = te.compute(
(batches, M, N),
lambda b, x, y: te.sum(
A[b, x, k].astype(in_dtype) * B_interleaved_t[k, y].astype(in_dtype),
axis=k,
),
name="C",
)
zero = tvm.tir.const(0)
elif use_scalable_vectors:
assert len(B_interleaved_t.shape) == 2
C = te.compute(
(batches, M_padded, N_padded),
Expand Down
Loading

0 comments on commit 6db06bc

Please sign in to comment.