Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why is Tensorclass implemented as a decorator? #663

Open
kurt-stolle opened this issue Feb 6, 2024 · 10 comments
Open

Why is Tensorclass implemented as a decorator? #663

kurt-stolle opened this issue Feb 6, 2024 · 10 comments

Comments

@kurt-stolle
Copy link
Contributor

kurt-stolle commented Feb 6, 2024

Current Tensorclass interface

To define a Tensorclass, we could write:

@tensorclass
class MyTensorclass:
    foo: Tensor
    bar: Tensor

The @tensorclass decorator then generates a dataclass-like container class, adding various methods to implement (parts of) the TensorDictBase abstraction. Notice that unlike dataclasses, where the __init__ signature may entirely be inferred from the class attribute annotations, the current decorator also augments __init__ with extra keyword arguments batch_size, device and names.

Considering that the decorator has significant (side-)effects to the resulting class, I am confused on why the API has been defined in this way. In this issue, I ask for clarification on this design choice.

Simplified interface (conceptual)

Specifically, would a simplified representation using subclassing not suffice? For example, consider the following notation:

class MyTensorclass(Tensorclass):
    foo: Tensor
    bar: Tensor

The Tensorclass baseclass then has the same effects as @tensorclass, and also:

  1. Enables a more straightforward implementation in tensordict
  2. Works with builtins issubclass and isinstance
  3. Is compatible with static type checking (e.g. extra keyword arguments batch_size etc.)
  4. Is easily understood as canonical Python

Polyfill

To further clarify the effects of the subclass above, here's a quick polyfill that allows the definition of tensorclasses using the subclass-based API above in the current decorator-based paradigm:

from types import resolve_bases
from typing import dataclass_transform, TYPE_CHECKING, Sequence
from dataclasses import dataclass, KW_ONLY
import torch
from torch.types import Device
from tensordict import tensorclass

@dataclass_transform()
class _TensorclassMeta(type):
    def __new__(metacls, name, bases, ns, **kwds):
        bases = resolve_bases(bases)
        cls = super().__new__(metacls, name, tuple(bases), ns, **kwds)
        return tensorclass(cls)

@dataclass_transform()
class Tensorclass(metaclass=_TensorclassMeta):  # Or: TensorDictBase subclass
    # Demonstrates point (3) above
    if TYPE_CHECKING:  
        _: KW_ONLY
        batch_size: torch.Size | Sequence[int]
        device: Device | str | None = None
        ...  # Other  methods
@vmoens
Copy link
Contributor

vmoens commented Feb 6, 2024

Hey thanks for proposing this!
I agree with all your points, inheritance is easy and understandable for most of the python community, and way less hacky (which is the one thing I do not like about dataclass and consequently tensorclass too).

We thought about that initially, but there are a couple of reasons that made us go for the @tensorclass decorator.
It has mainly to do with the fact that a large community likes @dataclass for the very reason that it does not inherit from anything, you can build a class that is the parent of all subclasses with little trouble.
The idea of tensorclass is to have a dataclass on steroids. If we'd made a TensorClass base, it would have been a mixture of @dataclass and inheritance, a strange class to play with for people used to dataclasses.

Note that we could implement a class & metaclass that implement the isinstance if that makes things easier?

from tensordict import tensorclass, is_tensorclass
import torch

class SomeClassMeta(type):
    def __instancecheck__(self, instance):
        if is_tensorclass(instance):
            return True
        return False
    def __subclasscheck__(self, subclass):
        if is_tensorclass(subclass):
            return True
        return False

class TensorClass(metaclass=SomeClassMeta):
    pass

@tensorclass
class MyDC:
    x: torch.Tensor

c = MyDC(1, batch_size=[])
assert isinstance(c, TensorClass)
assert issubclass(type(c), TensorClass)

That only partially fixes your issues like type checking though.

I'm personally not a huge fan of dataclass but I understand why people like them. I think that most of the time it has to do with typing (i.e., the content is more explicit than with a dictionary). Your solution still offers that so I wouldn't put it aside straight away, but if we consider this we must make sure that it won't harm adoption of the feature for the target audience (ie, users who are accustomed to @DataClass decorator).

RE this point

Notice that unlike dataclasses, where the init signature may entirely be inferred from the class attribute annotations, the current decorator also augments init with extra keyword arguments batch_size, device and names.

I don't see how that relates to the four points you raised about the benefits of subclassing. To me @dataclass suffers from the 4 pitfalls you identified there too, doesn't it?

cc @shagunsodhani @apbard @tcbegley who may have some other thoughts to share on this

@kurt-stolle
Copy link
Contributor Author

kurt-stolle commented Feb 6, 2024

Thanks for your swift reply @vmoens! This clears up my main questions on the design choices made.

While I did not intend to advocate immediate alterations to the status quo, I am curious to learn what the intended role of tensorclasses is exactly. From the user perspective, the main reasons listed for using a tensorclass come from the keys being explicitly defined. In that case, weighting the library design in a way that prefers users' code aesthetics (i.e. being similar to @dataclass) over their static type checking capabilities seems a bit off-balance.

Second, was a solution that specifically addresses typing a la typing.TypedDict ever explored?

RE:

It has mainly to do with the fact that a large community likes @DataClass for the very reason that it does not inherit from anything, you can build a class that is the parent of all subclasses with little trouble.

It is my perception that 'it not inheriting from anything' in this context is more relevant to dataclasses than it is to tensorclasses. I would argue:

  1. Dataclasses are a method for defining a class without many implications on functionality. As such, I could understand why one would not want two functionally different dataclasses to share a common base, as the only thing they share is how they were defined and not what they represent. Thus, @dataclass can be interpreted as a code generator that outputs the body of a class.
  2. Tensorclasses are also a method for defining a class, but has significant implications for the functionality. This added functionality is mostly in the form of binding class methods, like would normally be done via a superclass.

Consequently, two tensorclasses will have a larger degree of shared functionality than two dataclasses. This is difficult to define in an exact manner, though.

RE:

I don't see how that relates to the four points you raised about the benefits of subclassing. To me @DataClass suffers from the 4 pitfalls you identified there too, doesn't it?

It does indeed. Let me further elaborate on why I identify this property of init signatures. It is mostly related to problematic typing:

  1. Classes made with @dataclass already have good static type check support (e.g. as shown in Pyright and Mypy)
  2. Until recently (PEP 681), such behavior was non-extensible to other classes with dataclass-like behaviour.
    • To add typing support for @tensorclass, we could follow PEP 681 and decorate with @dataclass_transform from typing.
    • However, this would result in batch_size and device missing from the signature at type-check time. To my knowledge, there is no way around this using the decorator-based approach.
  3. If defined as a subclass, we can use @dataclass_transform and TYPE_CHECKING to define the interface of the __init__ method statically in code, providing the user with proper typing using simple and canonical Python.
    • The added batch_size and device are now recogized by type checkers as being part of the keyword arguments of __init__.
    • See code block in the original issue above.

@vmoens
Copy link
Contributor

vmoens commented Feb 7, 2024

Thank you for your comments, a lot of good points there.

To provide some context on typing: tensordict, like pytorch, is not strongly typed. We do not currently use mypy as part of our CI and it's unlikely that we will in the future.
Regardless of my personal views on type checking (which are not pertinent to this discussion), the potential integration of tensordict's core features into pytorch repo makes it highly improbable that it will ever be type-checked. However, we understand that some users may wish to use mypy with tensordict, and we must consider how to support this.

Proposed Solution

Your suggestion of dataclass_transform looks very cool to me. If I understand correctly, the type checker will only be satisfied with inheritance, correct? This could be a feasible path forward.

Given this, we would find ourselves in the following scenarios based on user requirements:

  • If a user desires a decorator like dataclass (current status) without typing: all python versions will work.
  • If a user desires a decorator like dataclass (current status) with typing: This is not possible.
  • If a user is open to subclassing or compromising the decorator with subclassing and typing: This requires python 3.11.

Would this status quo be acceptable to you? Is the dual usage (as a decorator or subclass) even desirable? It's worth noting that offering two methods to achieve the same result can lead to confusion. Many users are currently familiar with using @tensorclass, so any decision to deprecate it must be carefully considered.

We now need to evaluate the feasibility and implications of these options...

@vmoens
Copy link
Contributor

vmoens commented Oct 31, 2024

Hey @kurt-stolle I was wondering if you had thoughts about this? #1067

        >>> from typing import Any
        >>> import torch
        >>> from tensordict import TensorClass
        >>> class Foo(TensorClass):
        ...     tensor: torch.Tensor
        ...     non_tensor: Any
        ...     nested: Any = None
        >>> foo = Foo(tensor=torch.randn(3), non_tensor="a string!", nested=None, batch_size=[3])
        >>> print(foo)
        Foo(
            non_tensor=NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None),
            tensor=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
            nested=None,
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False)

The only problem I see atm is that autocomplete doesn't really work (you can have either the tensordict args or the dataclass fields but can't make both).

@kurt-stolle
Copy link
Contributor Author

kurt-stolle commented Nov 12, 2024

Hi @vmoens, you mention that autocomplete does not work. When I use my implementation from the original issue above, I do get completions in the editor.

It looks like this in Neovim and VSCode:
Neovim LSP
VSCode

Just to verify, are we talking about the same thing? I did not test the code in #1067.

@vmoens
Copy link
Contributor

vmoens commented Nov 12, 2024

Hmmm I cannot make this work with python 3.10
What works currently is that all the methods from tensordict / tensorclass appear in the tensorclass object, but the attributes don't appear during construction:
image

image

@kurt-stolle
Copy link
Contributor Author

I am not sure how to replicate autocomplete behavior in interactive python, as my experience is limited there. I am also using Python 3.10, however.

I suppose it would work if you remove the __init__(...) function from the interface definition file tensordict/tensorclass.pyi (line 69) and replace it with mocked dataclass fields that represent the keyword arguments. This can be done using KW_ONLY (docs).

Practically, the __init__ definition in tensordict/tensorclass.pyi is removed and the following dataclass field definitions are added/mocked:

_: KW_ONLY
batch_size: Sequence[int] | torch.Size | int | None = None
device: DeviceType | None = None
names: Sequence[str] | None = None
non_blocking: bool | None = None
lock: bool = False

These fields represent the keyword arguments passed to the constructor.

This has the added side-effect that batch_size, device, etc. would become recognized as fields by the language server. I am not sure whether that is desired, e.g. if all these parameters are also accessible as fields after initialization. If not, then they could be wrapped with InitVar (docs).

I am a bit limited in time currently, but if desired I could investigate this and submit an alternative PR to #1067 if I manage to replicate and address the autocomplete issue you desribe.

@kurt-stolle
Copy link
Contributor Author

Upon further inspection, I also see that @dataclass_transform() is not added to the TensorClass in the interface file tensorclass/tensorclass.pyi on #1067. In my example, I added this decorator to both the Tensorclass and its metaclass.

@vmoens
Copy link
Contributor

vmoens commented Nov 12, 2024

I am not sure how to replicate autocomplete behavior in interactive python, as my experience is limited there. I am also using Python 3.10, however.

Same here, I was just using that to quickly experiment autocompletion. Writing a script isn't different though.

I am a bit limited in time currently, but if desired I could investigate this and submit an alternative PR to #1067 if I manage to replicate and address the autocomplete issue you desribe.

I'm happy with anything as long as existing tests pass!

Upon further inspection, I also see that @dataclass_transform() is not added to the TensorClass in the interface file tensorclass/tensorclass.pyi on #1067. In my example, I added this decorator to both the Tensorclass and its metaclass.

Yeah I tried all combinations (around TensorClass, around the metaclass and around the class in the stub file) and none helped with autocomplete...

By the way, how do you get dataclass_transform to work without parenthesis? (@dataclass_transform())?

@kurt-stolle
Copy link
Contributor Author

I'm happy with anything as long as existing tests pass!

I will submit a PR once I find some time to have a closer look.

By the way, how do you get dataclass_transform to work without parenthesis? (@dataclass_transform())?

That does not work! I see that I made an error while cleaning an excerpt from my own code. It has been corrected.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants