Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] object has no attribute _differentiable_kwargs #101

Open
jaghili opened this issue Sep 13, 2024 · 0 comments
Open

[Bug] object has no attribute _differentiable_kwargs #101

jaghili opened this issue Sep 13, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@jaghili
Copy link

jaghili commented Sep 13, 2024

🐛 Bug

  • Install torch==2.0.1
  • Install linear_operator 0.5.3 with pip

To reproduce

I took the snippet from the README

import linear_operator
import torch

class DiagLinearOperator(linear_operator.LinearOperator):
    r"""
    A LinearOperator representing a diagonal matrix.
    """
    def __init__(self, diag):
        # diag: the vector that defines the diagonal of the matrix
        self.diag = diag

    def _matmul(self, v):
        return self.diag.unsqueeze(-1) * v

    def _size(self):
        return torch.Size([*self.diag.shape, self.diag.size(-1)])

    def _transpose_nonbatch(self):
        return self  # Diagonal matrices are symmetric

    # this function is optional, but it will accelerate computation
    def logdet(self):
        return self.diag.log().sum(dim=-1)
# ...

D = DiagLinearOperator(torch.tensor([1., 2., 3.]))
# Represents the matrix
#   [[1., 0., 0.],
#    [0., 2., 0.],
#    [0., 0., 3.]]
torch.matmul(D, torch.tensor([4., 5., 6.]))
# Returns [4., 10., 18.]

** Stack trace/error message **

Traceback (most recent call last):
  File "/home/jagh/codes/ng/src/a.py", line 31, in <module>
    torch.matmul(D, torch.tensor([4., 5., 6.]))
  File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 2970, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 1839, in matmul
    return Matmul.apply(self.representation_tree(), other, *self.representation())
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 2072, in representation_tree
    return LinearOperatorRepresentationTree(self)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/linear_operator_representation_tree.py", line 8, in __init__
    self._differentiable_kwarg_names = linear_op._differentiable_kwargs.keys()
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'DiagLinearOperator' object has no attribute '_differentiable_kwargs'

Expected Behavior

Snippet should return [4., 10., 18.]

Additional context

I added self._differentiable_kwargs = { some dict }, which seems by pass the problem, but I get another message with self._nondifferentiable_kwargs I don't know how to setup. Did I miss something?

@jaghili jaghili added the bug Something isn't working label Sep 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant