-
Notifications
You must be signed in to change notification settings - Fork 74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Cannot override __getitem__ or __setitem__ for tensorclass with non-tensor data #698
Comments
Oddly I was just working on that :) |
Wow thanks for the fast response! I'd be happy to test it out when you have a working version :) |
Now your new method will be used but To me the API should me more something like @tensorclass
class MyClass:
X: torch.tensor
def __setitem__(self, name, value):
# your code here
self.__tensorclass_function__().__setitem__(name, value) or self.__tensorclass_function__("__setitem__", name, value) whichever people think is more appropriate. |
@vmoens Thanks so much. I tried playing around a bit but I wasn't able to get things working as I'd expect. Perhaps my use-case just isn't supported but just wanted to let you know what I encountered. I wasn't able to use
|
The important part of my message was this
i.e. I did not implement it yet. You can overwrite the function but you won't be able to call super() or anything similar as of now. I was asking for feedback about the feature before implementing it :) |
Ah my bad! I think it's a good workaround (personally I like the first option where it's called directly and not with the string name, as that seems unfriendly to type checking), but it's not immediately clear as an end-user how it would affect the variety of tensor operations that tensordict supports. I think the extent of supported operations for non-tensor data should be well-defined [e.g., you need to implement these 4 methods to cover all supported operations] and some are not possible, ideally there's a [disableable] warning message that's shown when you perform an unsupported operation. Otherwise, I'd be concerned that some operations work just fine but there are some unexpected corner cases that cause surprises. In the past I've implemented my own very basic version of a tensorclass which only supported a couple operations but worked in a very understandable way which I think is critical. I'm not sure how complicated it would be to support all the different kinds of advanced indexing possible on tensors. One possibility I could think of would be to only support modifying non-tensordata for a single batch dimension although I can easily see situations that I'd want to implement e.g., 2 dimensions. |
Describe the bug
As described here, it should be possible to override getitem and setitem to implement indexing for non-tensor data types.
To Reproduce
Following the linked issue:
Expected behavior
The printed class to contain
captions=["a"]
.Additional context
It appears that monkey patching works if I copy and modify
_getitem
from tensorclass.py.In addition, this might be outside the scope of this issue and a separate feature request, but it seems that even with this monkey patching, stacking/concat does not properly handle this non-tensor data. Is this something that is supported?
I very often have non-tensordata that is associated with batch elements [e.g., a caption for each image] and it would be very weird to cat/stack and have the captions be unchanged.
The text was updated successfully, but these errors were encountered: