Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Cannot override __getitem__ or __setitem__ for tensorclass with non-tensor data #698

Closed
alexanderswerdlow opened this issue Mar 5, 2024 · 6 comments · Fixed by #699
Assignees
Labels
bug Something isn't working

Comments

@alexanderswerdlow
Copy link

Describe the bug

As described here, it should be possible to override getitem and setitem to implement indexing for non-tensor data types.

To Reproduce

Following the linked issue:

import tensordict

@tensorclass
class MyClass:
    images: torch.Tensor
    captions: List[str]

    def __getitem__(self, item):
        c = super().__getitem__(item)
        c.captions = self.captions[item]
        return c

data = MyClass(torch.randn(2, 3, 64, 64), ["a", "b"], batch_size=[2])
print(data[0])

Expected behavior

The printed class to contain captions=["a"].

Additional context

It appears that monkey patching works if I copy and modify _getitem from tensorclass.py.

MyClass.__getitem__ = _getitem

In addition, this might be outside the scope of this issue and a separate feature request, but it seems that even with this monkey patching, stacking/concat does not properly handle this non-tensor data. Is this something that is supported?

I very often have non-tensordata that is associated with batch elements [e.g., a caption for each image] and it would be very weird to cat/stack and have the captions be unchanged.

@alexanderswerdlow alexanderswerdlow added the bug Something isn't working label Mar 5, 2024
@vmoens
Copy link
Contributor

vmoens commented Mar 5, 2024

Oddly I was just working on that :)
In general, I think #663 should solve this issue more generally but we can patch it (though I'm not sure of how I can make super() to work since there is no real inheritance...)!

@alexanderswerdlow
Copy link
Author

Wow thanks for the fast response! I'd be happy to test it out when you have a working version :)

@vmoens
Copy link
Contributor

vmoens commented Mar 6, 2024

Now your new method will be used but super won't work. I can make a follow-up PR to make this a documented option. Because we don't explicitly inherit from anything but use a decorator instead, I think that making super() to work won't be a good idea anyway since it'll break python convention and people will end up with a class that inherits from something they never asked for (in a way, you don't expect super() to do anything with a dataclass, why should it?).

To me the API should me more something like

@tensorclass
class MyClass:
    X: torch.tensor

    def __setitem__(self, name, value):
        # your code here
        self.__tensorclass_function__().__setitem__(name, value)

or

        self.__tensorclass_function__("__setitem__", name, value)

whichever people think is more appropriate.

@alexanderswerdlow
Copy link
Author

@vmoens Thanks so much. I tried playing around a bit but I wasn't able to get things working as I'd expect. Perhaps my use-case just isn't supported but just wanted to let you know what I encountered.

I wasn't able to use __tensorclass_function__ as you described (perhaps this was meant to be a generic statement, e.g., $tensorclass_function?), but I was able to override e.g. __getitem__ as shown below. I wasn't sure how to make things work for stack/cat, etc. so I tried overriding the _get and _get_at functions but that didn't seem to do it.

from tensordict import tensorclass
from tensordict.tensorclass import _getitem, _setitem, _get, _get_at, _set, _set_at_
import torch

def my_setitem(self, item, value):
    print(f"Called setitem with {item} and type {type(value)}")
    _setitem(self, item, value)
    if isinstance(self.captions, list):
        self.captions[item] = value.captions
    else:
        raise ValueError(f"Invalid type for captions: {type(item)} and {type(self.captions)} and {type(value.captions)}")

def my_getitem(self, name):
    print(f"Called getitem with {name}")
    obj = _getitem(self, name)
    obj.captions = self.captions[name]
    return obj

def my_get(self, key):
    print("Called get")
    obj = _get(self, key)
    return obj

def my_get_at(self, key, idx):
    print("Called get_at")
    obj = _get_at(self, key, idx)
    return obj

def my_set_at_(self, key, value, idx):
    print("Called set_at_")
    _set_at_(self, key, value, idx)

def my_set(self, key, value):
    print("Called set")
    _set(self, key, value)

@tensorclass
class MyClass:
    images: torch.Tensor
    captions: List[str]
    
MyClass.__setitem__ = my_setitem
MyClass.__getitem__ = my_getitem
MyClass.get__ = my_get
MyClass.get_at = my_get_at
# MyClass.set = my_set # Causes an error
MyClass.set_at_ = my_set_at_

data = MyClass(torch.ones(4, 3, 64, 64), ["a", "b", "c", "d"], batch_size=[4])

data.images = torch.randn(4, 3, 64, 64) # Works
data.captions = ["d", "c", "b", "a"] # Works

data[0].images = torch.randn(3, 64, 64) # Works
data[0].captions = "a" # Understandably does not work but unintuitive
print(data[0].captions) # Prints "d"

data[1] = MyClass(torch.randn(3, 64, 64), "b", batch_size=[]) # Works
data[2:4] = MyClass(torch.randn(2, 3, 64, 64), ["e", "f"], batch_size=[2]) # Works
print(data[2:4].captions) # Prints ["e", "f"]

cat_data = torch.cat([data, data], dim=0) # Does not modify captions
stack_data = torch.stack([data, data], dim=0) # Does not modify captions

breakpoint()

@vmoens
Copy link
Contributor

vmoens commented Mar 6, 2024

I wasn't able to use tensorclass_function as you described (perhaps this was meant to be a generic statement, e.g., $tensorclass_function?), but I was able to override e.g. getitem as shown below. I wasn't sure how to make things work for stack/cat, etc. so I tried overriding the _get and _get_at functions but that didn't seem to do it.

The important part of my message was this

Now your new method will be used but super won't work. I can make a follow-up PR to make this a documented option.

i.e. I did not implement it yet. You can overwrite the function but you won't be able to call super() or anything similar as of now. I was asking for feedback about the feature before implementing it :)

@alexanderswerdlow
Copy link
Author

Ah my bad! I think it's a good workaround (personally I like the first option where it's called directly and not with the string name, as that seems unfriendly to type checking), but it's not immediately clear as an end-user how it would affect the variety of tensor operations that tensordict supports.

I think the extent of supported operations for non-tensor data should be well-defined [e.g., you need to implement these 4 methods to cover all supported operations] and some are not possible, ideally there's a [disableable] warning message that's shown when you perform an unsupported operation. Otherwise, I'd be concerned that some operations work just fine but there are some unexpected corner cases that cause surprises. In the past I've implemented my own very basic version of a tensorclass which only supported a couple operations but worked in a very understandable way which I think is critical.

I'm not sure how complicated it would be to support all the different kinds of advanced indexing possible on tensors. One possibility I could think of would be to only support modifying non-tensordata for a single batch dimension although I can easily see situations that I'd want to implement e.g., 2 dimensions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants