Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Avoid reshape(-1) for inputs to DreamerActorLoss #2496

Merged
merged 1 commit into from
Oct 18, 2024

Conversation

kurtamohler
Copy link
Collaborator

Description

Avoid reshaping inputs to DreamerActorLoss.

Motivation and Context

Follow-up to #2494

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

Copy link

pytorch-bot bot commented Oct 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2496

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 6 Unrelated Failures

As of commit bca6b79 with merge base d894358 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 15, 2024
loss_td, fake_data = loss_module(tensordict)
# NOTE: Input is reshaped because GRUCell (which is part of the
# RSSMPrior module in `mb_env`) expects input to be either 1D or 2D
loss_td, fake_data = loss_module(tensordict.reshape(-1))
Copy link
Collaborator Author

@kurtamohler kurtamohler Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if there is a better way to fix this test. I suppose it could be possible to just reshape the direct input to the GRUCell?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we need to reshape we should reshape - but another option here would be to use vmap
like:

if tensordict.ndim > 1:
    loss_td, fake_data = vmap(loss_module, (0,))(tensordict)

(gru works with vmap as long as you are using the python only version in torchrl.modules)

Copy link
Collaborator Author

@kurtamohler kurtamohler Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I try using VmapModule, I get this error:

  File "/home/endoplasm/develop/torchrl-1/test/test_cost.py", line 10338, in test_dreamer_actor
    loss_td, fake_data = VmapModule(loss_module, (0,))(tensordict)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/endoplasm/develop/torchrl-1/torchrl/modules/tensordict_module/common.py", line 454, in __init__
    self.in_keys = module.in_keys
                   ^^^^^^^^^^^^^^
  File "/home/endoplasm/develop/torchrl-1/torchrl/objectives/common.py", line 441, in __getattr__
    return super().__getattr__(item)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/endoplasm/miniconda/envs/torchrl-1/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
    raise AttributeError(
AttributeError: 'DreamerActorLoss' object has no attribute 'in_keys'

I'll probably just leave the reshape for now, but I would like to understand this.

Indeed if I try to access loss_module.in_keys directly, I also get the above error. But I can access the in_keys of the actor model and world model within the loss module:

print(loss_module.actor_model.in_keys)
print(loss_module.model_based_env.world_model.in_keys)
['state', 'belief']
['state', 'belief', 'action']

So I'm wondering what would be the right way to make VmapModule and DreamerActorLoss compatible? Would we want to add an in_keys attribute to DreamerActorLoss that returns a combined list of the keys in the actor model and world model?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah DreamerActorLoss should have in_keys!
All losses should. Dreamer hasn't received much love lately as you can see.
Let's take care of that in a separate PR then

@vmoens
Copy link
Contributor

vmoens commented Oct 16, 2024

The Dreamer implementation (in examples workflow) is failing

@kurtamohler
Copy link
Collaborator Author

The Dreamer implementation (in examples workflow) is failing

Should be fixed now

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks!

loss_td, fake_data = loss_module(tensordict)
# NOTE: Input is reshaped because GRUCell (which is part of the
# RSSMPrior module in `mb_env`) expects input to be either 1D or 2D
loss_td, fake_data = loss_module(tensordict.reshape(-1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah DreamerActorLoss should have in_keys!
All losses should. Dreamer hasn't received much love lately as you can see.
Let's take care of that in a separate PR then

@vmoens vmoens merged commit a27514c into pytorch:main Oct 18, 2024
71 of 80 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Objectives
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants