-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relax] Handle binary operations between Tensor and PrimValue #16827
Changes from 2 commits
687950e
4f487d5
4f3f510
7a6ff0a
efe5323
5295220
b5e2608
16e468e
f097c78
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -239,52 +239,91 @@ InferLayoutOutput InferLayoutUnaryEwise(const Call& call, | |
const Map<String, Array<String>>& desired_layouts, | ||
const VarLayoutMap& var_layout_map); | ||
|
||
/*! | ||
* \brief Get the element dtype from StructInfo | ||
* | ||
* \param sinfo The StructInfo to expect | ||
* \return The inferred element dtype. | ||
* \throw Throw exception if the StructInfo doesn't have an element type. | ||
*/ | ||
inline DataType GetElementDType(const StructInfo& sinfo) { | ||
if (const auto* prim = sinfo.as<PrimStructInfoNode>()) { | ||
return prim->dtype; | ||
} else if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) { | ||
return tensor->dtype; | ||
} else if (sinfo.as<ObjectStructInfoNode>()) { | ||
return DataType::Void(); | ||
} else { | ||
LOG(FATAL) << "TypeError: " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Originally our error message would ask for TensorStructInfo. In this particular case, would this error message be less informative than before? Given this is a global change across all binary ops, would be good to cross confirm the usages here and make error more informative. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, and this no longer tells the user which operation it was. Updated. |
||
<< "Cannot determine element type of " << sinfo; | ||
} | ||
} | ||
|
||
/*! | ||
* \brief Infer the output datatype for binary arithmetic operators. | ||
* \param call The context Call to the operator. | ||
* \param ctx The error reporting context. | ||
* \param x1_sinfo The struct info of the first operand | ||
* \param x2_sinfo The struct info of the second operand | ||
* \param lhs_sinfo The struct info of the first operand | ||
* \param rhs_sinfo The struct info of the second operand | ||
* \return The inferred output dtype. | ||
* \throw Throw exception if the dtype of two input TensorStructInfo don’t match | ||
*/ | ||
inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& ctx, | ||
const TensorStructInfo& x1_sinfo, | ||
const TensorStructInfo& x2_sinfo) { | ||
if (x1_sinfo->IsUnknownDtype() || x2_sinfo->IsUnknownDtype()) { | ||
const StructInfo& lhs_sinfo, | ||
const StructInfo& rhs_sinfo) { | ||
auto lhs_dtype = GetElementDType(lhs_sinfo); | ||
auto rhs_dtype = GetElementDType(rhs_sinfo); | ||
if (lhs_dtype.is_void() || rhs_dtype.is_void()) { | ||
return DataType::Void(); | ||
} else if (x1_sinfo->dtype != x2_sinfo->dtype) { | ||
} else if (lhs_dtype != rhs_dtype) { | ||
ctx->ReportFatal(Diagnostic::Error(call) | ||
<< "Data types " << x1_sinfo->dtype << " and " << x2_sinfo->dtype | ||
<< " must be equal for binary operators"); | ||
<< "TypeErorr: " | ||
<< "Binary operators must have the same datatype for both operands. " | ||
<< "However, " << call << " uses datatype " << lhs_dtype | ||
<< " on the LHS (StructInfo of " << lhs_sinfo << "), and datatype " | ||
<< rhs_dtype << " on the RHS (StructInfo of " << rhs_sinfo << ")."); | ||
} | ||
return x1_sinfo->dtype; | ||
return lhs_dtype; | ||
} | ||
|
||
/*! | ||
* \brief Infer the output virtual device for binary arithmetic operators. | ||
* \param call The context Call to the operator. | ||
* \param ctx The error reporting context. | ||
* \param x1_sinfo The struct info of the first operand | ||
* \param x2_sinfo The struct info of the second operand | ||
* \param lhs_sinfo The struct info of the first operand | ||
* \param rhs_sinfo The struct info of the second operand | ||
* \return The inferred output vdevice. | ||
* \throw Throw exception if the vdevice of two input TensorStructInfo don’t match | ||
*/ | ||
inline Optional<VDevice> InferBinaryArithOpOutVDevice(const Call& call, const BlockBuilder& ctx, | ||
const TensorStructInfo& x1_sinfo, | ||
const TensorStructInfo& x2_sinfo) { | ||
if (!x1_sinfo->vdevice.defined() || !x1_sinfo->vdevice.value()->target.defined()) { | ||
return x2_sinfo->vdevice; | ||
const StructInfo& lhs_sinfo, | ||
const StructInfo& rhs_sinfo) { | ||
auto get_vdevice = [&](const StructInfo& sinfo) -> Optional<VDevice> { | ||
if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) { | ||
return tensor->vdevice; | ||
} else { | ||
return NullOpt; | ||
} | ||
}; | ||
|
||
auto lhs_vdevice = get_vdevice(lhs_sinfo); | ||
auto rhs_vdevice = get_vdevice(rhs_sinfo); | ||
|
||
if (!lhs_vdevice.defined() || !lhs_vdevice.value()->target.defined()) { | ||
return rhs_vdevice; | ||
} | ||
if (!x2_sinfo->vdevice.defined() || !x2_sinfo->vdevice.value()->target.defined()) { | ||
return x1_sinfo->vdevice; | ||
if (!rhs_vdevice.defined() || !rhs_vdevice.value()->target.defined()) { | ||
return lhs_vdevice; | ||
} | ||
if (x1_sinfo->vdevice.value() != x2_sinfo->vdevice.value()) { | ||
if (lhs_vdevice.value() != rhs_vdevice.value()) { | ||
ctx->ReportFatal(Diagnostic::Error(call) | ||
<< "VDevice " << x1_sinfo->vdevice.value() << " and " | ||
<< x2_sinfo->vdevice.value() << " must be equal for binary operators"); | ||
<< "TypeErorr: " | ||
<< "Binary operators with Tensor arguments " | ||
<< "must have the same VDevice for both operands. " | ||
<< "However, " << call << " has a LHS on VDevice " << lhs_vdevice | ||
<< " and a RHS on VDevice " << rhs_vdevice); | ||
} | ||
return x1_sinfo->vdevice; | ||
return lhs_vdevice; | ||
} | ||
|
||
/*! | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this necessarily be expected behavior? An
Object
could be anything, including things that dtype does not make sense for at all.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I went back and forth on it. There isn't currently a standard for whether
FInferStructInfo
should raise an error when the arguments are provably invalid, or if it should raise an error when the arguments are not provably valid. On the one hand,StructInfoLCA
returnsObjectStructInfo
as the common base class ofTensorStructInfo
andPrimStructInfo
, so anObjectStructInfo
could contain a valid instance of either. On the other hand, the current struct inference requires that the input be validated asTensorStructInfo
.Overall, I'm not sure which is the better behavior. For now, I'm updating this PR to explicitly require either
TensorStructInfo
orPrimStructInfo
, and to raise an exception forObjectStructInfo
, since allowingObjectStructInfo
would be an independent change.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally, I'm in favor of asking for a MatchCast if we can't draw a conclusion. Down the line, inserting MatchCasts via normalization rules would be a good policy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True. Currently,
FInferStructInfo
is called prior toFNormalize
, so inference could be inspecting an expression that hasn't yet been normalized. This was useful for providingFNormalize
forR.Prim
(ifPrimStructInfo
contains a known value, in-line that value), but I'm wondering if we should re-visit that.