-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relax] Handle binary operations between Tensor and PrimValue #16827
[Relax] Handle binary operations between Tensor and PrimValue #16827
Conversation
Prior to this commit, binary operations were only defined between two tensors. This commit allows binary operations to apply between a tensor and a `relax::PrimValue`. When inferring the output `StructInfo`, binary operations with a `PrimValue` produce the same output as using a 0-d tensor. When legalizing operations containing a `PrimValue`, they are lowered to primitive TIR arguments.
579125e
to
687950e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for pursuing these changes and also making a few refactors that improve readability. I have a couple of concerns listed below about how to handle Object
types (I think arithmetic ops shouldn't accept them, though admittedly we presently don't have a way to express, "I would be fine with either a tensor or prim value").
Is there a particular use case for arithmetic with PrimValues and tensors? I guess it makes sense to be able to pass a PrimValue directly to one of these ops without requiring an explicit conversion. I would be a little hesitant to have arithmetic on PrimValues via shape expressions and then also pass them around to Relax arithmetic ops.
src/relax/op/op_common.h
Outdated
} else if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) { | ||
return tensor->dtype; | ||
} else if (sinfo.as<ObjectStructInfoNode>()) { | ||
return DataType::Void(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this necessarily be expected behavior? An Object
could be anything, including things that dtype does not make sense for at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I went back and forth on it. There isn't currently a standard for whether FInferStructInfo
should raise an error when the arguments are provably invalid, or if it should raise an error when the arguments are not provably valid. On the one hand, StructInfoLCA
returns ObjectStructInfo
as the common base class of TensorStructInfo
and PrimStructInfo
, so an ObjectStructInfo
could contain a valid instance of either. On the other hand, the current struct inference requires that the input be validated as TensorStructInfo
.
Overall, I'm not sure which is the better behavior. For now, I'm updating this PR to explicitly require either TensorStructInfo
or PrimStructInfo
, and to raise an exception for ObjectStructInfo
, since allowing ObjectStructInfo
would be an independent change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally, I'm in favor of asking for a MatchCast if we can't draw a conclusion. Down the line, inserting MatchCasts via normalization rules would be a good policy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True. Currently, FInferStructInfo
is called prior to FNormalize
, so inference could be inspecting an expression that hasn't yet been normalized. This was useful for providing FNormalize
for R.Prim
(if PrimStructInfo
contains a known value, in-line that value), but I'm wondering if we should re-visit that.
src/relax/op/tensor/binary.cc
Outdated
} else if (lhs_sinfo.as<ObjectStructInfoNode>() && rhs_sinfo.as<ObjectStructInfoNode>()) { | ||
return ObjectStructInfo(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure it's appropriate to accept Object
s for an arithmetic operation. This relates to your idea about using normalization rules to turn type requirements into explicit checks with MatchCast
, but I think this would be a case to ask a user to put in a MatchCast
to assert the types work.
ICHECK(n->dtype.is_int() && n->dtype.is_scalar()) << "TypeError: Relax only uses " | ||
"scalar integer TIR variables, but gets: " | ||
<< n; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What was the reason for removing this requirement? Are we using handle-typed vars now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I needed to remove the n->dtype.is_int()
check, as the value could be R.Prim("float32")
, but the check for n->dtype.is_scalar()
should be kept.
Primarily for cases where a dynamic computation requires use of a dynamic shape (e.g. RMS_norm). Or simpler cases, like computing the mean. @R.function
def mean(A: R.Tensor(['m','n'], 'float32') -> R.Tensor(['m'], 'float32'):
n = T.int64()
sum = R.sum(A, axis=1, keep=False)
output = sum / n.astype('float32') # Allow expressing this step.
return output |
And the PR is now updated to require operands to have either |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making the requested changes. I don't see much harm in adding support for such ops, though we should be mindful of the added complexity of having more ways to express equivalent computations. I don't think it's a problem for now.
@tqchen The requested change has been made, and CI is passing. Any other changes that should be made before merging? |
} else if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) { | ||
return tensor->dtype; | ||
} else { | ||
LOG(FATAL) << "TypeError: " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Originally our error message would ask for TensorStructInfo. In this particular case, would this error message be less informative than before? Given this is a global change across all binary ops, would be good to cross confirm the usages here and make error more informative.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, and this no longer tells the user which operation it was. Updated.
Sorry i didn't yet have time to do a full look through, will spend sometime this weekend |
No problem, and thank you. This isn't a high-priority PR to land, and can certainly wait until after the weekend. |
thanks @Lunderberg should be good to go after ci |
@tqchen Thank you! I've resolved the unit test whose failure was specific to this PR. There's a few other CI failures, which look like they're triggered by a bug in |
Prior to this commit, binary operations were only defined between two tensors. This commit allows binary operations to apply between a tensor and a
relax::PrimValue
.When inferring the output
StructInfo
, binary operations with aPrimValue
produce the same output as using a 0-d tensor. When legalizing operations containing aPrimValue
, they are lowered to primitive TIR arguments.