Skip to content

Commit

Permalink
[Compute-inline] Prefer T.where for reverse compute-inlined block wit…
Browse files Browse the repository at this point in the history
…h predicate (#17128)

* prefer T.where for reverse compute-inlined block with predicate

* update ut scripts

---------

Co-authored-by: wrongtest <[email protected]>
  • Loading branch information
wrongtest-intellif and wrongtest authored Jul 5, 2024
1 parent 3e08e70 commit 0fc047c
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 49 deletions.
44 changes: 25 additions & 19 deletions src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -682,11 +682,14 @@ class ReverseComputeInliner : public BaseInliner {
using BaseInliner::VisitStmt_;

/*! \brief Generate the predicate after inlining based on the consumer predicate */
Block BuildInlinedConsumerPredicate(const BlockNode* producer_block) {
BlockRealize BuildInlinedConsumerPredicate(BlockRealize producer_block_realize) {
// Bind the producer block iter domains for simplification
Map<Var, PrimExpr> subst_map;
Block producer_block = producer_block_realize->block;
for (int i = 0, n = producer_block->iter_vars.size(); i < n; ++i) {
const IterVar& iter = producer_block->iter_vars[i];
const PrimExpr& binding = producer_block_realize->iter_values[i];
subst_map.Set(iter->var, binding);
analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent));
}
if (producer_block->annotations.count(tir::attr::auto_copy) != 0) {
Expand All @@ -705,30 +708,33 @@ class ReverseComputeInliner : public BaseInliner {
PrimExpr predicate = Substituter(this)(consumer_iter_in_bound_);
// Simplify the predicate using the producer block iter domains
predicate = analyzer_.Simplify(predicate);
ObjectPtr<BlockNode> block = make_object<BlockNode>(*producer_block);
if (is_one(predicate)) {
return Block(block);
}
if (const auto* if_ = producer_block->body.as<tir::IfThenElseNode>()) {
PrimExpr if_predicate = analyzer_.Simplify(if_->condition);
if (!StructuralEqual()(predicate, if_predicate)) {
predicate = analyzer_.Simplify(predicate && if_->condition);
return producer_block_realize;
}
if (const auto* if_ = producer_block->body.as<IfThenElseNode>()) {
if (!if_->else_case.defined()) {
PrimExpr if_predicate = analyzer_.Simplify(if_->condition);
if (!StructuralEqual()(predicate, if_predicate)) {
predicate = analyzer_.Simplify(predicate && if_->condition);
producer_block.CopyOnWrite()->body = if_->then_case;
}
}
block->body = IfThenElse(predicate, if_->then_case);
return Block(block);
}
block->body = IfThenElse(predicate, block->body);
return Block(block);
PrimExpr outer_predicate = Substitute(predicate, subst_map);
auto n = producer_block_realize.CopyOnWrite();
n->block = producer_block;
n->predicate = analyzer_.Simplify(outer_predicate);
return GetRef<BlockRealize>(n);
}

Stmt VisitStmt_(const BlockNode* op) final {
Block src_block = GetRef<Block>(op);
Block tgt_block = Downcast<Block>(BaseInliner::VisitStmt_(op));
if (op == producer_block_) {
tgt_block = BuildInlinedConsumerPredicate(tgt_block.get());
block_reuse.Set(src_block, tgt_block);
Stmt VisitStmt_(const BlockRealizeNode* op) final {
Block src_block = op->block;
BlockRealize tgt_block_realize = Downcast<BlockRealize>(StmtMutator::VisitStmt_(op));
if (src_block.get() == producer_block_) {
tgt_block_realize = BuildInlinedConsumerPredicate(tgt_block_realize);
block_reuse.Set(src_block, tgt_block_realize->block);
}
return std::move(tgt_block);
return std::move(tgt_block_realize);
}

Stmt VisitStmt_(const BufferStoreNode* _store) final {
Expand Down
20 changes: 10 additions & 10 deletions tests/python/dlight/test_gpu_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)),
v0 = T.axis.spatial(T.int64(1), ax0)
v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1)
T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < m)
T.reads(matmul_reindex_pad_local[v0, v1, v2])
T.writes(matmul[T.int64(0), v1, v2])
if v1 < m:
matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
# fmt: on


Expand Down Expand Up @@ -200,10 +200,10 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((4096, 4096), "float32"), var_ma
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < m)
T.reads(matmul_reindex_pad_local[v0, v1, v2])
T.writes(matmul[0, v1, v2])
if v1 < m:
matmul[0, v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
matmul[0, v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
# fmt: on

mod = tvm.IRModule({"main": func})
Expand Down Expand Up @@ -466,10 +466,10 @@ def expected(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), lv14: T.Bu
v0 = T.axis.spatial(T.int64(1), ax0)
v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1)
T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < n)
T.reads(var_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv13_1[v2], lv3[T.int64(0), v1, v2])
T.writes(p_output0_intermediate[T.int64(0), v1, v2])
if v1 < n:
p_output0_intermediate[T.int64(0), v1, v2] = T.Cast("float16", var_matmul_intermediate_reindex_pad_local[v0, v1, v2] + T.Cast("float32", lv13_1[v2])) + lv3[T.int64(0), v1, v2]
p_output0_intermediate[T.int64(0), v1, v2] = T.Cast("float16", var_matmul_intermediate_reindex_pad_local[v0, v1, v2] + T.Cast("float32", lv13_1[v2])) + lv3[T.int64(0), v1, v2]

# fmt: on

Expand Down Expand Up @@ -596,9 +596,9 @@ def expected(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), T.int64(2048)), "fl
v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1)
T.reads(lv52[T.int64(0), v1, v2], var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < n)
T.writes(var_T_multiply_intermediate[v1, v2])
if v1 < n:
var_T_multiply_intermediate[v1, v2] = T.Cast("float16", lv52[T.int64(0), v1, v2]) * (var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] * T.sigmoid(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]))
var_T_multiply_intermediate[v1, v2] = T.Cast("float16", lv52[T.int64(0), v1, v2]) * (var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] * T.sigmoid(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]))

