-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support aten._trilinear
and improve einsum
decomposition
#3784
Conversation
…ts a trilinear einstein sum. WIP, it currently builds, but fails at lowering to linalg
Lowers to torch backend, but unable to lower to linalg
…ephen-aten-_trilinear-op
There's a discrepancy between the way that _trilinear and einsum op handles the second test case (in torch python). Troubleshooting this discrepancy to try and figure out why/where the two ops differ.
Add more test cases, Add PyTorch _trilinear "bug" to xfail set
…ephen-aten-_trilinear-op
torch.ops.aten._trilinear
torch.ops.aten._trilinear
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Glad you got something working Stephen! The major review points here:
- We need to fail the conversion in the cases we don't support. It's not sufficient to xfail the tests for unsupported cases, because a downstream user of the tool isn't going to run a big model and say "Oh this random shape is messed up, it must be that one esoteric e2e test for this one op that I saw one day". We need to report a match failure so that the op actually doesn't get converted, so we don't have model support people spending days debugging a silently failing conversion.
- Related to 1. What does unroll dim do? It needs to be included, or if unrolldim !=0 we also need to report an "unimplemented" match failure.
- Not really major, but glad we don't have to use einsum. I think the einsum changes are generally good, but it might be better to move them into a different patch. I'm fine leaving them in here, but the commit messaging will seem odd if anyone wants to trace back the history.
|
If it is a genuine bug, let's at least file an issue in pytorch and emit a warning.
I don't think we will need to implement this, but no point in reporting a match failure if the unrollDim is non-constant. Just make a comment somewhere in the conversion that the unrollDim does not change the result of the operation, so we do not use it in the conversion.
Ah, good that it resolves some failing tests. We should rename the title to something like "support |
torch.ops.aten._trilinear
aten._trilinear
and improve einsum
decomposition
…s not included in sumDim, Add note in func description that `unrollDim` is unused
PyTorch bug filed here and
Comment that unrollDim does not impact output and is unused included in function description.
Updated title of PR to: "support aten._trilinear and improve einsum decomposition" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
Tracking
Issue
TorchToLinalg Op Support
Description
Aten_TrilinearOp is an implementation of a "trilinear einstein sum". Essentially, just an einsum across 3 tensors.
There are a few inputs:
Tensor Inputs
Expands
These inputs allow you to unsqueeze an input tensor at the specified dims as a pre-processing step to make the shapes compatible for the rest of the op:
sumdim
unroll_dim
Solution
After trying a bunch of more complicated approaches for it, this op actually ended up being quite simple: See _trilinear
_trilinear = (i1.unsqueeze(expand1) * i2.unsqueeze(expand2) * i3.unsqueeze(expand3)).sum(sumdim)
Wish I saw this earlier, but watcha gonna do: 🙃
Not Reusing AtenEinsumOp
Frankly, I found multiple cases where valid inputs would have numerical mismatches for EinsumOp, even when running tests against EinsumOp directly. I think it has something to do with the singleton dimensions. Will need to look into this further, but once I realized the simplified approach, it appeared to be more reliable and much simpler.
Either way (credit to @zjgarvey), there are improvements to the einsum op here. When I was originally trying to use the op, intermediate tensors were being flattened properly, but then its 0th dimension was being cast from a static dim to a dynamic dim due to integers not folding correctly in the MLIR. Figured it's worth keeping these improvements for future reusers of EinsumOp.
The zero'd out dim "bug"
For some reason, if you specify a dimension in all
expands
,The _trilinear op would specify
0
for that dimension in the output shape, unless it was also included insumdim
. This goes against the implementation of torch.einsum:And is just straight up incorrect mathematically. I considered "replacing" singleton dims with zeroed out dims, but that seemed like carrying over a bug. Instead, I included a test for the case, verified that the singleton dimensions were handled the way that torch.einsum handles it, instead of torch._trilinear, and xfailed it with a note as to why.