Skip to content

Commit

Permalink
Introduce outer reduction for metal
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy committed Jun 3, 2024
1 parent b87d1f9 commit b43c030
Show file tree
Hide file tree
Showing 5 changed files with 426 additions and 422 deletions.
92 changes: 47 additions & 45 deletions python/tvm/dlight/gpu/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from functools import reduce
from typing import List, Optional, Union

from tvm import DataType, arith, ir, tir
from tvm import arith, ir, tir
from tvm.target import Target

from ..base import (
Expand All @@ -31,6 +31,7 @@
try_inline_contiguous_spatial,
)
from .base import GPUScheduleRule
from .utils import auto_vectorize, get_bytes, get_extent


def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
Expand All @@ -49,17 +50,6 @@ def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
return buffer_store.value.b


def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV):
loop: tir.For = sch.get(loop_rv)
return loop.extent.value if isinstance(loop.extent, tir.IntImm) else loop.extent


def get_bytes(dtype: Union[DataType, str]) -> int:
if isinstance(dtype, str):
dtype = DataType(dtype)
return dtype.itemsize()


def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]:
"""Check if the block is a GEMV.
Expand Down Expand Up @@ -207,17 +197,13 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-
return None
elif is_inner_reduction:
return self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue)
elif target.kind.name == "opencl" and "android" in str(target.host):
else:
ret = self.sch_outer_reduction(sch, target, block, vector_input_buffers, epilogue)
if ret is None:
return self.sch_outer_reduction_fallback(
sch, target, block, vector_input_buffers, epilogue
)
return sch
else:
return self.sch_outer_reduction_fallback(
sch, target, block, vector_input_buffers, epilogue
)

def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument
self,
Expand Down Expand Up @@ -535,9 +521,11 @@ def apply(

TILE_S, TILE_R = (
1,
len_c
if len_c > 1
else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1),
(
len_c
if len_c > 1
else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1)
),
)
VEC_C = min(get_max_factor(TILE_R, [1, 2, 4, 8]), VEC_C)

Expand Down Expand Up @@ -614,9 +602,9 @@ def apply(
sch.reorder(bx, ts, tr, r, v_tile, tile_r, vec_c)
# sch.bind(batch, "blockIdx.z")
sch.bind(bx, "blockIdx.x")
sch.bind(ts, "threadIdx.x")
sch.bind(tr, "threadIdx.y")
sch.vectorize(vec_c)
sch.bind(ts, TAG_S)
sch.bind(tr, TAG_R)
auto_vectorize(sch, vec_c, VEC_C)

# decompose independent scale read to outer loop
block_rf_stmt = sch.get(rf)
Expand All @@ -635,26 +623,26 @@ def apply(
V_shared = sch.cache_read(rf, read_buffer_index=0, storage_scope="shared")
sch.compute_at(V_shared, r, preserve_unit_loops=True)
l = sch.get_loops(block=V_shared)[-1]
_, v_tile, tx, ty, vec = sch.split(
_, v_tile, ts, tr, vec = sch.split(
l, factors=[None, LOAD_V_TILE, TS, TR, LOAD_V_VEC], preserve_unit_iters=True
)
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
sch.vectorize(vec)
sch.bind(tr, TAG_R)
sch.bind(ts, TAG_S)
auto_vectorize(sch, vec, LOAD_V_VEC)

# reduce tile_s * tr * vec to tile_s * tr
sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True)
tr, vec_c, ts = sch.get_loops(block=rf2)[1:]
sch.reorder(ts, tr, vec_c)
sch.bind(ts, "threadIdx.x")
sch.bind(tr, "threadIdx.y")
sch.bind(ts, TAG_S)
sch.bind(tr, TAG_R)

# reduce tile_s * tr to tile_s
sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True)
tr, ts = sch.get_loops(block=gemv)[1:]
sch.reorder(ts, tr)
sch.bind(ts, "threadIdx.x")
sch.bind(tr, "threadIdx.y")
sch.bind(ts, TAG_S)
sch.bind(tr, TAG_R)

sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[2])
sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1])
Expand All @@ -665,7 +653,7 @@ def apply(
sch.annotate(
block_or_loop=sch.get_loops(rf2)[3],
ann_key="pragma_auto_unroll_max_step",
ann_val=DEC_PACK,
ann_val=UNROLL,
)
sch.annotate(
block_or_loop=sch.get_loops(rf2)[3], ann_key="pragma_unroll_explicit", ann_val=1
Expand All @@ -678,14 +666,14 @@ def apply(
sch.reverse_compute_at(epilogue, bx)
sch.set_scope(block, 0, "shared")
_, _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name
_, tx = sch.split(sch.fuse(*s), factors=[None, TX])
sch.bind(tx, "threadIdx.x")
_, ts = sch.split(sch.fuse(*s), factors=[None, TS])
sch.bind(ts, TAG_S)
else:
sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True)
ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[1:])
ts_tile_s = sch.get_loops(epilogue)[-1]
ts, _ = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True)
sch.bind(ts, "threadIdx.x")
sch.bind(ts, TAG_S)
sch.set_scope(block, 0, "local")
return sch

Expand All @@ -698,15 +686,27 @@ def apply(
get_extent(sch, c),
)

TAG_S, TAG_R = "threadIdx.x", "threadIdx.y"
VEC_C = 1
UNROLL = 4
TS, TR = 64, 4
DEC_PACK = 8
SCALE_PACK = 4
LOAD_V_SHARED = False
LOAD_V_VEC = 4
LOAD_V_TILE = 8

if target.kind.name == "opencl" and "android" in str(target.host):
TAG_S, TAG_R = "threadIdx.x", "threadIdx.y"
VEC_C = 8
UNROLL = 8
TS, TR = 64, 4
LOAD_V_SHARED = False
LOAD_V_VEC = 4
LOAD_V_TILE = 8
elif target.kind.name == "metal":
TAG_S, TAG_R = "threadIdx.x", "threadIdx.y"
VEC_C = 4
UNROLL = 8
TS, TR = 128, 4
LOAD_V_SHARED = False
LOAD_V_VEC = 4
LOAD_V_TILE = 4
else:
return None

if LOAD_V_SHARED is False:
LOAD_V_TILE = 1
Expand All @@ -723,9 +723,11 @@ def apply(

_, TILE_R = (
1,
len_c
if len_c > 1
else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1),
(
len_c
if len_c > 1
else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1)
),
)
LOAD_V_VEC = min(get_max_factor(TILE_R, [1, 2, 4, 8]), LOAD_V_VEC)
VEC_LOAD = 1
Expand Down
Loading

0 comments on commit b43c030

Please sign in to comment.