# fmt: on

Expand Down Expand Up @@ -666,10 +666,10 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)),
v0 = T.axis.spatial(T.int64(1), ax0)
v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) + ax1_2 * T.int64(2) + ax1)
v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(8) + ax2_0_1 * T.int64(8) + ax2_1_1)
T.where(ax0_ax1_0_fused * T.int64(32) + ax1_2 * T.int64(2) + ax1 < m)
T.reads(matmul_reindex_pad_local[v0, v1, v2])
T.writes(matmul[T.int64(0), v1, v2])
if v1 < m:
matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
# fmt: on


Expand Down
20 changes: 10 additions & 10 deletions tests/python/dlight/test_gpu_matmul_tensorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,10 @@ def expected(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T.
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
v2 = T.axis.spatial(64, ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < m and ax2_2 * 4 + ax2_0 * 2 + ax2_1_1 < 15)
T.reads(compute_reindex_pad_local[v0, v1, v2])
T.writes(compute[v1, v2])
if v1 < m and v2 < 15:
compute[v1, v2] = compute_reindex_pad_local[v0, v1, v2]
compute[v1, v2] = compute_reindex_pad_local[v0, v1, v2]
# fmt: on


Expand Down Expand Up @@ -417,11 +417,11 @@ def expected(lv686: T.Buffer((4096, 256), "uint32"), lv687: T.Buffer((4096, 64),
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((n + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32)
v2 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32)
T.where(ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + ((ax0_ax1_fused_0 * 32 + ax0_ax1_fused_1) * 4 + ax0_ax1_fused_2) // 32 < n)
T.reads(lv3[0, v1, v2], var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2])
T.writes(p_output0_intermediate[0, v1, v2])
T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]})
if v1 < n:
p_output0_intermediate[0, v1, v2] = lv3[0, v1, v2] * T.float16(0.5) + var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2]
p_output0_intermediate[0, v1, v2] = lv3[0, v1, v2] * T.float16(0.5) + var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2]
# fmt: on


Expand Down Expand Up @@ -690,11 +690,11 @@ def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T.
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((m + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32)
v2 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32)
T.where(ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + ((ax0_ax1_fused_0 * 32 + ax0_ax1_fused_1) * 4 + ax0_ax1_fused_2) // 32 < m)
T.reads(matmul_1_reindex_pad_shared_dyn[v0, v1, v2])
T.writes(matmul_1[0, v1, v2])
T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]})
if v1 < m:
matmul_1[0, v1, v2] = matmul_1_reindex_pad_shared_dyn[v0, v1, v2]
matmul_1[0, v1, v2] = matmul_1_reindex_pad_shared_dyn[v0, v1, v2]
# fmt: on


Expand Down Expand Up @@ -831,10 +831,10 @@ def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.ha
v0 = T.axis.spatial(1, ax0_1)
v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64)
v2 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64)
T.where(ax1_0 * 16 + (((ax1_ax2_fused_0 * 4 + ax1_ax2_fused_1 + ax1_ax2_fused_2) * 32 + ax1_ax2_fused_3) * 4 + ax1_ax2_fused_4) // 64 < batch_size)
T.reads(C_reindex_pad_shared[v0, v1, v2])
T.writes(C[v1, 0, v2])
if v1 < batch_size:
C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2]
C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2]
# fmt: on


Expand Down Expand Up @@ -971,10 +971,10 @@ def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "f
v0 = T.axis.spatial(1, ax0_1)
v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64)
v2 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64)
T.where(ax1_0 * 16 + (((ax1_ax2_fused_0 * 4 + ax1_ax2_fused_1 + ax1_ax2_fused_2) * 32 + ax1_ax2_fused_3) * 4 + ax1_ax2_fused_4) // 64 < batch_size)
T.reads(C_reindex_pad_shared[v0, v1, v2])
T.writes(C[v1, 0, v2])
if v1 < batch_size:
C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2]
C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2]


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -856,11 +856,11 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1
v3 = T.axis.spatial(1, 0)
v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16)
v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16)
T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax2 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_ax3_ax4_ax5_fused // 256 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 16 < 127)
T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5])
T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16])
T.block_attr({"meta_schedule.cooperative_fetch": 4})
if v0 * 32 + v2 * 16 + v4 < 127 and v1 * 16 + v5 < 127:
compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0))
compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0))
# fmt: on

decision_0 = [
Expand Down
Loading

0 comments on commit 0fc047c

Please sign in to comment.