Skip to content

Commit

Permalink
[TOPI] Remove blockIdx.z in topi sort (#16977)
Browse files Browse the repository at this point in the history
As `blockIdx.z` is not allowed in WebGPU, this PR split `blockIdx.z`
into `blockIdx.y` to support WebGPU
  • Loading branch information
Hzfengsy authored May 10, 2024
1 parent 2565aa3 commit 825dc1f
Showing 1 changed file with 14 additions and 17 deletions.
31 changes: 14 additions & 17 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,16 @@ def traverse(op):
return s


def _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz):
def _get_threads(ib, nthread_tx, nthread_bx, nthread_by):
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)

by = te.thread_axis("blockIdx.y")
bz = te.thread_axis("blockIdx.z")
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(bz, "thread_extent", nthread_bz)

return tx, bx, by, bz
return tx, bx, by


def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_func=None):
Expand All @@ -87,13 +85,13 @@ def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_f
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = ceil_div(shape[axis], max_threads)
nthread_by = axis_mul_before
nthread_bz = axis_mul_after
nthread_by = axis_mul_before * axis_mul_after

# Copy the keys_in to initial output
with ib.new_scope():
tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz)
tx, bx, by = _get_threads(ib, nthread_tx, nthread_bx, nthread_by)
tid = bx * nthread_tx + tx
by, bz = by % axis_mul_before, by // axis_mul_before
idx = (by * shape[axis] + tid) * axis_mul_after + bz
with ib.if_scope(tid < shape[axis]):
keys_out[idx] = keys_in[idx]
Expand Down Expand Up @@ -122,11 +120,11 @@ def _odd_even_sort(
):
nthread_tx = block_size // 2
nthread_bx = ceil_div(size, block_size)
nthread_by = axis_mul_before
nthread_bz = axis_mul_after
nthread_by = axis_mul_before * axis_mul_after
with ib.new_scope():
ib.scope_attr(tvm.tir.const(0), "hand_threaded", 0)
tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz)
tx, bx, by = _get_threads(ib, nthread_tx, nthread_bx, nthread_by)
by, bz = by % axis_mul_before, by // axis_mul_before
tid = 2 * tx
start = bx * block_size

Expand Down Expand Up @@ -222,7 +220,6 @@ def _sort_common(

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_by = axis_mul_before * axis_mul_after
nthread_bz = 1
nthread_tx = max_threads
nthread_bx = ceil_div(size, nthread_tx)

Expand Down Expand Up @@ -334,12 +331,13 @@ def assign_j():
ntx = max_threads
nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * thread_work), "int32")
nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32")
tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz)
tx, bx, by = _get_threads(ib, ntx, nbx, nthread_by * nbz)
else:
ntx = tvm.tir.generic.cast(tvm.te.min(max_threads, width), "int32")
nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * thread_work), "int32")
nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32")
tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz)
tx, bx, by = _get_threads(ib, ntx, nbx, nthread_by * nbz)
by, bz = by % nthread_by, by // nthread_by

def mergepath(
source,
Expand Down Expand Up @@ -471,18 +469,17 @@ def do_merge(first, last):
width,
tvm.tir.indexmod(l2_width, 2) == 0,
)
nthread_by = axis_mul_before
nthread_bz = axis_mul_after
nthread_by = axis_mul_before * axis_mul_after
nthread_tx = max_threads
nthread_bx = ceil_div(size, nthread_tx)
## if the final sorted data ended up in the swap, copy it to the real output
with ib.if_scope(
tvm.tir.all(upper_lim > lower_lim, tvm.tir.indexmod(upper_lim - lower_lim, 2) == 1)
):
with ib.new_scope():
tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz)
tx, bx, by = _get_threads(ib, nthread_tx, nthread_bx, nthread_by)
tid = bx * nthread_tx + tx
idx = (by * axis_mul_after + bz) * size + tid
idx = by * size + tid
with ib.if_scope(tid < size):
keys[idx] = keys_swap[idx]
if values is not None:
Expand Down

0 comments on commit 825dc1f

Please sign in to comment.