From b43c0309d87410ece1e46afd98c57733379b14bc Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Mon, 3 Jun 2024 12:59:33 +0800 Subject: [PATCH] Introduce outer reduction for metal --- python/tvm/dlight/gpu/gemv.py | 92 ++--- python/tvm/dlight/gpu/low_batch_gemv.py | 227 ++++++++--- python/tvm/dlight/gpu/utils.py | 24 +- tests/python/dlight/test_gpu_gemv.py | 359 ++---------------- .../python/dlight/test_gpu_low_batch_gemv.py | 146 +++++++ 5 files changed, 426 insertions(+), 422 deletions(-) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index 9ad6f3f89af3..ce1c5986e1ca 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -18,7 +18,7 @@ from functools import reduce from typing import List, Optional, Union -from tvm import DataType, arith, ir, tir +from tvm import arith, ir, tir from tvm.target import Target from ..base import ( @@ -31,6 +31,7 @@ try_inline_contiguous_spatial, ) from .base import GPUScheduleRule +from .utils import auto_vectorize, get_bytes, get_extent def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: @@ -49,17 +50,6 @@ def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: return buffer_store.value.b -def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV): - loop: tir.For = sch.get(loop_rv) - return loop.extent.value if isinstance(loop.extent, tir.IntImm) else loop.extent - - -def get_bytes(dtype: Union[DataType, str]) -> int: - if isinstance(dtype, str): - dtype = DataType(dtype) - return dtype.itemsize() - - def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]: """Check if the block is a GEMV. @@ -207,17 +197,13 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- return None elif is_inner_reduction: return self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue) - elif target.kind.name == "opencl" and "android" in str(target.host): + else: ret = self.sch_outer_reduction(sch, target, block, vector_input_buffers, epilogue) if ret is None: return self.sch_outer_reduction_fallback( sch, target, block, vector_input_buffers, epilogue ) return sch - else: - return self.sch_outer_reduction_fallback( - sch, target, block, vector_input_buffers, epilogue - ) def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument self, @@ -535,9 +521,11 @@ def apply( TILE_S, TILE_R = ( 1, - len_c - if len_c > 1 - else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1), + ( + len_c + if len_c > 1 + else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1) + ), ) VEC_C = min(get_max_factor(TILE_R, [1, 2, 4, 8]), VEC_C) @@ -614,9 +602,9 @@ def apply( sch.reorder(bx, ts, tr, r, v_tile, tile_r, vec_c) # sch.bind(batch, "blockIdx.z") sch.bind(bx, "blockIdx.x") - sch.bind(ts, "threadIdx.x") - sch.bind(tr, "threadIdx.y") - sch.vectorize(vec_c) + sch.bind(ts, TAG_S) + sch.bind(tr, TAG_R) + auto_vectorize(sch, vec_c, VEC_C) # decompose independent scale read to outer loop block_rf_stmt = sch.get(rf) @@ -635,26 +623,26 @@ def apply( V_shared = sch.cache_read(rf, read_buffer_index=0, storage_scope="shared") sch.compute_at(V_shared, r, preserve_unit_loops=True) l = sch.get_loops(block=V_shared)[-1] - _, v_tile, tx, ty, vec = sch.split( + _, v_tile, ts, tr, vec = sch.split( l, factors=[None, LOAD_V_TILE, TS, TR, LOAD_V_VEC], preserve_unit_iters=True ) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - sch.vectorize(vec) + sch.bind(tr, TAG_R) + sch.bind(ts, TAG_S) + auto_vectorize(sch, vec, LOAD_V_VEC) # reduce tile_s * tr * vec to tile_s * tr sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True) tr, vec_c, ts = sch.get_loops(block=rf2)[1:] sch.reorder(ts, tr, vec_c) - sch.bind(ts, "threadIdx.x") - sch.bind(tr, "threadIdx.y") + sch.bind(ts, TAG_S) + sch.bind(tr, TAG_R) # reduce tile_s * tr to tile_s sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True) tr, ts = sch.get_loops(block=gemv)[1:] sch.reorder(ts, tr) - sch.bind(ts, "threadIdx.x") - sch.bind(tr, "threadIdx.y") + sch.bind(ts, TAG_S) + sch.bind(tr, TAG_R) sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[2]) sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1]) @@ -665,7 +653,7 @@ def apply( sch.annotate( block_or_loop=sch.get_loops(rf2)[3], ann_key="pragma_auto_unroll_max_step", - ann_val=DEC_PACK, + ann_val=UNROLL, ) sch.annotate( block_or_loop=sch.get_loops(rf2)[3], ann_key="pragma_unroll_explicit", ann_val=1 @@ -678,14 +666,14 @@ def apply( sch.reverse_compute_at(epilogue, bx) sch.set_scope(block, 0, "shared") _, _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name - _, tx = sch.split(sch.fuse(*s), factors=[None, TX]) - sch.bind(tx, "threadIdx.x") + _, ts = sch.split(sch.fuse(*s), factors=[None, TS]) + sch.bind(ts, TAG_S) else: sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True) ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[1:]) ts_tile_s = sch.get_loops(epilogue)[-1] ts, _ = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) - sch.bind(ts, "threadIdx.x") + sch.bind(ts, TAG_S) sch.set_scope(block, 0, "local") return sch @@ -698,15 +686,27 @@ def apply( get_extent(sch, c), ) - TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" - VEC_C = 1 - UNROLL = 4 - TS, TR = 64, 4 DEC_PACK = 8 SCALE_PACK = 4 - LOAD_V_SHARED = False - LOAD_V_VEC = 4 - LOAD_V_TILE = 8 + + if target.kind.name == "opencl" and "android" in str(target.host): + TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" + VEC_C = 8 + UNROLL = 8 + TS, TR = 64, 4 + LOAD_V_SHARED = False + LOAD_V_VEC = 4 + LOAD_V_TILE = 8 + elif target.kind.name == "metal": + TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" + VEC_C = 4 + UNROLL = 8 + TS, TR = 128, 4 + LOAD_V_SHARED = False + LOAD_V_VEC = 4 + LOAD_V_TILE = 4 + else: + return None if LOAD_V_SHARED is False: LOAD_V_TILE = 1 @@ -723,9 +723,11 @@ def apply( _, TILE_R = ( 1, - len_c - if len_c > 1 - else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1), + ( + len_c + if len_c > 1 + else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1) + ), ) LOAD_V_VEC = min(get_max_factor(TILE_R, [1, 2, 4, 8]), LOAD_V_VEC) VEC_LOAD = 1 diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py b/python/tvm/dlight/gpu/low_batch_gemv.py index 20911f0e7d9c..b528086a1626 100644 --- a/python/tvm/dlight/gpu/low_batch_gemv.py +++ b/python/tvm/dlight/gpu/low_batch_gemv.py @@ -16,9 +16,9 @@ # under the License. """A rule for low-batch GEMM / decode-GEMM using GEMV schedule.""" from functools import reduce -from typing import List, Optional, Set, Union +from typing import List, Literal, Optional, Set, Union -from tvm import DataType, arith, ir, tir +from tvm import arith, ir, tir from tvm.target import Target from ..base import ( @@ -30,6 +30,7 @@ try_inline_contiguous_spatial, ) from .base import GPUScheduleRule +from .utils import auto_vectorize, get_bytes, get_extent def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: @@ -48,17 +49,6 @@ def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: return buffer_store.value.b -def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV): - loop: tir.For = sch.get(loop_rv) - return loop.extent.value if isinstance(loop.extent, tir.IntImm) else loop.extent - - -def get_bytes(dtype: Union[DataType, str]) -> int: - if isinstance(dtype, str): - dtype = DataType(dtype) - return dtype.itemsize() - - def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]: """Check if the block is a low batch GEMM. @@ -170,7 +160,7 @@ def normalize( ): return None iter_to_info = {i.var: i for i in block_info.iters} - batch_loops, s_loops, r_loops, c_loops = [], [], [], [] + batch_loops, s_loops, r_loops = [], [], [] inner_axis = access.args[-1].source.source is_inner_reduction = iter_to_info[inner_axis].kind == "R" @@ -179,14 +169,7 @@ def normalize( info = iter_to_info.get(var) loop = info.loop_rv is_reduction = info.kind == "R" - if split_expr.lower_factor > 1: - if c_loops: - return None - loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor]) - # we only support the reduction dim being grouped atm - if not is_reduction: - return None - c_loops.append(c_loop) + # No C loops as we do not compute_inline weights into main block if is_reduction: r_loops.append(loop) elif all([var in buf_vars for buf_vars in buffers_use_vars]): @@ -196,14 +179,9 @@ def normalize( assert s_loops assert r_loops - if not c_loops: - c_loops = [sch.add_unit_loop(block_info.block_rv)] dynamic_loops = [iter_to_info[var].loop_rv for var in dynamic_iter_vars] assert len(dynamic_loops) == 1 - if not batch_loops: - batch_loops = [sch.add_unit_loop(block_info.block_rv)] - sch.reorder(*dynamic_loops, *batch_loops, *s_loops, *r_loops, *c_loops) - sch.fuse(*batch_loops) + sch.reorder(*dynamic_loops, *s_loops, *r_loops) sch.fuse(*s_loops) sch.fuse(*r_loops) return is_inner_reduction @@ -292,6 +270,18 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- batch_pad, ) return sch + elif self.bucket <= 4: + self.sch_outer_reduction( + sch, + target, + block, + dequantize_block, + pad_input_block, + vector_input_buffers, + epilogue, + batch_pad, + ) + return sch else: return None @@ -332,9 +322,7 @@ def apply( ): # rfactor: reduce to tx * vec_c - _, b, s, r, c = sch.get_loops(block=gemv) - s = sch.fuse(b, s) - r = sch.fuse(r, c) + _, s, r = sch.get_loops(block=gemv) bx, ts, tile_s = sch.split(s, factors=[None, TS, TILE_S], preserve_unit_iters=True) r, tr, tile_r_vec_n, vec_c = sch.split( r, factors=[None, TR, TILE_R // VEC_C, VEC_C], preserve_unit_iters=True @@ -516,15 +504,8 @@ def apply( return sch # Specify the `len_tx` and `len_ty` according to the loop extent - _, batch, s, r, c = sch.get_loops(block=block) - len_batch, len_s, len_r, len_c = ( - get_extent(sch, batch), - get_extent(sch, s), - get_extent(sch, r), - get_extent(sch, c), - ) - len_S = len_batch * len_s - len_R = len_r * len_c + _, s, r = sch.get_loops(block=block) + len_s, len_r = get_extent(sch, s), get_extent(sch, r) TAG_S, TAG_R = "threadIdx.y", "threadIdx.x" if target.kind.name == "cuda": @@ -532,8 +513,8 @@ def apply( LOAD_V_SHARED = True LOAD_V_VEC = 8 UNROLL = 256 - if isinstance(len_S, int): - if len_S > len_R: + if isinstance(len_s, int): + if len_s > len_r: TS, TR = 4, 64 else: TS, TR = 16, 32 @@ -542,8 +523,8 @@ def apply( LOAD_V_SHARED = False LOAD_V_VEC = -1 UNROLL = 8 - if isinstance(len_S, int): - if len_S > len_R: + if isinstance(len_s, int): + if len_s > len_r: TS, TR = 8, 32 else: TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" @@ -553,8 +534,8 @@ def apply( LOAD_V_SHARED = True LOAD_V_VEC = 8 UNROLL = 256 - if isinstance(len_S, int): - if len_S > len_R: + if isinstance(len_s, int): + if len_s > len_r: TS, TR = 1, 128 else: TS, TR = 8, 64 @@ -570,8 +551,8 @@ def apply( LOAD_V_SHARED = True LOAD_V_VEC = 4 UNROLL = 256 - if isinstance(len_S, int): - if len_S > len_R: + if isinstance(len_s, int): + if len_s > len_r: TS, TR = 4, 32 else: TS, TR = 16, 32 @@ -588,7 +569,7 @@ def apply( UNROLL = 64 TS, TR = 1, 64 - if not isinstance(len_S, int): + if not isinstance(len_s, int): TS, TR = 1, 64 while TS * TR > target.max_num_threads: @@ -597,12 +578,7 @@ def apply( else: TR //= 2 - TILE_S, TILE_R = ( - 2, - len_c - if len_c > 1 - else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1), - ) + TILE_S, TILE_R = 2, max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1) VEC_C = min(get_max_factor(TILE_R, [1, 2, 4, 8]), VEC_C) VEC_LOAD = 1 return apply( @@ -620,3 +596,144 @@ def apply( LOAD_V_VEC=LOAD_V_VEC, UNROLL=UNROLL, ) + + def sch_outer_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument + self, + sch: tir.Schedule, + target: Target, + block: tir.schedule.BlockRV, + dequantize_block: Optional[tir.schedule.BlockRV], + pad_input_block: Optional[tir.schedule.BlockRV], + vector_input_buffers: List[tir.Buffer], + epilogue_info: Optional[BlockInfo], + batch_pad: int, + ): + """Schedule the outer reduction block.""" + + # Need to detect from the block + DEC_PACK = 8 + SCALE_PACK = 4 + + def apply( + sch: tir.Schedule, + main_block: tir.schedule.BlockRV, + TAG_S: Literal["threadIdx.x", "threadIdx.y"], + TAG_R: Literal["threadIdx.x", "threadIdx.y"], + TS: int, + TR: int, + VEC: int, + UNROLL: int, + ): + # rfactor: reduce to tx * vec_c + b, s, r = sch.get_loops(main_block) + by, batch = sch.split(b, [None, batch_pad], preserve_unit_iters=True) + bx, ts = sch.split(s, [None, TS], preserve_unit_iters=True) + r, tr, scale_c, vec_c = sch.split( + r, [None, TR, SCALE_PACK, DEC_PACK], preserve_unit_iters=True + ) + sch.reorder(by, bx, ts, r, batch, scale_c, tr, vec_c) + tr_vec_c = sch.fuse(tr, vec_c) + rf = sch.rfactor(tr_vec_c, 0) + + # rfactor: reduce to tx + by, bx, ts, batch, tr_vec_c = sch.get_loops(block=main_block) + tr, vec_c = sch.split(tr_vec_c, [TR, DEC_PACK], preserve_unit_iters=True) + rf2 = sch.rfactor(tr, 0) + + # bind, vectorize compute + by, bx, ts, r, batch, scale_c, tr_vec_c = sch.get_loops(block=rf) + tr, vec_c = sch.split(tr_vec_c, [TR, DEC_PACK], preserve_unit_iters=True) + sch.reorder(by, bx, ts, tr, r, scale_c, batch, vec_c) + sch.bind(by, "blockIdx.y") + sch.bind(bx, "blockIdx.x") + sch.bind(ts, TAG_S) + sch.bind(tr, TAG_R) + auto_vectorize(sch, vec_c, VEC) + + if dequantize_block is not None: + sch.compute_at(dequantize_block, scale_c, preserve_unit_loops=True) + sch.set_scope(dequantize_block, 0, "local") + auto_vectorize(sch, sch.fuse(*sch.get_loops(dequantize_block)[6:]), VEC) + + B0_local = sch.cache_read(dequantize_block, 0, "local") + sch.compute_at(B0_local, r, preserve_unit_loops=True) + auto_vectorize(sch, sch.fuse(*sch.get_loops(B0_local)[5:]), VEC) + + B1_local = sch.cache_read(dequantize_block, 1, "local") + sch.compute_at(B1_local, r, preserve_unit_loops=True) + auto_vectorize(sch, sch.fuse(*sch.get_loops(B1_local)[5:]), VEC) + else: + # Only support quantized workloads for now + sch = None + return + + if LOAD_V_SHARED: + sch.set_scope(pad_input_block, 0, "shared") + sch.compute_at(pad_input_block, r, preserve_unit_loops=True) + sch.storage_align(pad_input_block, 0, axis=-2, factor=8, offset=1) + tr, ts, v = sch.split(sch.fuse(*sch.get_loops(pad_input_block)[5:]), [TR, TS, None]) + sch.bind(tr, TAG_R) + sch.bind(ts, TAG_S) + auto_vectorize(sch, v, VEC) + else: + sch.compute_inline(pad_input_block) + + # reduce tile_s * tr * vec to tile_s * tr + sch.reverse_compute_at(rf2, bx, preserve_unit_loops=True) + tr, vec_c, batch, ts = sch.get_loops(rf2)[2:] + sch.reorder(ts, tr, batch, vec_c) + sch.bind(ts, TAG_S) + sch.bind(tr, TAG_R) + + # reduce tile_s * tr to tile_s + sch.reverse_compute_at(main_block, bx, preserve_unit_loops=True) + tr, batch, ts = sch.get_loops(main_block)[2:] + sch.reorder(batch, ts, tr) + sch.bind(ts, TAG_S) + sch.bind(tr, TAG_R) + # unroll(batch, 1) + + sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[4]) + sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[4]) + + sch.set_scope(rf, buffer_index=0, storage_scope="local") + sch.set_scope(rf2, buffer_index=0, storage_scope="local") + + epilogue = sch.get_consumers(main_block) + # Schedule epilogue + if epilogue: + epilogue = epilogue[0] + if is_broadcast_epilogue( # pylint: disable=no-else-raise + sch, main_block, epilogue + ): + raise NotImplementedError + else: + sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True) + batch, ts = sch.get_loops(epilogue)[2:] + sch.bind(ts, TAG_S) + sch.set_scope(main_block, 0, "local") + + if target.kind.name == "metal": + TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" + TS, TR = 64, 4 + LOAD_V_SHARED = True + VEC = 4 + UNROLL = 8 + else: + # fallback configuration + TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" + TS, TR = 32, 4 + LOAD_V_SHARED = False + VEC = 1 + UNROLL = 64 + + return apply( + sch, + block, + TAG_S, + TAG_R, + TS, + TR, + VEC, + UNROLL, + ) diff --git a/python/tvm/dlight/gpu/utils.py b/python/tvm/dlight/gpu/utils.py index e27a6969ad88..875a9524bb9b 100644 --- a/python/tvm/dlight/gpu/utils.py +++ b/python/tvm/dlight/gpu/utils.py @@ -16,12 +16,32 @@ # under the License. # pylint: disable=missing-docstring """Utility methods for generic GPU.""" -from typing import List, Optional +from typing import List, Optional, Union -from tvm import tir +from tvm import DataType, tir from tvm.target import Target +def get_bytes(dtype: Union[DataType, str]) -> int: + if isinstance(dtype, str): + dtype = DataType(dtype) + return dtype.itemsize() + + +def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV): + loop: tir.For = sch.get(loop_rv) + return loop.extent.value if isinstance(loop.extent, tir.IntImm) else loop.extent + + +def auto_vectorize(sch: tir.Schedule, loop: tir.schedule.LoopRV, max_vec: int): + """Auto vectorize the loop.""" + extent = get_extent(sch, loop) + if not isinstance(extent, int): + return + v = loop if extent <= max_vec else sch.split(loop, factors=[None, max_vec])[-1] + sch.vectorize(v) + + def max_threads_per_block(target: Target) -> int: """Get the maximum number of threads per block for a given target. diff --git a/tests/python/dlight/test_gpu_gemv.py b/tests/python/dlight/test_gpu_gemv.py index 0f7b6f45ae3f..20cb703f7f60 100644 --- a/tests/python/dlight/test_gpu_gemv.py +++ b/tests/python/dlight/test_gpu_gemv.py @@ -672,6 +672,7 @@ def func(lv9: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv10: T.Buffer( def test_outer_reduction_adreno(): + # fmt: off @T.prim_func(private=True) def before( lv575: T.Buffer((1376, 4096), "uint32"), @@ -687,377 +688,95 @@ def before( for i, j in T.grid(11008, 4096): with T.block("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv575[v_i // 8, v_j], lv576[v_i // 32, v_j]) - T.writes(p_output0_intermediate_1[v_i, v_j]) - p_output0_intermediate_1[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv575[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * T.uint32(4) - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv576[v_i // 32, v_j] + p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * T.uint32(4)), T.uint32(15)))- T.float16(7)) * lv576[v_i // 32, v_j] for i0, i1, i2, k in T.grid(1, 1, 4096, 11008): with T.block("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv574[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_matmul_intermediate[v_i0, v_i1, v_i2] - + lv574[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] - ) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] for ax0, ax1, ax2 in T.grid(1, 1, 4096): with T.block("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv570[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( - lv570[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - ) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv570[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] @T.prim_func(private=True) - def expected( - lv575: T.Buffer((1376, 4096), "uint32"), - lv576: T.Buffer((344, 4096), "float16"), - lv574: T.Buffer((1, 1, 11008), "float16"), - lv570: T.Buffer((1, 1, 4096), "float16"), - p_output0_intermediate: T.Buffer((1, 1, 4096), "float16"), - ): + def expected(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), "float16"), lv574: T.Buffer((1, 1, 11008), "float16"), lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): var_matmul_intermediate_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local") - var_matmul_intermediate_rf_local = T.alloc_buffer( - (32, 1, 1, 4096), "float16", scope="local" - ) - var_matmul_intermediate_rf_local_1 = T.alloc_buffer( - (4, 1, 1, 4096), "float16", scope="local" - ) + var_matmul_intermediate_rf_local = T.alloc_buffer((32, 1, 1, 4096), "float16", scope="local") + var_matmul_intermediate_rf_local_1 = T.alloc_buffer((4, 1, 1, 4096), "float16", scope="local") lv576_local = T.alloc_buffer((344, 4096), "float16", scope="local") lv575_local = T.alloc_buffer((1376, 4096), "uint32", scope="local") for u_fused_ax0_fused_fused_0 in T.thread_binding(64, thread="blockIdx.x"): for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.x"): - for ( - ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init - ) in T.thread_binding(4, thread="threadIdx.y"): - for ( - ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init - ) in T.vectorized(8): + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init in T.vectorized(8): with T.block("matmul_rf_init"): - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial( - 32, - ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init * 8 - + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init, - ) - v0 = T.axis.spatial( - 4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 - ) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial(32, ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init * 8 + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init) + v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1) T.reads() - T.writes( - var_matmul_intermediate_rf_local[ - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, - 0, - 0, - v0, - ] - ) - var_matmul_intermediate_rf_local[ - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0 - ] = T.float16(0) - for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 in T.thread_binding( - 4, thread="threadIdx.y" - ): + T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0]) + var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0] = T.float16(0) + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1 in T.grid(86, 1): for ax0, ax1 in T.grid(1, 1): with T.block("lv576_local"): - v0 = T.axis.spatial( - 344, - ax1_0_fused_ax1_1_fused_0 * 4 - + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 - + ax0, - ) - v1 = T.axis.spatial( - 4096, - u_fused_ax0_fused_fused_0 * 64 - + u_fused_ax0_fused_fused_1 - + ax1, - ) + v0 = T.axis.spatial(344, ax1_0_fused_ax1_1_fused_0 * 4 + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 + ax0) + v1 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax1) T.reads(lv576[v0, v1]) T.writes(lv576_local[v0, v1]) lv576_local[v0, v1] = lv576[v0, v1] for ax1_0_fused_ax1_1_fused_3 in range(4): for ax0, ax1 in T.grid(1, 1): with T.block("lv575_local"): - v0 = T.axis.spatial( - 1376, - ax1_0_fused_ax1_1_fused_0 * 16 - + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 - * 4 - + ax1_0_fused_ax1_1_fused_3 - + ax0, - ) - v1 = T.axis.spatial( - 4096, - u_fused_ax0_fused_fused_0 * 64 - + u_fused_ax0_fused_fused_1 - + ax1, - ) + v0 = T.axis.spatial(1376, ax1_0_fused_ax1_1_fused_0 * 16 + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 4 + ax1_0_fused_ax1_1_fused_3 + ax0) + v1 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax1) T.reads(lv575[v0, v1]) T.writes(lv575_local[v0, v1]) lv575_local[v0, v1] = lv575[v0, v1] - for ( - ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 - ) in T.vectorized(8): + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 in T.vectorized(8): with T.block("matmul_rf_update"): - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial( - 32, - ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 - * 8 - + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, - ) - v0 = T.axis.spatial( - 4096, - u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1, - ) - ( - vax1_0_fused_ax1_1_fused_0, - vax1_0_fused_ax1_1_fused_1, - vax1_0_fused_ax1_1_fused_3, - ) = T.axis.remap( - "RRR", - [ - ax1_0_fused_ax1_1_fused_0, - ax1_0_fused_ax1_1_fused_1, - ax1_0_fused_ax1_1_fused_3, - ], - ) - T.reads( - var_matmul_intermediate_rf_local[ - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, - 0, - 0, - v0, - ], - lv574[ - 0, - 0, - vax1_0_fused_ax1_1_fused_0 * 128 - + vax1_0_fused_ax1_1_fused_1 * 128 - + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused - // 8 - * 32 - + vax1_0_fused_ax1_1_fused_3 * 8 - + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused - % 8, - ], - lv575_local[ - vax1_0_fused_ax1_1_fused_0 * 16 - + vax1_0_fused_ax1_1_fused_1 * 16 - + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused - // 8 - * 4 - + vax1_0_fused_ax1_1_fused_3, - v0, - ], - lv576_local[ - vax1_0_fused_ax1_1_fused_0 * 4 - + vax1_0_fused_ax1_1_fused_1 * 4 - + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused - // 8 - + vax1_0_fused_ax1_1_fused_3 // 4, - v0, - ], - ) - T.writes( - var_matmul_intermediate_rf_local[ - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, - 0, - 0, - v0, - ], - ) - var_matmul_intermediate_rf_local[ - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, - 0, - 0, - v0, - ] = var_matmul_intermediate_rf_local[ - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, - 0, - 0, - v0, - ] + lv574[ - 0, - 0, - vax1_0_fused_ax1_1_fused_0 * 128 - + vax1_0_fused_ax1_1_fused_1 * 128 - + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused - // 8 - * 32 - + vax1_0_fused_ax1_1_fused_3 * 8 - + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused - % 8, - ] * ( - ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv575_local[ - vax1_0_fused_ax1_1_fused_0 * 16 - + vax1_0_fused_ax1_1_fused_1 * 16 - + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused - // 8 - * 4 - + vax1_0_fused_ax1_1_fused_3, - v0, - ], - T.Cast( - "uint32", - ( - vax1_0_fused_ax1_1_fused_0 * 128 - + vax1_0_fused_ax1_1_fused_1 * 128 - + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused - // 8 - * 32 - + vax1_0_fused_ax1_1_fused_3 * 8 - + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused - % 8 - ) - % 8, - ) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) - * lv576_local[ - vax1_0_fused_ax1_1_fused_0 * 4 - + vax1_0_fused_ax1_1_fused_1 * 4 - + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused - // 8 - + vax1_0_fused_ax1_1_fused_3 // 4, - v0, - ] - ) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial(32, ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 8 + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1) + v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1) + vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_1, vax1_0_fused_ax1_1_fused_3 = T.axis.remap("RRR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1, ax1_0_fused_ax1_1_fused_3]) + T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0], lv574[0, 0, vax1_0_fused_ax1_1_fused_0 * 128 + vax1_0_fused_ax1_1_fused_1 * 128 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 * 32 + vax1_0_fused_ax1_1_fused_3 * 8 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % 8], lv575_local[vax1_0_fused_ax1_1_fused_0 * 16 + vax1_0_fused_ax1_1_fused_1 * 16 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 * 4 + vax1_0_fused_ax1_1_fused_3, v0], lv576_local[vax1_0_fused_ax1_1_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1 * 4 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 + vax1_0_fused_ax1_1_fused_3 // 4, v0]) + T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0]) + var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0] = var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0] + lv574[0, 0, vax1_0_fused_ax1_1_fused_0 * 128 + vax1_0_fused_ax1_1_fused_1 * 128 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 * 32 + vax1_0_fused_ax1_1_fused_3 * 8 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % 8] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv575_local[vax1_0_fused_ax1_1_fused_0 * 16 + vax1_0_fused_ax1_1_fused_1 * 16 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 * 4 + vax1_0_fused_ax1_1_fused_3, v0], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 128 + vax1_0_fused_ax1_1_fused_1 * 128 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 * 32 + vax1_0_fused_ax1_1_fused_3 * 8 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % 8) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576_local[vax1_0_fused_ax1_1_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1 * 4 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 + vax1_0_fused_ax1_1_fused_3 // 4, v0]) for ax2 in T.thread_binding(64, thread="threadIdx.x"): for ax0 in T.thread_binding(4, thread="threadIdx.y"): with T.block("matmul_rf_init"): - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = ( - T.axis.spatial(4, ax0) - ) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = T.axis.spatial(4, ax0) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2) T.reads() - T.writes( - var_matmul_intermediate_rf_local_1[ - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, - 0, - 0, - v0, - ] - ) - var_matmul_intermediate_rf_local_1[ - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0 - ] = T.float16(0) - for ax1 in T.serial( - 8, - annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}, - ): + T.writes(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0]) + var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0] = T.float16(0) + for ax1 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}): with T.block("matmul_rf_update"): - ( - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, - ) = T.axis.remap("SR", [ax0, ax1]) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2) - T.reads( - var_matmul_intermediate_rf_local_1[ - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, - 0, - 0, - v0, - ], - var_matmul_intermediate_rf_local[ - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 8 - + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, - 0, - 0, - v0, - ], - ) - T.writes( - var_matmul_intermediate_rf_local_1[ - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, - 0, - 0, - v0, - ] - ) - var_matmul_intermediate_rf_local_1[ - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, - 0, - 0, - v0, - ] = ( - var_matmul_intermediate_rf_local_1[ - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, - 0, - 0, - v0, - ] - + var_matmul_intermediate_rf_local[ - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 8 - + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, - 0, - 0, - v0, - ] - ) + T.reads(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0], var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, 0, 0, v0]) + T.writes(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0]) + var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0] = var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0] + var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, 0, 0, v0] for ax1 in T.thread_binding(64, thread="threadIdx.x"): for ax0 in T.thread_binding(4, thread="threadIdx.y"): with T.block("matmul"): - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = ( - T.axis.reduce(4, ax0) - ) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = T.axis.reduce(4, ax0) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax1) - T.reads( - var_matmul_intermediate_rf_local_1[ - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, - 0, - 0, - v0, - ] - ) + T.reads(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0]) T.writes(var_matmul_intermediate_local[0, 0, v0]) with T.init(): var_matmul_intermediate_local[0, 0, v0] = T.float16(0) - var_matmul_intermediate_local[0, 0, v0] = ( - var_matmul_intermediate_local[0, 0, v0] - + var_matmul_intermediate_rf_local_1[ - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, - 0, - 0, - v0, - ] - ) + var_matmul_intermediate_local[0, 0, v0] = var_matmul_intermediate_local[0, 0, v0] + var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0] for ax0_fused_0 in T.thread_binding(64, thread="threadIdx.x"): for ax0_fused_1 in range(1): with T.block("T_add"): - v0 = T.axis.spatial( - 4096, u_fused_ax0_fused_fused_0 * 64 + ax0_fused_0 + ax0_fused_1 - ) + v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax0_fused_0 + ax0_fused_1) T.reads(lv570[0, 0, v0], var_matmul_intermediate_local[0, 0, v0]) T.writes(p_output0_intermediate[0, 0, v0]) - p_output0_intermediate[0, 0, v0] = ( - lv570[0, 0, v0] + var_matmul_intermediate_local[0, 0, v0] - ) - + p_output0_intermediate[0, 0, v0] = lv570[0, 0, v0] + var_matmul_intermediate_local[0, 0, v0] + # fmt: on mod = tvm.IRModule({"main": before}) with Target("opencl", host="llvm -mtriple=aarch64-linux-android"): mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod) diff --git a/tests/python/dlight/test_gpu_low_batch_gemv.py b/tests/python/dlight/test_gpu_low_batch_gemv.py index 6072664b3a45..c3a06a1e3057 100644 --- a/tests/python/dlight/test_gpu_low_batch_gemv.py +++ b/tests/python/dlight/test_gpu_low_batch_gemv.py @@ -381,5 +381,151 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16" tvm.ir.assert_structural_equal(mod["main"], expected) +def test_outer_reduction(): + # fmt: off + @T.prim_func(private=True) + def before( + B0: T.Buffer((512, 6144), "uint32"), + B1: T.Buffer((128, 6144), "float16"), + var_A: T.handle, + var_C: T.handle + ): + batch_size = T.int32() + A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16") + C = T.match_buffer(var_C, (batch_size, 1, 6144), "float16") + compute = T.alloc_buffer((4096, 6144), "float16") + B = T.alloc_buffer((4096, 6144), "float16") + for i0, i1 in T.grid(4096, 6144): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(B0[v_i0 // 8, v_i1], T.Cast("uint32", v_i0 % 8 * 4)), T.uint32(15))) + for i0, i1 in T.grid(4096, 6144): + with T.block("dequantize"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + B[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) * B1[v_i0 // 32, v_i1] + for i0, i1, i2, k in T.grid(batch_size, 1, 6144, 4096): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + with T.init(): + C[v_i0, v_i1, v_i2] = T.float16(0) + C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * B[v_k, v_i2] + + @T.prim_func(private=True) + def expected(B0: T.Buffer((512, 6144), "uint32"), B1: T.Buffer((128, 6144), "float16"), var_A: T.handle, var_C: T.handle): + T.func_attr({"tir.is_scheduled": 1}) + batch_size = T.int32() + A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16") + C = T.match_buffer(var_C, (batch_size, 1, 6144), "float16") + # with T.block("root"): + B_local = T.alloc_buffer((4096, 6144), "float16", scope="local") + A_pad_shared = T.alloc_buffer(((batch_size + 3) // 4 * 4, 1, 4096), "float16", scope="shared") + C_pad_local = T.alloc_buffer(((batch_size + 3) // 4 * 4, 1, 6144), "float16", scope="local") + C_pad_rf_local = T.alloc_buffer((32, (batch_size + 3) // 4 * 4, 1, 6144), "float16", scope="local") + C_pad_rf_local_1 = T.alloc_buffer((4, (batch_size + 3) // 4 * 4, 1, 6144), "float16", scope="local") + B0_local = T.alloc_buffer((512, 6144), "uint32", scope="local") + B1_local = T.alloc_buffer((128, 6144), "float16", scope="local") + for ax0_0 in T.thread_binding((batch_size + 3) // 4, thread="blockIdx.y"): + for ax1_fused_0 in T.thread_binding(96, thread="blockIdx.x"): + for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): + for ax2_fused_1_ax2_fused_3_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for ax0_1_init, ax2_fused_1_ax2_fused_3_fused_1_0_init in T.grid(4, 2): + for ax2_fused_1_ax2_fused_3_fused_1_1_init in T.vectorized(4): + with T.block("matmul_rf_init"): + vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(32, ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax2_fused_1_ax2_fused_3_fused_1_0_init * 4 + ax2_fused_1_ax2_fused_3_fused_1_1_init) + v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax0_1_init) + v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax1_fused_1) + T.reads() + T.writes(C_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1]) + C_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = T.float16(0) + for ax2_fused_0 in range(32): + for ax0_ax1_fused in T.vectorized(4): + with T.block("B0_local"): + v0 = T.axis.spatial(512, ax2_fused_0 * 16 + ax2_fused_1_ax2_fused_3_fused_0 * 4 + ax0_ax1_fused) + v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax1_fused_1) + T.reads(B0[v0, v1]) + T.writes(B0_local[v0, v1]) + B0_local[v0, v1] = B0[v0, v1] + for ax0_ax1_fused in T.vectorized(1): + with T.block("B1_local"): + v0 = T.axis.spatial(128, ax2_fused_0 * 4 + ax2_fused_1_ax2_fused_3_fused_0) + v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax1_fused_1) + T.reads(B1[v0, v1]) + T.writes(B1_local[v0, v1]) + B1_local[v0, v1] = B1[v0, v1] + for ax0_ax1_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(2): + with T.block("A_pad"): + v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 128) + v1 = T.axis.spatial(4096, ax2_fused_0 * 128 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 128) + T.reads(A[v0, 0, v1]) + T.writes(A_pad_shared[v0, 0, v1]) + T.block_attr({"buffer_dim_align": [[0, 1, 8, 1]]}) + A_pad_shared[v0, 0, v1] = T.if_then_else(v0 < batch_size, A[v0, 0, v1], T.float16(0)) + for ax2_fused_2 in range(4): + for ax0_ax1_fused_0 in range(2): + for ax0_ax1_fused_1 in T.vectorized(4): + with T.block("dequantize"): + v0 = T.axis.spatial(4096, ax2_fused_0 * 128 + ax2_fused_1_ax2_fused_3_fused_0 * 32 + ax2_fused_2 * 8 + ax0_ax1_fused_0 * 4 + ax0_ax1_fused_1) + v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax1_fused_1) + T.reads(B0_local[v0 // 8, v1], B1_local[v0 // 32, v1]) + T.writes(B_local[v0, v1]) + B_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(B0_local[v0 // 8, v1], T.Cast("uint32", v0 % 8 * 4)), T.uint32(15))) - T.float16(7)) * B1_local[v0 // 32, v1] + for ax0_1, ax2_fused_1_ax2_fused_3_fused_1_0 in T.grid(4, 2): + for ax2_fused_1_ax2_fused_3_fused_1_1 in T.vectorized(4): + with T.block("matmul_rf_update"): + vax2_fused_1_ax2_fused_3_fused = T.axis.spatial(32, ax2_fused_1_ax2_fused_3_fused_0 * 8 + ax2_fused_1_ax2_fused_3_fused_1_0 * 4 + ax2_fused_1_ax2_fused_3_fused_1_1) + v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax0_1) + v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax1_fused_1) + vax2_fused_0, vax2_fused_2 = T.axis.remap("RR", [ax2_fused_0, ax2_fused_2]) + T.reads(C_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1], A_pad_shared[v0, 0, vax2_fused_0 * 128 + vax2_fused_1_ax2_fused_3_fused // 8 * 32 + vax2_fused_2 * 8 + vax2_fused_1_ax2_fused_3_fused % 8], B_local[vax2_fused_0 * 128 + vax2_fused_1_ax2_fused_3_fused // 8 * 32 + vax2_fused_2 * 8 + vax2_fused_1_ax2_fused_3_fused % 8, v1]) + T.writes(C_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1]) + C_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] = C_pad_rf_local[vax2_fused_1_ax2_fused_3_fused, v0, 0, v1] + A_pad_shared[v0, 0, vax2_fused_0 * 128 + vax2_fused_1_ax2_fused_3_fused // 8 * 32 + vax2_fused_2 * 8 + vax2_fused_1_ax2_fused_3_fused % 8] * B_local[vax2_fused_0 * 128 + vax2_fused_1_ax2_fused_3_fused // 8 * 32 + vax2_fused_2 * 8 + vax2_fused_1_ax2_fused_3_fused % 8, v1] + for ax3 in T.thread_binding(64, thread="threadIdx.x"): + for ax0 in T.thread_binding(4, thread="threadIdx.y"): + for ax2_init in range(4): + with T.block("matmul_rf_init"): + vax2_fused_1_ax2_fused_3_fused_0 = T.axis.spatial(4, ax0) + v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax2_init) + v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax3) + T.reads() + T.writes(C_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1]) + C_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = T.float16(0) + for ax2, ax1 in T.grid(4, 8): + with T.block("matmul_rf_update"): + vax2_fused_1_ax2_fused_3_fused_0, vax2_fused_1_ax2_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax2) + v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax3) + T.reads(C_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1], C_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 8 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1]) + T.writes(C_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1]) + C_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] = C_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] + C_pad_rf_local[vax2_fused_1_ax2_fused_3_fused_0 * 8 + vax2_fused_1_ax2_fused_3_fused_1, v0, 0, v1] + for ax1 in range(4): + for ax2 in T.thread_binding(64, thread="threadIdx.x"): + for ax0 in T.thread_binding(4, thread="threadIdx.y"): + with T.block("matmul"): + vax2_fused_1_ax2_fused_3_fused_0 = T.axis.reduce(4, ax0) + v0 = T.axis.spatial((batch_size + 3) // 4 * 4, ax0_0 * 4 + ax1) + v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax2) + T.reads(C_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1]) + T.writes(C_pad_local[v0, 0, v1]) + with T.init(): + C_pad_local[v0, 0, v1] = T.float16(0) + C_pad_local[v0, 0, v1] = C_pad_local[v0, 0, v1] + C_pad_rf_local_1[vax2_fused_1_ax2_fused_3_fused_0, v0, 0, v1] + for ax0 in range(4): + for ax1 in T.thread_binding(64, thread="threadIdx.x"): + with T.block("C_pad"): + v0 = T.axis.spatial(batch_size, ax0_0 * 4 + ax0) + v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax1) + T.where((ax0_0 - (batch_size + 3) // 4 < 0 or ax0_0 == 0) and ax0_0 * 4 + ax0 < batch_size) + T.reads(C_pad_local[v0, 0, v1]) + T.writes(C[v0, 0, v1]) + C[v0, 0, v1] = C_pad_local[v0, 0, v1] + # fmt: on + mod = tvm.IRModule({"main": before}) + with Target("metal"): + mod = dl.ApplyDefaultSchedule(dl.gpu.LowBatchGEMV(4))(mod) # pylint: disable=not-callable + tvm.ir.assert_structural_equal(mod["main"], expected) + + if __name__ == "__main__": tvm.testing.main()