Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Overly strict type checking #1131

Closed
davidegraff opened this issue Dec 6, 2024 · 1 comment · Fixed by #1143
Closed

[BUG] Overly strict type checking #1131

davidegraff opened this issue Dec 6, 2024 · 1 comment · Fixed by #1143
Assignees
Labels
bug Something isn't working

Comments

@davidegraff
Copy link

davidegraff commented Dec 6, 2024

Describe the bug

I'm currently using hydra in combination with TensorDictModule and am running into TypeErrors when building modules because the underlying config uses container subclasses rather than the base class.

Example

My hydra config looks like this:

# configs/config.yaml
linear:
  module:
    _target_: torch.nn.Linear
    input_dim: 64
    output_dim: 64
  in_keys: [x]
  out_keys: [h]

which can be parsed in python like so:

# app.py
import hydra
from omegaconf import DictConfig
from tensordict.nn import TensorDictModule

@hydra.main(config_path="configs", config_name="config", version_base=None)
def main(cfg: DictConfig):
    linear = hydra.utils.instantiate(cfg.linear)
    td_linear = TensorDictModule(linear, cfg.in_keys, cfg.out_keys)

if __name__ == "__main__":
    main()

but running this script gives a ValueError:

$ python app.py
ValueError: out_keys must be of type list, str or tuples of str.

Proposed solution

Replace all types in tensordict/nn/common.py with their collections.abc counterpart (which is the python recommendation). For example, change

# tensordict/nn/common.py#L928
        if isinstance(in_keys, dict):
            # write the kwargs and create a list instead
            _in_keys = []
            self._kwargs = []
            for key, value in in_keys.items():
                self._kwargs.append(value)
                _in_keys.append(key)
            in_keys = _in_keys
        else:
            if isinstance(in_keys, (str, tuple)):
                in_keys = [in_keys]
            elif not isinstance(in_keys, list):
                raise ValueError(self._IN_KEY_ERR)
            self._kwargs = None

to

        if isinstance(in_keys, collections.abc.Mapping):
            # write the kwargs and create a list instead
            _in_keys = []
            self._kwargs = []
            for key, value in in_keys.items():
                self._kwargs.append(value)
                _in_keys.append(key)
            in_keys = _in_keys
        else:
            if isinstance(in_keys, (str, tuple)):
                in_keys = [in_keys]
            elif not isinstance(in_keys, collections.abc.MutableSequence): # possibly even the more general `Iterable`
                raise ValueError(self._IN_KEY_ERR)
            self._kwargs = None

I don't think this is critical as it's not even a "bug" perse and is easy to get around, but it would be a nice QOL change. Thanks as always!

@davidegraff davidegraff added the bug Something isn't working label Dec 6, 2024
@vmoens
Copy link
Contributor

vmoens commented Dec 16, 2024

Sure we can make this change!
I'll take care of it

@vmoens vmoens linked a pull request Dec 17, 2024 that will close this issue
@vmoens vmoens closed this as completed Dec 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants