Skip to content

Commit

Permalink
[Relax][Pass] Skip data type node for CSE pass (#16493)
Browse files Browse the repository at this point in the history
* [Relax][Pass] Skip data type node for CSE pass
- The problem is seen when an arg of relax op is dtype

* Add comments to code
  • Loading branch information
abhikran-quic authored Jan 31, 2024
1 parent 06f7810 commit 0628bdb
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/relax/transform/eliminate_common_subexpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,12 @@ class SubexprCounter : public ExprVisitor {
// 4. StringImm nodes (not much benefit from binding to a var)
// 5. Scalar constants (not much benefit from binding to a var)
// 6. Shape expressions (exist to hold several PrimValue objects)
// 7. DataType nodes (no need to modify dtype nodes)
if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() ||
e->IsInstance<ShapeExprNode>() || e->IsInstance<ExternFuncNode>() ||
e->IsInstance<ConstantNode>())) {
e->IsInstance<ConstantNode>() || e->IsInstance<DataTypeImmNode>())) {
// also if e has an impure subexpression, we will not deduplicate it
if (!impurity_detector_.Detect(e)) {
int count = 0;
Expand Down
24 changes: 24 additions & 0 deletions tests/python/relax/test_transform_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,5 +339,29 @@ def sum(
tvm.ir.assert_structural_equal(Expected, After)


def test_do_not_eliminate_dtype():
@I.ir_module
class Before:
@R.function
def foo() -> R.Tensor((32, 64), "int32"):
obj: R.Object = R.vm.alloc_storage(
R.shape([24576]), runtime_device_index=0, dtype="uint8"
)
a: R.Tensor([32, 64], dtype="int32") = R.vm.alloc_tensor(
obj, offset=0, shape=R.shape([32, 64]), dtype="int32"
)
ret_val: R.Tensor([32, 64], dtype="int32") = R.builtin.alloc_tensor(
R.shape([32, 64]), R.dtype("int32"), R.prim_value(0)
)
_t1: R.Tuple = R.vm.kill_object(a)
_t3: R.Tuple = R.vm.kill_object(obj)
lv: R.Tensor([32, 64], dtype="int32") = ret_val
return lv

Expected = Before

verify(Before, Expected)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 0628bdb

Please sign in to comment.