-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relax] Express dynamic arguments of strided_slice as arguments #16826
[Relax] Express dynamic arguments of strided_slice as arguments #16826
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very good, we've discussed the topic before and I'm glad to see greater expressiveness via PrimValues where it's appropriate (as opposed to requiring things to be fixed in attributes).
src/relax/op/tensor/index.cc
Outdated
// PrimExpr bounds_offset = tvm::if_then_else(stride < 0, -1, 0); | ||
// index = tvm::min(tvm::max(index, 0 + bounds_offset), extent + bounds_offset); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably should just be deleted, I assume
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, and deleted!
// TODO(Lunderberg): Implement this check using `IsBaseOf`. Doing | ||
// so will require a way to represent a `relax::TupleStructInfo` of | ||
// unknown length, where each element has the same `StructInfo`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, I think the idea of having a list type has come up before. I think we've had lists via Object
s and PackedFunc
s before, probably for gradient. Reifying it in the type system could be feasible if it's a common enough use case.
@Archermmt Can you look into the failing MSC unit tests? It seems that it is a limitation in the Relax to MSC conversion, that it doesn't correctly handle in-line leaf nodes that occur within a |
Prior to this commit, `relax.op.strided_slice` stored the `axes`, `begin`, `end`, and `strides` in the `CallNode::attrs`. However, the attributes are only intended to store static values. The indices used used for `relax.op.strided_slice` must frequently be in terms of symbolic shape variables, which should not be stored in the attributes. While some utilities have special handling for `relax.op.strided_slice` (e.g. `tvm::relax::Bind`), many do not (e.g. `tvm::relax::WellFormed` and `tvm::relax::FreeSymbolicVars`). As a result, the symbolic expressions in `relax.op.strided_slice` will fail to be updated in generic utilities, and will fail to trigger safeguards when this occurs. This commit changes the representation of `relax.op.strided_slice` to store all arguments in the `relax::CallNode::args`, rather than the `relax::CallNode::attrs`. As mentioned in a comment from apache#13987, which initially implemented `relax.op.strided_slice`, this was an intended refactor once `relax::PrimValue` was fully supported.
8e4dd20
to
f596467
Compare
Merged from main into PR branch to resolve conflict. |
CI is passing, and this PR is approved. I'm going to give one final merge from main into this PR, as there have been a couple of recent changes in |
CI passed with main, and a quick smoke test with MLC-LLM passed as well, so it's time to merge. |
Prior to this commit,
relax.op.strided_slice
stored theaxes
,begin
,end
, andstrides
in theCallNode::attrs
. However, the attributes are only intended to store static values. The indices used used forrelax.op.strided_slice
must frequently be in terms of symbolic shape variables, which should not be stored in the attributes. While some utilities have special handling forrelax.op.strided_slice
(e.g.tvm::relax::Bind
), many do not (e.g.tvm::relax::WellFormed
andtvm::relax::FreeSymbolicVars
). As a result, the symbolic expressions inrelax.op.strided_slice
will fail to be updated in generic utilities, and will fail to trigger safeguards when this occurs.This commit changes the representation of
relax.op.strided_slice
to store all arguments in therelax::CallNode::args
, rather than therelax::CallNode::attrs
. As mentioned in a comment from #13987, which initially implementedrelax.op.strided_slice
, this was an intended refactor oncerelax::PrimValue
was fully supported.