Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Failing to slice the CatLinearOperator when indexes are negative or when using boolean array #79

Open
MoiseRousseau opened this issue Sep 14, 2023 · 4 comments
Labels
bug Something isn't working

Comments

@MoiseRousseau
Copy link

🐛 Bug

When slicing the CatLinearOperator using a negative index, the final shape of the slice does not match the expected shape and an error is returned. This fails at least for ToeplitzLinearOperator, the DiagLinearOperator and the IdentityLinearOperator.

To reproduce

** Code snippet to reproduce **

from linear_operator.operators import IdentityLinearOperator as Ops
from linear_operator.operators import cat as cat_ops

N = 8
base = cat_ops([Ops(N) for _ in range(2)], dim=1)
print(base.shape) #should be 8,16
print(base[:,3:base.shape[-1]-3].shape) #should be 8,10
print(base[:,3:-3].shape) #fail...

** Stack trace/error message **

torch.Size([8, 16])
torch.Size([8, 10])
Traceback (most recent call last):
  File "/home/moise/Program/moise/linear_operator/debug.py", line 8, in <module>
    print(base[:,3:-3].shape) #fail...
  File "/home/moise/Program/moise/linear_operator/linear_operator/operators/_linear_operator.py", line 2870, in __getitem__
    raise RuntimeError(
RuntimeError: CatLinearOperator.__getitem__ failed! Expected a final shape of size torch.Size([8, 10]), got torch.Size([8, 5]). This is a bug with LinearOperator, or your custom LinearOperator.

Expected Behavior

The slice behave as it is working when using positive indexes.

System information

LinearOperator Version 0.5.2
PyTorch Version 2.0.1
Ubuntu 22.04

@MoiseRousseau MoiseRousseau added the bug Something isn't working label Sep 14, 2023
@MoiseRousseau
Copy link
Author

New findings: I also get a similar error when slicing using boolean array and without using the CatLinearOperator, such as:

from linear_operator.operators import IdentityLinearOperator as Ops

N = 4
cond = [True,False,False,True]
ops = Ops(N)
print(ops.shape)
ops[:,cond]

Which gives:

torch.Size([4, 4])
Traceback (most recent call last):
  File "/home/moise/Program/moise/linear_operator/debug.py", line 8, in <module>
    ops[:,cond]
  File "/home/moise/Program/moise/linear_operator/linear_operator/operators/_linear_operator.py", line 2870, in __getitem__
    raise RuntimeError(
RuntimeError: IdentityLinearOperator.__getitem__ failed! Expected a final shape of size torch.Size([4, 4]), got torch.Size([4, 2]). This is a bug with LinearOperator, or your custom LinearOperator.

@MoiseRousseau MoiseRousseau changed the title [Bug] Failing to slice the CatLinearOperator when index are negative [Bug] Failing to slice the CatLinearOperator when indexes are negative Sep 14, 2023
@MoiseRousseau MoiseRousseau changed the title [Bug] Failing to slice the CatLinearOperator when indexes are negative [Bug] Failing to slice the CatLinearOperator when indexes are negative or when using boolean array Sep 14, 2023
@Balandat
Copy link
Collaborator

Looks like there may be a number of places where negative indexing isn't properly supported. I'll put up a fix for the CatLinearOperator case, but this should probably be audited more comprehensively.

I also don't think we've given much though to supporting boolean indexing with linear_operator - @gpleiss is that right?

Balandat added a commit to Balandat/linear_operator that referenced this issue Sep 18, 2023
Addresses cornellius-gp#79

Looks like there may be a number of places where negative indexing isn't properly supported. This should probably be audited more comprehensively.
@gpleiss
Copy link
Member

gpleiss commented Sep 20, 2023

Boolean indexing sounds tricky with linear operators. @MoiseRousseau do you have a good use case?

gpleiss pushed a commit that referenced this issue Sep 20, 2023
Addresses #79

Looks like there may be a number of places where negative indexing isn't properly supported. This should probably be audited more comprehensively.
@MoiseRousseau
Copy link
Author

I found a workaround doing torch.argwhere(bool_array) and then slice using the index. I was just reporting the error. Maybe this can be a way to implement it (even if this is suboptimal) ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants