From c0abab769ff152d87f84963f18a98d2f7c9bdf31 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Mon, 24 Jun 2024 21:24:32 +0800 Subject: [PATCH] [TIR][DLight] Enable SimdGroup op for Metal (#17112) --- include/tvm/tir/builtin.h | 44 ++- python/tvm/dlight/gpu/matmul.py | 145 ++++++++ python/tvm/script/ir_builder/tir/ir.py | 8 + python/tvm/tir/__init__.py | 6 + python/tvm/tir/op.py | 191 +++++++++- python/tvm/tir/tensor_intrin/metal.py | 350 ++++++++++++++++++ src/runtime/thread_storage_scope.h | 7 + src/target/source/codegen_metal.cc | 82 +++- src/target/source/codegen_metal.h | 3 + src/tir/op/builtin.cc | 12 + .../dlight/test_gpu_matmul_tensorize.py | 283 +++++++++++++- 11 files changed, 1124 insertions(+), 7 deletions(-) create mode 100644 python/tvm/tir/tensor_intrin/metal.py diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 5836eb8ea93a..120c1b71be72 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -746,7 +746,7 @@ TVM_DLL const Op& create_barriers(); TVM_DLL const Op& mma_store(); /*! - * \brief tvm intrinsic for zero-initalizing an MMA accumulation registor. + * \brief tvm intrinsic for zero-initializing an MMA accumulation register. * For example, if each thread in a warp of size 32 has 8 elements from the A matrix in * m16xn8xk16 MMA in its registers, this intrinsic can be used to zero-initialize its * 4 accumulation registers. @@ -758,6 +758,48 @@ TVM_DLL const Op& mma_store(); */ TVM_DLL const Op& mma_fill(); +// Metal SimdGroup matrix intrinsics + +/*! + * \brief tvm intrinsic for initializing and simdgroup with given value. + * \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params, + * keeping the similar interface with Metal Spec. + * + * void make_filled_simdgroup_matrix(Var d, PrimExpr index, PrimExpr value, + * int col = 8, int row = 8); + */ +TVM_DLL const Op& make_filled_simdgroup_matrix(); + +/*! + * \brief tvm intrinsic for loading data from device memory or threadgroup memory to simdgroup. + * \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params, + * keeping the similar interface with Metal Spec. + * + * void simdgroup_load(Var d, PrimExpr index, PrimExpr ptr, PrimExpr stride, + int col = 8, int row = 8, bool transpose_matrix = false); + */ +TVM_DLL const Op& simdgroup_load(); + +/*! + * \brief tvm intrinsic for storing data from simdgroup to device memory or threadgroup memory. + * \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params, + * keeping the similar interface with Metal Spec. + * + * void simdgroup_store(Var d, PrimExpr index, PrimExpr ptr, PrimExpr stride, + * int col = 8, int row = 8, bool transpose_matrix = false); + */ +TVM_DLL const Op& simdgroup_store(); + +/*! + * \brief tvm intrinsic for multiply and accumulate two matrices in simdgroup + * \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params, + * keeping the similar interface with Metal Spec. + * + * void simdgroup_mma(Var d, PrimExpr index_d, Var a, PrimExpr index_a, + * Var b, PrimExpr index_b, Var c, PrimExpr index_c); + */ +TVM_DLL const Op& simdgroup_multiply_accumulate(); + // TODO(tvm-team) replace the usage of the vector operations by Shuffle. /*! * \brief Get the high level half of the vector diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index f4ef1f50448b..a5759941caf5 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -313,6 +313,146 @@ def check_sm_version(arch: str) -> int: return int(sm_version) if sm_version.isdigit() else -1 +class MetalMatmul(GPUScheduleRule): + """ + The schedule rule for Metal matmul computation. + """ + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Optional[tir.Schedule]: + from tvm.tir.tensor_intrin.metal import ( # pylint: disable=import-outside-toplevel + get_simdgroup_intrin_group, + ) + + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + block_stmt = sch.get(main_block) + index_maps = get_index_map(block_stmt) + if index_maps is None: + return None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + # Step 0. Configs + block_size_x: int = 16 + block_size_y: int = 16 + block_size_k: int = 32 + micro_size: int = 8 + warp_size: int = 32 + ty_len: int = 1 + tz_len: int = 4 + vector_size: int = 4 + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] + block = sch.reindex(main_block, ("read", 0)) + sch.transform_layout(block, ("write", 0), a_index_map) + block = sch.reindex(main_block, ("read", 1)) + sch.transform_layout(block, ("write", 0), b_index_map) + block = sch.reindex(main_block, ("write", 0)) + sch.transform_layout(block, ("read", 0), c_index_map) + sch.transform_block_layout(main_block, matmul_index_map) + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + ty_len * block_size_x, + tz_len * block_size_y, + block_size_k, + ], + ) + + # Step 3. Schedule matmul to use simdgroup intrinsics + batch, i, j, k = sch.get_loops(main_block) + bx, ty, i0, i1 = sch.split(i, [None, ty_len, block_size_x // micro_size, micro_size]) + by, tz, j0, j1 = sch.split(j, [None, tz_len, block_size_y // micro_size, micro_size]) + k0, k1, k2 = sch.split(k, [None, block_size_k // micro_size, micro_size]) + sch.reorder(bx, by, ty, tz, k0, k1, i0, j0, i1, j1, k2) + sch.bind(bx, "blockIdx.x") + sch.bind(by, "blockIdx.y") + sch.bind(batch, "blockIdx.z") + sch.bind(ty, "threadIdx.y") + sch.bind(tz, "threadIdx.z") + + def fetch_to_shared(block, idx): + block_read = sch.cache_read(block, idx, "shared") + sch.compute_at(block_read, k0, preserve_unit_loops=True) + fused = sch.fuse(*sch.get_loops(block_read)[-2:]) + _, _tz, _ty, _tx, vec = sch.split(fused, [None, tz_len, ty_len, warp_size, vector_size]) + + sch.bind(_tz, "threadIdx.z") + sch.bind(_ty, "threadIdx.y") + sch.bind(_tx, "threadIdx.x") + sch.vectorize(vec) + + return block_read + + a_g2s = fetch_to_shared(main_block, 0) + b_g2s = fetch_to_shared(main_block, 1) + + auto_inline_producers(sch, a_g2s) + auto_inline_producers(sch, b_g2s) + + # create read cache to load matrix from shared memory to wmma fragments + A_simdgroup = sch.cache_read(main_block, 0, "metal.simdgroup") + B_simdgroup = sch.cache_read(main_block, 1, "metal.simdgroup") + sch.compute_at(A_simdgroup, k1) + sch.compute_at(B_simdgroup, k1) + + C_simd2s = sch.cache_write(main_block, 0, "metal.simdgroup") + C_s2g = sch.cache_write(C_simd2s, 0, "shared") + sch.reverse_compute_at(C_simd2s, tz, preserve_unit_loops=True) + sch.reverse_compute_at(C_s2g, by, preserve_unit_loops=True) + + intrin_group = get_simdgroup_intrin_group( + load_scope="shared", + store_scope="shared", + dtype="float16", + trans_a=False, + trans_b=True, + ) + sch.transform_layout(B_simdgroup, ("write", 0), lambda s, i, j: (s, j, i)) + + def tensorize_block(block: tir.schedule.BlockRV, intrin: str): + *_, i, j = sch.get_loops(block) + io, ii = sch.split(i, [None, micro_size]) + jo, ji = sch.split(j, [None, micro_size]) + sch.reorder(io, jo, ii, ji) + sch.tensorize(ii, intrin) + + C_init = sch.decompose_reduction(main_block, k0) + tensorize_block(A_simdgroup, intrin_group["load_a"]) + tensorize_block(B_simdgroup, intrin_group["load_b"]) + tensorize_block(C_simd2s, intrin_group["store"]) + tensorize_block(C_init, intrin_group["init"]) + + *_, i, j, k = sch.get_loops(main_block) + sch.tensorize(i, intrin_group["compute"]) + + auto_inline_consumer_chain(sch, C_s2g) + fused = sch.fuse(*sch.get_loops(C_s2g)[-2:]) + _, _tz, _ty, _tx, vec = sch.split(fused, [None, tz_len, ty_len, warp_size, vector_size]) + sch.bind(_tz, "threadIdx.z") + sch.bind(_ty, "threadIdx.y") + sch.bind(_tx, "threadIdx.x") + sch.vectorize(vec) + + return sch + + class MatmulTensorization(GPUScheduleRule): """ The schedule rule for float16 tensor core matmul computation. @@ -848,6 +988,11 @@ def apply( # pylint: disable=too-many-locals,missing-docstring tensorize_sch = MatmulTensorization().apply(func, target, _) if tensorize_sch is not None: return tensorize_sch + elif target.kind.name == "metal": + try: + return MetalMatmul().apply(func, target, _) + except: # pylint: disable=bare-except + pass # Step 2. Get schedule config. config = self.get_configs(target) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 18abc0ca5d01..caefc6a6bc16 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1887,6 +1887,10 @@ def wrapped(*args, **kwargs): ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier) ptx_arrive_barrier_expect_tx = _op_wrapper(_tir_op.ptx_arrive_barrier_expect_tx) ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier) +make_filled_simdgroup_matrix = _op_wrapper(_tir_op.make_filled_simdgroup_matrix) +simdgroup_load = _op_wrapper(_tir_op.simdgroup_load) +simdgroup_store = _op_wrapper(_tir_op.simdgroup_store) +simdgroup_multiply_accumulate = _op_wrapper(_tir_op.simdgroup_multiply_accumulate) create_barriers = _op_wrapper(_tir_op.create_barriers) assume = _op_wrapper(_tir_op.assume) undef = _op_wrapper(_tir_op.undef) @@ -2177,6 +2181,10 @@ def wrapped(*args, **kwargs): "ptx_arrive_barrier", "ptx_arrive_barrier_expect_tx", "ptx_wait_barrier", + "make_filled_simdgroup_matrix", + "simdgroup_load", + "simdgroup_store", + "simdgroup_multiply_accumulate", "create_barriers", "mma_store", "mma_fill", diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 0fee976eb130..5360ab2b9697 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -73,6 +73,12 @@ ptx_wait_barrier, create_barriers, ) +from .op import ( + make_filled_simdgroup_matrix, + simdgroup_load, + simdgroup_multiply_accumulate, + simdgroup_store, +) from .op import vectorlow, vectorhigh, vectorcombine from .op import infinity, reinterpret from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 95a85ab77d36..81d6604259a3 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=redefined-builtin, invalid-name +# pylint: disable=redefined-builtin, invalid-name, too-many-arguments """Operators used in TIR expression.""" from typing import Any, Optional, Union @@ -1567,6 +1567,195 @@ def create_barriers(barrier_count): return call_intrin("", "tir.create_barriers", barrier_count) +def make_filled_simdgroup_matrix( + d: Var, + index: PrimExpr, + value: PrimExpr, + col: int = 8, + row: int = 8, +): + """Create a filled SIMDGroup matrix + + Parameters + ---------- + d : var + The simdgroup var + + index : PrimExpr + The index of the matrix. + + value : PrimExpr + The value to fill. + + col : int + The number of columns. + + row : int + The number of rows. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.make_filled_simdgroup_matrix", d, index, value, col, row) + + +def simdgroup_load( + d: Var, + index: PrimExpr, + ptr: PrimExpr, + stride: PrimExpr, + col: int = 8, + row: int = 8, + transpose_matrix: bool = False, +): + """Load data from device memory or threadgroup memory to simdgroup + + Parameters + ---------- + d : var + The simdgroup var + + index : PrimExpr + The index of the matrix. + + ptr : PrimExpr + The pointer. + + stride : PrimExpr + The stride. + + col : int + The number of columns. + + row : int + The number of rows. + + transpose_matrix : bool + Whether to transpose the matrix. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "handle", + "tir.simdgroup_load", + d, + index, + ptr, + stride, + col, + row, + transpose_matrix, + ) + + +def simdgroup_store( + d: PrimExpr, + index: PrimExpr, + ptr: PrimExpr, + stride: PrimExpr, + col: int = 8, + row: int = 8, + transpose_matrix: bool = False, +): + """Store data from simdgroup to device memory or threadgroup memory + + Parameters + ---------- + d : PrimExpr + The SIMDGroup. + + index : PrimExpr + The index of the matrix. + + ptr : PrimExpr + The pointer. + + stride : PrimExpr + The stride. + + col : int + The number of columns. + + row : int + The number of rows. + + + transpose_matrix : bool + Whether to transpose the matrix. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "handle", "tir.simdgroup_store", d, index, ptr, stride, col, row, transpose_matrix + ) + + +def simdgroup_multiply_accumulate( + d: Var, + index_d: PrimExpr, + a: Var, + index_a: PrimExpr, + b: Var, + index_b: PrimExpr, + c: Var, + index_c: PrimExpr, +): + """Multiply and accumulate two matrices in simdgroup + i.e. d = a * b + c + + Parameters + ---------- + d : Var + The destination matrix. + + index_d : PrimExpr + The index of the destination matrix. + + a : Var + The first matrix. + + index_a : PrimExpr + The index of the first matrix. + + b : Var + The second matrix. + + index_b : PrimExpr + The index of the second matrix. + + c : Var + The third matrix. + + index_c : PrimExpr + The index of the third matrix. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "handle", + "tir.simdgroup_multiply_accumulate", + d, + index_d, + a, + index_a, + b, + index_b, + c, + index_c, + ) + + def vectorlow(dtype, vec): """Get the low level half of the vector diff --git a/python/tvm/tir/tensor_intrin/metal.py b/python/tvm/tir/tensor_intrin/metal.py new file mode 100644 index 000000000000..be34a9e266c8 --- /dev/null +++ b/python/tvm/tir/tensor_intrin/metal.py @@ -0,0 +1,350 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,missing-function-docstring,unused-variable +"""Intrinsics for tensorization on Apple GPU.""" +from typing import Dict, Literal, Tuple + +from tvm.script import tir as T +from tvm.tir import Buffer, PrimExpr, PrimFunc, TensorIntrin + +######## simdgroup matrix intrinsics ######## + + +def get_simdgroup_index(buffer: Buffer, stride: PrimExpr, col: int, row: int): + """Compute simdgroup index using elem_offset of the buffer""" + + # NOTE: Need further check the usage between `col`` and `row` + # Currently, Metal only supports 8x8, which means the values of `col` and `row` are the same + frag_index_m = buffer.elem_offset // stride // col + frag_index_n = buffer.elem_offset % stride // row + + num_fragments_per_row = stride // row + return frag_index_m * num_fragments_per_row + frag_index_n + + +def get_make_filled_simdgroup_matrix_intrin( + dtype: str, col: int = 8, row: int = 8 +) -> Tuple[PrimFunc, PrimFunc]: + @T.prim_func + def desc(a: T.handle) -> None: + A = T.match_buffer(a, (col, row), dtype, scope="metal.simdgroup", offset_factor=1) + with T.block("root"): + T.reads() + T.writes(A[0:col, 0:row]) + for i, j in T.grid(col, row): + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = T.float32(0) + + @T.prim_func + def impl(a: T.handle) -> None: + d0, d1 = T.int32(), T.int32() + A = T.match_buffer( + a, (col, row), dtype, scope="metal.simdgroup", strides=[d1, d0], offset_factor=1 + ) + with T.block("root"): + T.reads() + T.writes(A[0:col, 0:row]) + T.make_filled_simdgroup_matrix( + A.data, + index=get_simdgroup_index(A, d1, col, row), + value=T.float32(0), + col=col, + row=row, + ) + + return desc, impl + + +def get_simdgroup_load_intrin( + dtype: str, + scope: Literal["global", "shared"], + col: int = 8, + row: int = 8, + transpose_matrix: bool = False, +) -> Tuple[PrimFunc, PrimFunc]: + align = col * row + + @T.prim_func + def desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (col, row), dtype, align=align, scope=scope, offset_factor=1) + C = T.match_buffer( + c, (col, row), dtype, align=align, scope="metal.simdgroup", offset_factor=1 + ) + with T.block("root"): + T.reads(A[0:col, 0:row]) + T.writes(C[0:col, 0:row]) + for i, j in T.grid(col, row): + with T.block("load"): + vii, vjj = T.axis.remap("SS", [i, j]) + if transpose_matrix: + # C[vii, vjj] = A[vjj, vii] + C[vjj, vii] = A[vii, vjj] + else: + C[vii, vjj] = A[vii, vjj] + + @T.prim_func + def impl(a: T.handle, c: T.handle) -> None: + s0, s1, d0, d1 = T.int32(), T.int32(), T.int32(), T.int32() + A = T.match_buffer( + a, + (col, row), + dtype, + align=align, + scope=scope, + strides=[s1, s0], + offset_factor=1, + ) + C = T.match_buffer( + c, + (col, row), + dtype, + align=align, + scope="metal.simdgroup", + strides=[d1, d0], + offset_factor=1, + ) + with T.block("root"): + T.reads(A[0:col, 0:row]) + T.writes(C[0:col, 0:row]) + T.simdgroup_load( + C.data, + index=get_simdgroup_index(C, d1, col, row), + ptr=A.access_ptr("r"), + stride=s1, + col=col, + row=row, + transpose_matrix=transpose_matrix, + ) + + return desc, impl + + +def get_simdgroup_store_intrin( + dtype: str, + scope: Literal["global", "shared"], + col: int = 8, + row: int = 8, + transpose_matrix: bool = False, +) -> Tuple[PrimFunc, PrimFunc]: + align = col * row + + @T.prim_func + def desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer( + a, (col, row), dtype, align=align, scope="metal.simdgroup", offset_factor=1 + ) + C = T.match_buffer(c, (col, row), dtype, align=align, scope=scope, offset_factor=1) + with T.block("root"): + T.reads(A[0:col, 0:row]) + T.writes(C[0:col, 0:row]) + for i, j in T.grid(col, row): + with T.block("store"): + vii, vjj = T.axis.remap("SS", [i, j]) + if transpose_matrix: + C[vjj, vii] = A[vii, vjj] + else: + C[vii, vjj] = A[vii, vjj] + + @T.prim_func + def impl(a: T.handle, c: T.handle) -> None: + s0, s1, d0, d1 = T.int32(), T.int32(), T.int32(), T.int32() + A = T.match_buffer( + a, + (col, row), + dtype, + align=align, + scope="metal.simdgroup", + strides=[s1, s0], + offset_factor=1, + ) + C = T.match_buffer( + c, (col, row), dtype, align=align, scope=scope, strides=[d1, d0], offset_factor=1 + ) + with T.block("root"): + T.reads(A[0:col, 0:row]) + T.writes(C[0:col, 0:row]) + T.simdgroup_store( + A.data, + index=get_simdgroup_index(A, s1, col, row), + ptr=C.access_ptr("w"), + stride=d1, + col=col, + row=row, + transpose_matrix=transpose_matrix, + ) + + return desc, impl + + +def get_simdgroup_multiply_accumulate_intrin( + m_dim: int, n_dim: int, k_dim: int, dtype: str +) -> Tuple[PrimFunc, PrimFunc]: + @T.prim_func + def desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (m_dim, k_dim), dtype, scope="metal.simdgroup", offset_factor=1) + B = T.match_buffer(b, (k_dim, n_dim), dtype, scope="metal.simdgroup", offset_factor=1) + C = T.match_buffer(c, (m_dim, n_dim), dtype, scope="metal.simdgroup", offset_factor=1) + with T.block("root"): + T.reads(C[0:m_dim, 0:n_dim], A[0:m_dim, 0:k_dim], B[0:k_dim, 0:n_dim]) + T.writes(C[0:m_dim, 0:n_dim]) + for i, j, k in T.grid(m_dim, n_dim, k_dim): + with T.block(""): + vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) + C[vii, vjj] += A[vii, vkk] * B[vkk, vjj] + + @T.prim_func + def impl(a: T.handle, b: T.handle, c: T.handle) -> None: + a0, a1, b0, b1, c0, c1 = T.int32(), T.int32(), T.int32(), T.int32(), T.int32(), T.int32() + A = T.match_buffer( + a, (m_dim, k_dim), dtype, scope="metal.simdgroup", strides=[a1, a0], offset_factor=1 + ) + B = T.match_buffer( + b, (k_dim, n_dim), dtype, scope="metal.simdgroup", strides=[b1, b0], offset_factor=1 + ) + C = T.match_buffer( + c, (m_dim, n_dim), dtype, scope="metal.simdgroup", strides=[c1, c0], offset_factor=1 + ) + with T.block("root"): + T.reads(C[0:m_dim, 0:n_dim], A[0:m_dim, 0:k_dim], B[0:k_dim, 0:n_dim]) + T.writes(C[0:m_dim, 0:n_dim]) + T.simdgroup_multiply_accumulate( + C.data, + get_simdgroup_index(C, c1, m_dim, n_dim), + A.data, + get_simdgroup_index(A, a1, m_dim, k_dim), + B.data, + get_simdgroup_index(B, b1, k_dim, n_dim), + C.data, + get_simdgroup_index(C, c1, m_dim, n_dim), + ) + + return desc, impl + + +# Make filled simdgroup matrix intrinsics + +SIMDGROUP_MAKE_FILLED_8x8x8_f16_INTRIN = "simdgroup_make_filled_8x8x8_f16" +TensorIntrin.register( + SIMDGROUP_MAKE_FILLED_8x8x8_f16_INTRIN, + *get_make_filled_simdgroup_matrix_intrin("float16", 8, 8), +) + +SIMDGROUP_FILLED_8x8x8_f32_INTRIN = "simdgroup_fill_8x8x8_f32" +TensorIntrin.register( + SIMDGROUP_FILLED_8x8x8_f32_INTRIN, *get_make_filled_simdgroup_matrix_intrin("float32", 8, 8) +) + +SIMDGROUP_FILLED_8x8x8_bf16_INTRIN = "simdgroup_fill_8x8x8_bf16" +TensorIntrin.register( + SIMDGROUP_FILLED_8x8x8_bf16_INTRIN, *get_make_filled_simdgroup_matrix_intrin("bfloat16", 8, 8) +) + +# Load intrinsics + +SIMDGROUP_LOAD_8x8x8_f16_SHARED_INTRIN = "simdgroup_load_8x8x8_f16_shared" +TensorIntrin.register( + SIMDGROUP_LOAD_8x8x8_f16_SHARED_INTRIN, + *get_simdgroup_load_intrin("float16", "shared", 8, 8, False), +) + +SIMDGROUP_LOAD_8x8x8_f16_SHARED_TRANS_INTRIN = "simdgroup_load_8x8x8_f16_shared_trans" +TensorIntrin.register( + SIMDGROUP_LOAD_8x8x8_f16_SHARED_TRANS_INTRIN, + *get_simdgroup_load_intrin("float16", "shared", 8, 8, True), +) + +# Store intrinsics + +SIMDGROUP_STORE_8x8x8_f16_GLOBAL_INTRIN = "simdgroup_store_8x8x8_f16_global" +TensorIntrin.register( + SIMDGROUP_STORE_8x8x8_f16_GLOBAL_INTRIN, + *get_simdgroup_store_intrin("float16", "global", 8, 8, False), +) + +SIMDGROUP_STORE_8x8x8_f16_SHARED_INTRIN = "simdgroup_store_8x8x8_f16_shared" +TensorIntrin.register( + SIMDGROUP_STORE_8x8x8_f16_SHARED_INTRIN, + *get_simdgroup_store_intrin("float16", "shared", 8, 8, False), +) +# Multiply accumulate intrinsics + +SIMDGROUP_MULTI_ACC_8x8x8_f16_INTRIN = "simdgroup_multiply_accumulate_8x8x8_f16" +TensorIntrin.register( + SIMDGROUP_MULTI_ACC_8x8x8_f16_INTRIN, + *get_simdgroup_multiply_accumulate_intrin(8, 8, 8, "float16"), +) + + +def get_simdgroup_intrin_group( + load_scope: Literal["shared"], + store_scope: Literal["global", "shared"], + dtype: str, + trans_a: bool = False, + trans_b: bool = False, +) -> Dict[str, str]: + """Get a group of intrinsics for tensorization on Apple GPU. + + Parameters + ---------- + load_scope : Literal["shared"] + The memory scope of the input buffer. + + store_scope : Literal["global", "shared"] + The memory scope of the result buffer. + + dtype : str + The data type of the input and output buffers. + + trans_a : bool + Whether the input matrix A is transposed. + + trans_b : bool + Whether the input matrix B is transposed. + + Returns + ------- + ret : Dict[str, str] + A group of tensor intrinsics. + """ + assert load_scope in ["shared"] + assert store_scope in ["global", "shared"] + assert dtype in ["float16", "bfloat16", "float32"] + + shape = "8x8x8" + dtype = "f16" if dtype == "float16" else "bf16" if dtype == "bfloat16" else "f32" + trans_a = "_trans" if trans_a else "" + trans_b = "_trans" if trans_b else "" + + # e.g. simdgroup_load_8x8x8_f16_shared + load_a_intrin = f"simdgroup_load_{shape}_{dtype}_{load_scope}{trans_a}" + # e.g. simdgroup_load_8x8x8_f16_shared_trans + load_b_intrin = f"simdgroup_load_{shape}_{dtype}_{load_scope}{trans_b}" + # e.g. simdgroup_multiply_accumulate_8x8x8_f16 + compute_intrin = f"simdgroup_multiply_accumulate_{shape}_{dtype}" + # e.g. simdgroup_make_filled_8x8x8_f16 + init_intrin = f"simdgroup_make_filled_{shape}_{dtype}" + # e.g. simdgroup_store_8x8x8_f16_global + store_intrin = f"simdgroup_store_{shape}_{dtype}_{store_scope}" + + return { + "init": init_intrin, + "load_a": load_a_intrin, + "load_b": load_b_intrin, + "compute": compute_intrin, + "store": store_intrin, + } diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 747b90581207..d1af2cb701a0 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -70,6 +70,8 @@ enum class StorageRank { kMMAMatrixB = 10, /*! \brief mma scope memory of accumulator */ kMMAMatrixC = 11, + /*! \brief Metal SIMD group memory */ + kMetalSimdGroup = 12, }; /*! @@ -126,6 +128,8 @@ struct StorageScope { return "m16n8k8.matrixB" + tag; case StorageRank::kMMAMatrixC: return "m16n8k8.matrixC" + tag; + case StorageRank::kMetalSimdGroup: + return "metal.simdgroup" + tag; default: LOG(FATAL) << "unknown storage scope"; } @@ -175,6 +179,9 @@ struct StorageScope { } else if (s.compare(0, 15, "m16n8k8.matrixC") == 0) { r.rank = StorageRank::kMMAMatrixC; r.tag = s.substr(15, std::string::npos); + } else if (s.compare(0, 15, "metal.simdgroup") == 0) { + r.rank = StorageRank::kMetalSimdGroup; + r.tag = s.substr(15, std::string::npos); } else { LOG(FATAL) << "unknown storage scope " << s; } diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index e729af417ca8..290851498843 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -25,10 +25,10 @@ #include #include +#include #include #include #include -#include #include "../../runtime/metal/metal_module.h" #include "../../runtime/thread_storage_scope.h" @@ -262,6 +262,9 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << lanes; return; } + } else if (t.is_bfloat16()) { + os << "bfloat"; + return; } LOG(FATAL) << "Cannot convert type " << t << " to Metal type"; } @@ -296,9 +299,43 @@ void CodeGenMetal::PrintStorageScope(const std::string& scope, std::ostream& os) os << "device "; } else if (scope == "shared") { os << "threadgroup "; - } else { + } else if (scope == "local") { os << "thread "; + } else { + LOG(FATAL) << "Unknown storage scope `" << scope << "`"; + } +} + +void CodeGenMetal::VisitStmt_(const AllocateNode* op) { + ICHECK(!is_zero(op->condition)); + std::string vid = AllocVarID(op->buffer_var.get()); + + this->PrintIndent(); + size_t constant_size = op->ConstantAllocationSize(); + ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; + + auto scope = GetPtrStorageScope(op->buffer_var); + alloc_storage_scope_[op->buffer_var.get()] = scope; + if (scope == "metal.simdgroup") { + ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) || + op->dtype == DataType::BFloat(16)) + << "Only float16, float32, and bfloat16 are supported, but got " << op->dtype; + ICHECK(constant_size % 64 == 0) + << "Only 8x8 matrix is supported, but got " << constant_size << " bytes\n"; + + std::ostringstream dtype_os; + PrintType(op->dtype, dtype_os); + std::string dtype_str = dtype_os.str(); + simdgroup_dtype_[op->buffer_var.get()] = dtype_str; + stream << "simdgroup_" << dtype_str << "8x8 " << vid << '[' << constant_size / 64 << "];\n"; + } else { + PrintStorageScope(scope, stream); + PrintType(op->dtype, stream); + stream << ' ' << vid << '[' << constant_size << "];\n"; } + + RegisterHandleType(op->buffer_var.get(), op->dtype); + this->PrintStmt(op->body); } void CodeGenMetal::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*) @@ -322,7 +359,46 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT CHECK(!op->op.as()) << "CodegenMetal does not support inter-function calls, " << "but expression " << GetRef(op) << " calls PrimFunc " << op->op; - if (op->op.same_as(builtin::reinterpret())) { + auto f_check_simdgroup_shape = [](PrimExpr col, PrimExpr row) { + ICHECK(col->IsInstance() && row->IsInstance()) + << "Only constant shape is supported for simdgroup matrix, but got " << col << "x" << row; + int col_val = col.as()->value; + int row_val = row.as()->value; + ICHECK(col_val == 8 && row_val == 8) + << "Only 8x8 matrix is supported, but got " << col_val << "x" << row_val; + }; + if (op->op.same_as(builtin::make_filled_simdgroup_matrix())) { + ICHECK_EQ(op->args.size(), 5); + Var var = runtime::Downcast(op->args[0]); + // Get the data type of the simdgroup matrix + auto it = simdgroup_dtype_.find(var.get()); + ICHECK(it != simdgroup_dtype_.end()) + << "Cannot find variable allocation for simdgroup: " << var; + const std::string& dtype_str = it->second; + f_check_simdgroup_shape(op->args[3], op->args[4]); + os << PrintExpr(var) << "[" << PrintExpr(op->args[1]) << "] = make_filled_simdgroup_matrix<" + << dtype_str << ", " << PrintExpr(op->args[3]) << ", " << PrintExpr(op->args[4]) << ">(" + << PrintExpr(op->args[2]) << ")"; + } else if (op->op.same_as(builtin::simdgroup_load())) { + ICHECK_EQ(op->args.size(), 7); + f_check_simdgroup_shape(op->args[4], op->args[5]); + os << "simdgroup_load(" << PrintExpr(op->args[0]) << "[" << PrintExpr(op->args[1]) << "], " + << PrintExpr(op->args[2]) << ", " << PrintExpr(op->args[3]) << ", 0, " + << PrintExpr(op->args[6]) << ")"; + } else if (op->op.same_as(builtin::simdgroup_store())) { + ICHECK_EQ(op->args.size(), 7); + f_check_simdgroup_shape(op->args[4], op->args[5]); + os << "simdgroup_store(" << PrintExpr(op->args[0]) << "[" << PrintExpr(op->args[1]) << "], " + << PrintExpr(op->args[2]) << ", " << PrintExpr(op->args[3]) << ", 0, " + << PrintExpr(op->args[6]) << ")"; + } else if (op->op.same_as(builtin::simdgroup_multiply_accumulate())) { + ICHECK_EQ(op->args.size(), 8); + os << "simdgroup_multiply_accumulate(" // + << PrintExpr(op->args[0]) << "[" << PrintExpr(op->args[1]) << "], " // + << PrintExpr(op->args[2]) << "[" << PrintExpr(op->args[3]) << "], " // + << PrintExpr(op->args[4]) << "[" << PrintExpr(op->args[5]) << "], " // + << PrintExpr(op->args[6]) << "[" << PrintExpr(op->args[7]) << "])"; + } else if (op->op.same_as(builtin::reinterpret())) { // generate as_type(ARG) os << "(as_type<"; this->PrintType(op->dtype, os); diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h index 9cff3211ce44..9bc0e15d155f 100644 --- a/src/target/source/codegen_metal.h +++ b/src/target/source/codegen_metal.h @@ -27,6 +27,7 @@ #include #include +#include #include "codegen_c.h" @@ -50,6 +51,7 @@ class CodeGenMetal final : public CodeGenC { // print store of single element. void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; // overload visitor + void VisitStmt_(const AllocateNode* op) final; // NOLINT(*) void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) @@ -59,6 +61,7 @@ class CodeGenMetal final : public CodeGenC { using CodeGenC::PrintType; private: + std::unordered_map simdgroup_dtype_; int thread_index_bits_{32}; int thread_work_dim_{0}; Target target_; diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 67d01aa92389..0404fd28230e 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -328,6 +328,18 @@ TIR_DEFINE_BUILTIN_FUNC(mma_fill) .set_attr("TScriptDtypePrintLocation", Integer(ScriptDtypePrintLocation::kFirst)); +TIR_DEFINE_BUILTIN_FUNC(make_filled_simdgroup_matrix) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(simdgroup_load) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(simdgroup_store) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(simdgroup_multiply_accumulate) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(vectorhigh) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_attr("TScriptDtypePrintLocation", diff --git a/tests/python/dlight/test_gpu_matmul_tensorize.py b/tests/python/dlight/test_gpu_matmul_tensorize.py index 095447766e28..59ccfec55cc5 100644 --- a/tests/python/dlight/test_gpu_matmul_tensorize.py +++ b/tests/python/dlight/test_gpu_matmul_tensorize.py @@ -14,12 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=missing-docstring +# pylint: disable=missing-docstring, unused-variable, invalid-name +# flake8: noqa: E501 import pytest import tvm.testing from tvm import dlight as dl -from tvm.script import ir as I from tvm.script import tir as T from tvm.target import Target @@ -698,5 +698,284 @@ def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T. # fmt: on +class MetalBeforeAfter(tvm.testing.CompareBeforeAfter): + @pytest.fixture + def transform(self): + def transform(mod): + with Target("metal"): + return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod) + + return transform + + +class TestMatmulMetal(MetalBeforeAfter): + # fmt: off + @T.prim_func(private=True) + def before( + var_A: T.handle, + B: T.Buffer((28672, 4096), "float16"), + var_C: T.handle, + ): + batch_size = T.int32() + A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16") + C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16") + for i0, i1, i2, k in T.grid(batch_size, 1, 28672, 4096): + with T.block("C"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.writes(C[v_i0, v_i1, v_i2]) + with T.init(): + C[v_i0, v_i1, v_i2] = T.float16(0) + C[v_i0, v_i1, v_i2] += A[v_i0, v_i1, v_k] * B[v_i2, v_k] + + @T.prim_func + def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.handle): + T.func_attr({"tir.is_scheduled": 1}) + batch_size = T.int32() + A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16") + C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16") + # with T.block("root"): + A_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 4096), "float16", scope="shared") + B_reindex_shared = T.alloc_buffer((1, 28672, 4096), "float16", scope="shared") + A_reindex_pad_shared_metal_simdgroup = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 4096), "float16", scope="metal.simdgroup") + B_reindex_shared_metal_simdgroup = T.alloc_buffer((1, 4096, 28672), "float16", scope="metal.simdgroup") + C_reindex_pad_metal_simdgroup = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 28672), "float16", scope="metal.simdgroup") + C_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 28672), "float16", scope="shared") + for ax0 in T.thread_binding(1, thread="blockIdx.z"): + for ax1_0 in T.thread_binding((batch_size + 15) // 16, thread="blockIdx.x"): + for ax2_0 in T.thread_binding(448, thread="blockIdx.y"): + for ax1_1 in T.thread_binding(1, thread="threadIdx.y"): + for ax2_1 in T.thread_binding(4, thread="threadIdx.z"): + for ax1_2_init, ax2_2_init, ax1_3_init_0, ax2_3_init_0 in T.grid(2, 2, 1, 1): + with T.block("C_init_o"): + v0_o = T.axis.spatial(1, ax0) + v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2_init + ax1_3_init_0) + v2_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax2_2_init + ax2_3_init_0) + T.reads() + T.writes(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) + A_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="metal.simdgroup", offset_factor=1) + T.make_filled_simdgroup_matrix(A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, T.float32(0), 8, 8) + for ax3_0 in range(128): + for ax0_1, ax1_ax2_fused_0 in T.grid(1, 1): + for ax1_ax2_fused_1 in T.thread_binding(4, thread="threadIdx.z"): + for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): + for ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"): + for ax1_ax2_fused_4 in T.vectorized(4): + with T.block("A_reindex_pad_shared"): + v0 = T.axis.spatial(1, ax0_1) + v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 32) + v2 = T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32) + T.reads(A[v1, 0, v2]) + T.writes(A_reindex_pad_shared[v0, v1, v2]) + A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, A[v1, 0, v2], T.float16(0)) + for ax0_1, ax1_ax2_fused_0 in T.grid(1, 4): + for ax1_ax2_fused_1 in T.thread_binding(4, thread="threadIdx.z"): + for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): + for ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"): + for ax1_ax2_fused_4 in T.vectorized(4): + with T.block("B_reindex_shared"): + v0 = T.axis.spatial(1, ax0_1) + v1 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 32) + v2 = T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32) + T.reads(B[v1, v2]) + T.writes(B_reindex_shared[v0, v1, v2]) + B_reindex_shared[v0, v1, v2] = B[v1, v2] + for ax3_1 in range(4): + for ax0_0, ax1_0_1 in T.grid(2, 1): + with T.block("A_reindex_pad_shared_metal.simdgroup_o"): + v0_o = T.axis.spatial(1, 0) + v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax0_0) + v2_o = T.axis.spatial(512, ax3_0 * 4 + ax3_1 + ax1_0_1) + T.reads(A_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) + T.writes(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) + A_1 = T.match_buffer(A_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared", offset_factor=1) + C_1 = T.match_buffer(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) + T.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(False)) + for ax0_0, ax1_0_1 in T.grid(2, 1): + with T.block("B_reindex_shared_metal.simdgroup_o"): + v0_o = T.axis.spatial(1, 0) + v1_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax0_0) + v2_o = T.axis.spatial(512, ax3_0 * 4 + ax3_1 + ax1_0_1) + T.reads(B_reindex_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) + T.writes(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, v1_o * 8:v1_o * 8 + 8]) + A_1 = T.match_buffer(B_reindex_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared", offset_factor=1) + C_1 = T.match_buffer(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, v1_o * 8:v1_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) + T.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(True)) + for ax1_2, ax2_2 in T.grid(2, 2): + with T.block("C_update_o"): + v0_o = T.axis.spatial(1, ax0) + v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2) + v2_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax2_2) + v3_o = T.axis.reduce(512, ax3_0 * 4 + ax3_1) + T.reads(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], A_reindex_pad_shared_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v3_o * 8:v3_o * 8 + 8], B_reindex_shared_metal_simdgroup[0, v3_o * 8:v3_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) + T.writes(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) + A_1 = T.match_buffer(A_reindex_pad_shared_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v3_o * 8:v3_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="metal.simdgroup", offset_factor=1) + B_1 = T.match_buffer(B_reindex_shared_metal_simdgroup[0, v3_o * 8:v3_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("B_s0", "B_s1"), scope="metal.simdgroup", offset_factor=1) + C_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) + T.simdgroup_multiply_accumulate(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, B_1.data, B_1.elem_offset // B_1.strides[0] // 8 * (B_1.strides[0] // 8) + B_1.elem_offset % B_1.strides[0] // 8, C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8) + for ax0_1, ax1_0_1, ax2_0_1 in T.grid(1, 2, 2): + with T.block("C_reindex_pad_metal.simdgroup_o"): + v0_o = T.axis.spatial(1, ax0_1) + v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax1_0_1) + v2_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax2_0_1) + T.reads(C_reindex_pad_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) + T.writes(C_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) + A_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="metal.simdgroup", offset_factor=1) + C_1 = T.match_buffer(C_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="shared", offset_factor=1) + T.simdgroup_store(A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), C_1.data, C_1.elem_offset, C_1.strides[0] * 8, 2), C_1.strides[0], 8, 8, T.bool(False)) + for ax0_1, ax1_ax2_fused_0 in T.grid(1, 2): + for ax1_ax2_fused_1 in T.thread_binding(4, thread="threadIdx.z"): + for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): + for ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"): + for ax1_ax2_fused_4 in T.vectorized(4): + with T.block("C_reindex_pad_shared"): + v0 = T.axis.spatial(1, ax0_1) + v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64) + v2 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64) + T.reads(C_reindex_pad_shared[v0, v1, v2]) + T.writes(C[v1, 0, v2]) + if v1 < batch_size: + C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2] + # fmt: on + + +class TestMatmulMetalInt4Quant(MetalBeforeAfter): + # fmt: off + @T.prim_func(private=True) + def before( + B0: T.Buffer((28672, 512), "uint32"), + B1: T.Buffer((28672, 128), "float16"), + var_A: T.handle, + var_C: T.handle + ): + batch_size = T.int32() + A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16") + C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16") + compute = T.alloc_buffer((28672, 4096), "float16") + B = T.alloc_buffer((28672, 4096), "float16") + for i0, i1 in T.grid(28672, 4096): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(B0[v_i0, v_i1 // 8], T.Cast("uint32", v_i1 % 8 * 4)), T.uint32(15))) + for i0, i1 in T.grid(28672, 4096): + with T.block("dequantize"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + B[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) * B1[v_i0, v_i1 // 32] + for i0, i1, i2, k in T.grid(batch_size, 1, 28672, 4096): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + with T.init(): + C[v_i0, v_i1, v_i2] = T.float16(0) + C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * B[v_i2, v_k] + + @T.prim_func(private=True) + def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "float16"), var_A: T.handle, var_C: T.handle): + T.func_attr({"tir.is_scheduled": 1}) + batch_size = T.int32() + A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16") + C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16") + # with T.block("root"): + A_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 4096), "float16", scope="shared") + B_reindex_shared = T.alloc_buffer((1, 28672, 4096), "float16", scope="shared") + A_reindex_pad_shared_metal_simdgroup = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 4096), "float16", scope="metal.simdgroup") + B_reindex_shared_metal_simdgroup = T.alloc_buffer((1, 4096, 28672), "float16", scope="metal.simdgroup") + C_reindex_pad_metal_simdgroup = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 28672), "float16", scope="metal.simdgroup") + C_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 28672), "float16", scope="shared") + for ax0 in T.thread_binding(1, thread="blockIdx.z"): + for ax1_0 in T.thread_binding((batch_size + 15) // 16, thread="blockIdx.x"): + for ax2_0 in T.thread_binding(448, thread="blockIdx.y"): + for ax1_1 in T.thread_binding(1, thread="threadIdx.y"): + for ax2_1 in T.thread_binding(4, thread="threadIdx.z"): + for ax1_2_init, ax2_2_init, ax1_3_init_0, ax2_3_init_0 in T.grid(2, 2, 1, 1): + with T.block("NT_matmul_init_o"): + v0_o = T.axis.spatial(1, ax0) + v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2_init + ax1_3_init_0) + v2_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax2_2_init + ax2_3_init_0) + T.reads() + T.writes(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) + A_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="metal.simdgroup", offset_factor=1) + T.make_filled_simdgroup_matrix(A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, T.float32(0), 8, 8) + for ax3_0 in range(128): + for ax0_1, ax1_ax2_fused_0 in T.grid(1, 1): + for ax1_ax2_fused_1 in T.thread_binding(4, thread="threadIdx.z"): + for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): + for ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"): + for ax1_ax2_fused_4 in T.vectorized(4): + with T.block("A_reindex_pad_shared"): + v0 = T.axis.spatial(1, ax0_1) + v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 32) + v2 = T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32) + T.reads(A[v1, 0, v2]) + T.writes(A_reindex_pad_shared[v0, v1, v2]) + A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, A[v1, 0, v2], T.float16(0)) + for ax0_1, ax1_ax2_fused_0 in T.grid(1, 4): + for ax1_ax2_fused_1 in T.thread_binding(4, thread="threadIdx.z"): + for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): + for ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"): + for ax1_ax2_fused_4 in T.vectorized(4): + with T.block("B_reindex_shared"): + v0 = T.axis.spatial(1, ax0_1) + v1 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 32) + v2 = T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32) + T.reads(B0[v1, v2 // 8], B1[v1, v2 // 32]) + T.writes(B_reindex_shared[v0, v1, v2]) + B_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(B0[v1, v2 // 8], T.Cast("uint32", v2 % 8 * 4)), T.uint32(15))) - T.float16(7)) * B1[v1, v2 // 32] + for ax3_1 in range(4): + for ax0_0, ax1_0_1 in T.grid(2, 1): + with T.block("A_reindex_pad_shared_metal.simdgroup_o"): + v0_o = T.axis.spatial(1, 0) + v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax0_0) + v2_o = T.axis.spatial(512, ax3_0 * 4 + ax3_1 + ax1_0_1) + T.reads(A_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) + T.writes(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) + A_1 = T.match_buffer(A_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared", offset_factor=1) + C_1 = T.match_buffer(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) + T.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(False)) + for ax0_0, ax1_0_1 in T.grid(2, 1): + with T.block("B_reindex_shared_metal.simdgroup_o"): + v0_o = T.axis.spatial(1, 0) + v1_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax0_0) + v2_o = T.axis.spatial(512, ax3_0 * 4 + ax3_1 + ax1_0_1) + T.reads(B_reindex_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) + T.writes(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, v1_o * 8:v1_o * 8 + 8]) + A_1 = T.match_buffer(B_reindex_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared", offset_factor=1) + C_1 = T.match_buffer(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, v1_o * 8:v1_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) + T.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(True)) + for ax1_2, ax2_2 in T.grid(2, 2): + with T.block("NT_matmul_update_o"): + v0_o = T.axis.spatial(1, ax0) + v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2) + v2_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax2_2) + v3_o = T.axis.reduce(512, ax3_0 * 4 + ax3_1) + T.reads(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], A_reindex_pad_shared_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v3_o * 8:v3_o * 8 + 8], B_reindex_shared_metal_simdgroup[0, v3_o * 8:v3_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) + T.writes(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) + A_1 = T.match_buffer(A_reindex_pad_shared_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v3_o * 8:v3_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="metal.simdgroup", offset_factor=1) + B = T.match_buffer(B_reindex_shared_metal_simdgroup[0, v3_o * 8:v3_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("B_s0", "B_s1"), scope="metal.simdgroup", offset_factor=1) + C_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) + T.simdgroup_multiply_accumulate(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, B.data, B.elem_offset // B.strides[0] // 8 * (B.strides[0] // 8) + B.elem_offset % B.strides[0] // 8, C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8) + for ax0_1, ax1_0_1, ax2_0_1 in T.grid(1, 2, 2): + with T.block("C_reindex_pad_metal.simdgroup_o"): + v0_o = T.axis.spatial(1, ax0_1) + v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax1_0_1) + v2_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax2_0_1) + T.reads(C_reindex_pad_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) + T.writes(C_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) + A_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="metal.simdgroup", offset_factor=1) + C_1 = T.match_buffer(C_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="shared", offset_factor=1) + T.simdgroup_store(A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), C_1.data, C_1.elem_offset, C_1.strides[0] * 8, 2), C_1.strides[0], 8, 8, T.bool(False)) + for ax0_1, ax1_ax2_fused_0 in T.grid(1, 2): + for ax1_ax2_fused_1 in T.thread_binding(4, thread="threadIdx.z"): + for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): + for ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"): + for ax1_ax2_fused_4 in T.vectorized(4): + with T.block("C_reindex_pad_shared"): + v0 = T.axis.spatial(1, ax0_1) + v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64) + v2 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64) + T.reads(C_reindex_pad_shared[v0, v1, v2]) + T.writes(C[v1, 0, v2]) + if v1 < batch_size: + C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2] + + if __name__ == "__main__": tvm.testing.main()