Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.aten._trilinear #848

Open
Tracked by #347
vivekkhandelwal1 opened this issue Sep 26, 2024 · 0 comments
Open
Tracked by #347

torch.aten._trilinear #848

vivekkhandelwal1 opened this issue Sep 26, 2024 · 0 comments

Comments

@vivekkhandelwal1
Copy link
Contributor

vivekkhandelwal1 commented Sep 26, 2024

Assigned to @stbaione

zjgarvey pushed a commit to llvm/torch-mlir that referenced this issue Oct 31, 2024
# Tracking
[Issue](nod-ai/SHARK-ModelDev#848)
[TorchToLinalg Op
Support](nod-ai/SHARK-ModelDev#347)

# Description

Aten_TrilinearOp is an implementation of a "trilinear einstein sum".
Essentially, just an einsum across 3 tensors.

There are a few inputs:
## Tensor Inputs
- i1, i2, i3 - The three input tensors for the _trilinear op.
## Expands 
These inputs allow you to unsqueeze an input tensor at the specified
dims as a pre-processing step to make the shapes compatible for the rest
of the op:
- expand1: List[int], expand2: List[int], expand3: List[int]

## sumdim
- sumdim: List[int] - After applying element wise multiplication, the
values in sumdim denote where to collapse a dimension by summing over it

## unroll_dim
- unroll_dim: int - In the PyTorch implementation, this specifies a
dimension where you could slice the input tensors, multiply and sum
them, then concatenate the results in an output tensor. This complicates
the implementation significantly, but doesn't change the result, so I
opted against it. Along with that, a previously accepted path for
solving this involved reusing the AtenEinsumOp, which also would also
ignore this input.


# Solution

After trying a bunch of more complicated approaches for it, this op
actually ended up being quite simple: [See
_trilinear](https://dev-discuss.pytorch.org/t/defining-the-core-aten-opset/1464)

`_trilinear = (i1.unsqueeze(expand1) * i2.unsqueeze(expand2) *
i3.unsqueeze(expand3)).sum(sumdim)`

Wish I saw this earlier, but watcha gonna do: 🙃

## Not Reusing AtenEinsumOp
Frankly, I found multiple cases where valid inputs would have numerical
mismatches for EinsumOp, even when running tests against EinsumOp
directly. I think it has something to do with the singleton dimensions.
Will need to look into this further, but once I realized the simplified
approach, it appeared to be more reliable and much simpler.

Either way (credit to @zjgarvey), there are improvements to the einsum
op here. When I was originally trying to use the op, intermediate
tensors were being flattened properly, but then its 0th dimension was
being cast from a static dim to a dynamic dim due to integers not
folding correctly in the MLIR. Figured it's worth keeping these
improvements for future reusers of EinsumOp.

# The zero'd out dim "bug"

For some reason, if you specify a dimension in all `expands`,

```i.e. 
[expand1=[0], expand2=[0], expand3=[0]],
[expand1=[1], expand2=[1], expand3=[1]]
```

The _trilinear op would specify `0` for that dimension in the output
shape, unless it was also included in `sumdim`. This goes against the
implementation of torch.einsum:

```
>>> a, b, c = [torch.rand(1, 3, 3, 3) for i in range(3)] # Simulate expand at dim=0 for all input tensors
>>> torch.einsum('abcd,abcd,abcd->abcd', a, b, c).shape
torch.Size([1, 3, 3, 3])
```

And is just straight up incorrect mathematically. I considered
"replacing" singleton dims with zeroed out dims, but that seemed like
carrying over a bug. Instead, I included a test for the case, verified
that the singleton dimensions were handled the way that torch.einsum
handles it, instead of torch._trilinear, and xfailed it with a note as
to why.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant