Skip to content

Commit

Permalink
[TIR] Fix the thread binding iter_var dtype in Bind primitive (#16074)
Browse files Browse the repository at this point in the history
As a follow up PR of #16041, this PR fixes the iter_var dtype generated
by the schedule primitive `bind`. Now the iter_var dtype is the same as
the loop_var.

Note that this PR changes the internal interface (tir interface) of the
bind primitive. But it does not change the user interface (python side,
and concrete_schedule.cc side).
  • Loading branch information
Hzfengsy authored Nov 5, 2023
1 parent b144145 commit 1de5aa5
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 10 deletions.
4 changes: 1 addition & 3 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -555,9 +555,7 @@ void ConcreteScheduleNode::Bind(const LoopRV& loop_rv, const String& thread_axis
"`vthread.x`, `vthread.y` and `vthread.z` instead";
}
TVM_TIR_SCHEDULE_BEGIN();
tir::Bind(state_, this->GetSRef(loop_rv),
IterVar(/*dom=*/Range(nullptr), /*var=*/Var(thread_axis), /*iter_type=*/kThreadIndex,
/*thread_tag=*/thread_axis));
tir::Bind(state_, this->GetSRef(loop_rv), thread_axis);
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("bind", this->error_render_level_);
}
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ TVM_DLL void Vectorize(ScheduleState self, const StmtSRef& loop_sref);
* \param loop_sref The sref of the loop to be bound to the thread axis
* \param thread_axis The thread axis to be bound to the loop
*/
TVM_DLL void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar& thread_axis);
TVM_DLL void Bind(ScheduleState self, const StmtSRef& loop_sref, const String& thread_axis);
/*!
* \brief Unroll the input loop. It requires nothing
* \param self The state of the schedule
Expand Down
19 changes: 13 additions & 6 deletions src/tir/schedule/primitive/for_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind
* `for_kind` is `kThreadBinding`
*/
void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref, ForKind for_kind,
Optional<IterVar> thread_axis) {
Optional<String> thread_axis) {
const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);

/*
Expand All @@ -164,14 +164,21 @@ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref
// Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each
// underlying block.
CheckParallelizability(self, GetRef<For>(loop), for_kind,
thread_axis.defined()
? runtime::ThreadScope::Create(thread_axis.value()->thread_tag)
: runtime::ThreadScope{-1, -1});
thread_axis.defined() ? runtime::ThreadScope::Create(thread_axis.value())
: runtime::ThreadScope{-1, -1});

// Step 3. Loop update and IR replacement
ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop);
new_loop->kind = for_kind;
new_loop->thread_binding = std::move(thread_axis);
if (thread_axis.defined()) {
const String& thread_tag = thread_axis.value();
new_loop->thread_binding = IterVar(/*dom=*/Range(nullptr), //
/*var=*/Var(thread_axis.value(), loop->loop_var.dtype()), //
/*iter_type=*/kThreadIndex, //
/*thread_tag=*/thread_axis.value());
} else {
new_loop->thread_binding = NullOpt;
}
self->Replace(loop_sref, For(new_loop), {});
}

Expand All @@ -183,7 +190,7 @@ void Vectorize(ScheduleState self, const StmtSRef& loop_sref) {
ParallelizeComputation(self, loop_sref, ForKind::kVectorized, NullOpt);
}

void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar& thread_axis) {
void Bind(ScheduleState self, const StmtSRef& loop_sref, const String& thread_axis) {
ParallelizeComputation(self, loop_sref, ForKind::kThreadBinding, thread_axis);
}

Expand Down
29 changes: 29 additions & 0 deletions tests/python/unittest/test_tir_schedule_for_kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,5 +668,34 @@ def test_scatter_parallelize():
verify_trace_roundtrip(s, mod=scatter_compute)


def test_bind_thread_iter_var_dtype():
@T.prim_func(private=True)
def before(
A: T.Buffer((T.int64(128), T.int64(128))),
B: T.Buffer((T.int64(128), T.int64(128))),
) -> None:
for i, j in T.grid(T.int64(128), T.int64(128)):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0

@T.prim_func(private=True)
def expected(
A: T.Buffer((T.int64(128), T.int64(128))),
B: T.Buffer((T.int64(128), T.int64(128))),
) -> None:
for i0 in T.thread_binding(T.int64(128), thread="threadIdx.x"):
for i1 in range(T.int64(128)):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i0, i1])
B[vi, vj] = A[vi, vj] * 2.0

s = tir.Schedule(before, debug_mask="all")
i, _ = s.get_loops(s.get_block("B"))
s.bind(i, "threadIdx.x")
assert_structural_equal_ignore_global_symbol(s.mod["main"], expected)
verify_trace_roundtrip(s, mod=before)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 1de5aa5

Please sign in to comment.