Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Handle binary operations between Tensor and PrimValue #16827

Merged

Conversation

Lunderberg
Copy link
Contributor

Prior to this commit, binary operations were only defined between two tensors. This commit allows binary operations to apply between a tensor and a relax::PrimValue.

When inferring the output StructInfo, binary operations with a PrimValue produce the same output as using a 0-d tensor. When legalizing operations containing a PrimValue, they are lowered to primitive TIR arguments.

Prior to this commit, binary operations were only defined between two
tensors.  This commit allows binary operations to apply between a
tensor and a `relax::PrimValue`.

When inferring the output `StructInfo`, binary operations with a
`PrimValue` produce the same output as using a 0-d tensor.  When
legalizing operations containing a `PrimValue`, they are lowered to
primitive TIR arguments.
@Lunderberg Lunderberg force-pushed the relax_binary_operations_with_primvalue branch from 579125e to 687950e Compare April 1, 2024 12:43
Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pursuing these changes and also making a few refactors that improve readability. I have a couple of concerns listed below about how to handle Object types (I think arithmetic ops shouldn't accept them, though admittedly we presently don't have a way to express, "I would be fine with either a tensor or prim value").

Is there a particular use case for arithmetic with PrimValues and tensors? I guess it makes sense to be able to pass a PrimValue directly to one of these ops without requiring an explicit conversion. I would be a little hesitant to have arithmetic on PrimValues via shape expressions and then also pass them around to Relax arithmetic ops.

} else if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
return tensor->dtype;
} else if (sinfo.as<ObjectStructInfoNode>()) {
return DataType::Void();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this necessarily be expected behavior? An Object could be anything, including things that dtype does not make sense for at all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I went back and forth on it. There isn't currently a standard for whether FInferStructInfo should raise an error when the arguments are provably invalid, or if it should raise an error when the arguments are not provably valid. On the one hand, StructInfoLCA returns ObjectStructInfo as the common base class of TensorStructInfo and PrimStructInfo, so an ObjectStructInfo could contain a valid instance of either. On the other hand, the current struct inference requires that the input be validated as TensorStructInfo.

Overall, I'm not sure which is the better behavior. For now, I'm updating this PR to explicitly require either TensorStructInfo or PrimStructInfo, and to raise an exception for ObjectStructInfo, since allowing ObjectStructInfo would be an independent change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, I'm in favor of asking for a MatchCast if we can't draw a conclusion. Down the line, inserting MatchCasts via normalization rules would be a good policy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. Currently, FInferStructInfo is called prior to FNormalize, so inference could be inspecting an expression that hasn't yet been normalized. This was useful for providing FNormalize for R.Prim (if PrimStructInfo contains a known value, in-line that value), but I'm wondering if we should re-visit that.

Comment on lines 50 to 52
} else if (lhs_sinfo.as<ObjectStructInfoNode>() && rhs_sinfo.as<ObjectStructInfoNode>()) {
return ObjectStructInfo();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it's appropriate to accept Objects for an arithmetic operation. This relates to your idea about using normalization rules to turn type requirements into explicit checks with MatchCast, but I think this would be a case to ask a user to put in a MatchCast to assert the types work.

Comment on lines -44 to -46
ICHECK(n->dtype.is_int() && n->dtype.is_scalar()) << "TypeError: Relax only uses "
"scalar integer TIR variables, but gets: "
<< n;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the reason for removing this requirement? Are we using handle-typed vars now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I needed to remove the n->dtype.is_int() check, as the value could be R.Prim("float32"), but the check for n->dtype.is_scalar() should be kept.

@Lunderberg
Copy link
Contributor Author

Is there a particular use case for arithmetic with PrimValues and tensors?

Primarily for cases where a dynamic computation requires use of a dynamic shape (e.g. RMS_norm). Or simpler cases, like computing the mean.

@R.function
def mean(A: R.Tensor(['m','n'], 'float32') -> R.Tensor(['m'], 'float32'):
    n = T.int64()
    sum = R.sum(A, axis=1, keep=False)
    output = sum / n.astype('float32') # Allow expressing this step.
    return output

@Lunderberg
Copy link
Contributor Author

And the PR is now updated to require operands to have either TensorStructInfo or PrimStructInfo.

Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the requested changes. I don't see much harm in adding support for such ops, though we should be mindful of the added complexity of having more ways to express equivalent computations. I don't think it's a problem for now.

tests/python/relax/test_op_binary.py Outdated Show resolved Hide resolved
@Lunderberg
Copy link
Contributor Author

@tqchen The requested change has been made, and CI is passing. Any other changes that should be made before merging?

} else if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
return tensor->dtype;
} else {
LOG(FATAL) << "TypeError: "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally our error message would ask for TensorStructInfo. In this particular case, would this error message be less informative than before? Given this is a global change across all binary ops, would be good to cross confirm the usages here and make error more informative.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, and this no longer tells the user which operation it was. Updated.

@tqchen
Copy link
Member

tqchen commented Apr 5, 2024

Sorry i didn't yet have time to do a full look through, will spend sometime this weekend

@Lunderberg
Copy link
Contributor Author

No problem, and thank you. This isn't a high-priority PR to land, and can certainly wait until after the weekend.

@tqchen
Copy link
Member

tqchen commented Apr 17, 2024

thanks @Lunderberg should be good to go after ci

@Lunderberg
Copy link
Contributor Author

@tqchen Thank you! I've resolved the unit test whose failure was specific to this PR.

There's a few other CI failures, which look like they're triggered by a bug in tvm.device('cuda').exist. If no GPUs are present, it raises an exception when it should return False. This is an independent failure mode, and I've submitted #16903 which resolves it.

@Lunderberg Lunderberg merged commit 622bd15 into apache:main Apr 18, 2024
18 checks passed
@Lunderberg Lunderberg deleted the relax_binary_operations_with_primvalue branch April 18, 2024 21:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants