Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Express dynamic arguments of strided_slice as arguments #16826

Merged

Conversation

Lunderberg
Copy link
Contributor

Prior to this commit, relax.op.strided_slice stored the axes, begin, end, and strides in the CallNode::attrs. However, the attributes are only intended to store static values. The indices used used for relax.op.strided_slice must frequently be in terms of symbolic shape variables, which should not be stored in the attributes. While some utilities have special handling for relax.op.strided_slice (e.g. tvm::relax::Bind), many do not (e.g. tvm::relax::WellFormed and
tvm::relax::FreeSymbolicVars). As a result, the symbolic expressions in relax.op.strided_slice will fail to be updated in generic utilities, and will fail to trigger safeguards when this occurs.

This commit changes the representation of relax.op.strided_slice to store all arguments in the relax::CallNode::args, rather than the relax::CallNode::attrs. As mentioned in a comment from #13987, which initially implemented relax.op.strided_slice, this was an intended refactor once relax::PrimValue was fully supported.

Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good, we've discussed the topic before and I'm glad to see greater expressiveness via PrimValues where it's appropriate (as opposed to requiring things to be fixed in attributes).

Comment on lines 162 to 163
// PrimExpr bounds_offset = tvm::if_then_else(stride < 0, -1, 0);
// index = tvm::min(tvm::max(index, 0 + bounds_offset), extent + bounds_offset);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should just be deleted, I assume

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, and deleted!

Comment on lines +305 to +304
// TODO(Lunderberg): Implement this check using `IsBaseOf`. Doing
// so will require a way to represent a `relax::TupleStructInfo` of
// unknown length, where each element has the same `StructInfo`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I think the idea of having a list type has come up before. I think we've had lists via Objects and PackedFuncs before, probably for gradient. Reifying it in the type system could be feasible if it's a common enough use case.

@Lunderberg
Copy link
Contributor Author

@Archermmt Can you look into the failing MSC unit tests? It seems that it is a limitation in the Relax to MSC conversion, that it doesn't correctly handle in-line leaf nodes that occur within a relax::Tuple. For now, I've marked the failing unit tests with @pytest.mark.xfail, but it would be good to resolve the breakages either in this or a follow-up PR.

Prior to this commit, `relax.op.strided_slice` stored the `axes`,
`begin`, `end`, and `strides` in the `CallNode::attrs`.  However, the
attributes are only intended to store static values.  The indices used
used for `relax.op.strided_slice` must frequently be in terms of
symbolic shape variables, which should not be stored in the
attributes.  While some utilities have special handling for
`relax.op.strided_slice` (e.g. `tvm::relax::Bind`), many do
not (e.g. `tvm::relax::WellFormed` and
`tvm::relax::FreeSymbolicVars`).  As a result, the symbolic
expressions in `relax.op.strided_slice` will fail to be updated in
generic utilities, and will fail to trigger safeguards when this
occurs.

This commit changes the representation of `relax.op.strided_slice` to
store all arguments in the `relax::CallNode::args`, rather than the
`relax::CallNode::attrs`.  As mentioned in a comment from
apache#13987, which initially implemented
`relax.op.strided_slice`, this was an intended refactor once
`relax::PrimValue` was fully supported.
@Lunderberg Lunderberg force-pushed the relax_use_primvalue_for_strided_slice branch from 8e4dd20 to f596467 Compare April 16, 2024 21:28
@Lunderberg
Copy link
Contributor Author

Merged from main into PR branch to resolve conflict.

@Lunderberg
Copy link
Contributor Author

CI is passing, and this PR is approved. I'm going to give one final merge from main into this PR, as there have been a couple of recent changes in main that may impact the test cases, and then merging.

@Lunderberg
Copy link
Contributor Author

CI passed with main, and a quick smoke test with MLC-LLM passed as well, so it's time to merge.

@Lunderberg Lunderberg merged commit 20d7696 into apache:main May 1, 2024
18 checks passed
@Lunderberg Lunderberg deleted the relax_use_primvalue_for_strided_slice branch May 1, 2024 14:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants