Skip to content

Commit

Permalink
[TIR] Fix get_block_access_region for let bindings (#16553)
Browse files Browse the repository at this point in the history
* [TIR] Fix get_block_access_region for let bindings

The current implementation of `block_access_region_detector` does not
consider the let bindings inside the block. To be more specific:

- The let bindings inside the block can be the index of buffer access
  indices
- The let bindings var is defined inside the block, so the block
  annotation cannot use those vars.
- We need to substitute the let bindings inside the block to the
  block annotation.
  • Loading branch information
Hzfengsy authored Feb 11, 2024
1 parent 0449a16 commit 3fd9bac
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 13 deletions.
17 changes: 15 additions & 2 deletions src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

#include <unordered_map>

#include "../transforms/ir_utils.h"
namespace tvm {
namespace tir {
Expand Down Expand Up @@ -78,6 +80,8 @@ class BlockReadWriteDetector : public StmtExprVisitor {
Map<Var, Buffer> buffer_var_map_;
/*! \brief The target buffer var mapping to its matching */
std::unordered_map<const VarNode*, MatchBufferRegion> match_buffers_;
/*! \brief let bindings inside the block */
std::unordered_map<const VarNode*, PrimExpr> let_bindings_;
/*!\ brief Internal analyzer. */
arith::Analyzer ana_;

Expand Down Expand Up @@ -111,6 +115,7 @@ class BlockReadWriteDetector : public StmtExprVisitor {
void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const BlockRealizeNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const LetStmtNode* op) override;
void VisitExpr_(const BufferLoadNode* op) override;
void VisitExpr_(const VarNode* op) override;
void VisitExpr_(const CallNode* op) override;
Expand Down Expand Up @@ -149,7 +154,8 @@ void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef
void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) {
std::vector<arith::IntSet> relaxed_region;
for (const PrimExpr& index : op->indices) {
relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_));
PrimExpr remapped_index = Substitute(index, let_bindings_);
relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(remapped_index), dom_map_));
}
Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
ExprVisitor::VisitExpr_(op);
Expand All @@ -176,6 +182,12 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) {
}
}

void BlockReadWriteDetector::VisitStmt_(const LetStmtNode* op) {
let_bindings_[op->var.get()] = op->value;
StmtVisitor::VisitStmt_(op);
let_bindings_.erase(op->var.get());
}

void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
if (op->op.same_as(builtin::tvm_access_ptr())) {
const VarNode* buffer_var = op->args[1].as<VarNode>();
Expand Down Expand Up @@ -225,7 +237,8 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) {
std::vector<arith::IntSet> relaxed_region;
for (const PrimExpr& index : op->indices) {
relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_));
PrimExpr remapped_index = Substitute(index, let_bindings_);
relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(remapped_index), dom_map_));
}
Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
StmtVisitor::VisitStmt_(op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
# specific language governing permissions and limitations
# under the License.
import pytest

import tvm
import tvm.testing
from tvm import tir
from tvm.script import tir as T
from tvm.ir import Range
from tvm.script import tir as T


@T.prim_func
Expand Down Expand Up @@ -355,14 +357,33 @@ def test_access_of_decompose_reduction():
tvm.ir.assert_structural_equal(block.writes, ret[1])


def test_buffer_access_with_let_binding():
@T.prim_func
def func(
storage: T.Buffer((16, 16, 16), "float32"),
seq_slot_ids: T.Buffer((16,), "int32"),
history_slot_ids: T.Buffer((16,), "int32"),
output: T.Buffer((16, 16), "float32"),
):
for i, s in T.grid(16, 16):
with T.block("copy"):
vi, vs = T.axis.remap("SS", [i, s])
T.reads(
seq_slot_ids[vi],
history_slot_ids[vi],
storage[seq_slot_ids[vi], history_slot_ids[vi], vs],
)
T.writes(output[vi, vs])
seq_id: T.int32 = seq_slot_ids[vi]
history_id: T.int32 = history_slot_ids[vi]
output[vi, vs] = storage[seq_id, history_id, vs]

block = func.body.block.body.body.body.block
buffer_var_map = {buf.data: buf for buf in func.buffer_map.values()}
ret = tir.analysis.get_block_access_region(block, buffer_var_map)
tvm.ir.assert_structural_equal(block.reads, ret[0])
tvm.ir.assert_structural_equal(block.writes, ret[1])


if __name__ == "__main__":
test_block_access_region_detector()
test_opaque_block()
test_opaque_access()
test_opaque_access_with_tvm_access_ptr()
test_match_buffer()
test_access_in_if_then_else_func()
test_access_in_branch_func()
test_access_of_padding_pattern()
test_access_of_reduction()
test_access_of_decompose_reduction()
tvm.testing.main()

0 comments on commit 3fd9bac

Please sign in to comment.