-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix, Feature] Vmap randomness in losses #1740
Changes from 16 commits
3206ddb
82ec897
29a5cda
53ca415
50aa8a4
3f54635
c4ff926
1c44c35
bc03b92
9aa65ed
3ff3ef8
febf277
3af72c5
08c83c1
f122825
d85d63b
92c3e40
7cdc1e6
f737172
1bfcc27
6ac65e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -120,6 +120,7 @@ | |
from torchrl.objectives.redq import REDQLoss | ||
from torchrl.objectives.reinforce import ReinforceLoss | ||
from torchrl.objectives.utils import ( | ||
_vmap_func, | ||
HardUpdate, | ||
hold_out_net, | ||
SoftUpdate, | ||
|
@@ -233,6 +234,38 @@ def set_advantage_keys_through_loss_test( | |
) | ||
|
||
|
||
@pytest.mark.parametrize("device", get_default_devices()) | ||
@pytest.mark.parametrize("vmap_randomness", (None, "different", "same")) | ||
@pytest.mark.parametrize("dropout", (0.0, 0.1)) | ||
def test_loss_vmap_random(device, vmap_randomness, dropout): | ||
class VmapTestLoss(LossModule): | ||
def __init__(self): | ||
super().__init__() | ||
layers = [nn.Linear(4, 4), nn.ReLU()] | ||
if dropout > 0.0: | ||
layers.append(nn.Dropout(dropout)) | ||
layers.append(nn.Linear(4, 4)) | ||
net = nn.Sequential(*layers).to(device) | ||
model = TensorDictModule(net, in_keys=["obs"], out_keys=["action"]) | ||
self.convert_to_functional(model, "model", expand_dim=4) | ||
self.vmap_model = _vmap_func( | ||
self.model, (None, 0), randomness=self.vmap_randomness | ||
) | ||
|
||
def forward(self, td): | ||
out = self.vmap_model(td, self.model_params) | ||
return {"loss": out["action"].mean()} | ||
|
||
loss_module = VmapTestLoss() | ||
td = TensorDict({"obs": torch.randn(3, 4).to(device)}, [3]) | ||
|
||
# If user sets vmap randomness to a specific value | ||
if vmap_randomness in ("different", "same") and dropout > 0.0: | ||
loss_module.set_vmap_randomness(vmap_randomness) | ||
|
||
loss_module(td)["loss"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe let's test that things actually fail if we don't call the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we don't do I can add a test for that but not sure if thats what you meant. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes that is what I meant. Here we only test that the code runs, but we're not really checking that it would have been broken had we done things differently. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added it! |
||
|
||
|
||
class TestDQN(LossModuleTestBase): | ||
seed = 0 | ||
|
||
|
@@ -1803,12 +1836,17 @@ def _create_mock_actor( | |
device="cpu", | ||
in_keys=None, | ||
out_keys=None, | ||
dropout=0.0, | ||
): | ||
# Actor | ||
action_spec = BoundedTensorSpec( | ||
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,) | ||
) | ||
module = nn.Linear(obs_dim, action_dim) | ||
module = nn.Sequential( | ||
nn.Linear(obs_dim, obs_dim), | ||
nn.Dropout(dropout), | ||
nn.Linear(obs_dim, action_dim), | ||
) | ||
actor = Actor( | ||
spec=action_spec, module=module, in_keys=in_keys, out_keys=out_keys | ||
) | ||
|
@@ -1984,6 +2022,7 @@ def _create_seq_mock_data_td3( | |
@pytest.mark.parametrize("noise_clip", [0.1, 1.0]) | ||
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) | ||
@pytest.mark.parametrize("use_action_spec", [True, False]) | ||
@pytest.mark.parametrize("dropout", [0.0, 0.1]) | ||
def test_td3( | ||
self, | ||
delay_actor, | ||
|
@@ -1993,9 +2032,10 @@ def test_td3( | |
noise_clip, | ||
td_est, | ||
use_action_spec, | ||
dropout, | ||
): | ||
torch.manual_seed(self.seed) | ||
actor = self._create_mock_actor(device=device) | ||
actor = self._create_mock_actor(device=device, dropout=dropout) | ||
value = self._create_mock_value(device=device) | ||
td = self._create_mock_data_td3(device=device) | ||
if use_action_spec: | ||
|
@@ -4876,7 +4916,6 @@ def test_cql( | |
device, | ||
td_est, | ||
): | ||
|
||
torch.manual_seed(self.seed) | ||
td = self._create_mock_data_cql(device=device) | ||
|
||
|
@@ -6075,7 +6114,7 @@ def zero_param(p): | |
p.grad = None | ||
loss_objective.backward() | ||
named_parameters = loss_fn.named_parameters() | ||
for (name, other_p) in named_parameters: | ||
for name, other_p in named_parameters: | ||
p = params.get(tuple(name.split("."))) | ||
assert other_p.shape == p.shape | ||
assert other_p.dtype == p.dtype | ||
|
@@ -11137,7 +11176,6 @@ def test_set_deprecated_keys(self, adv, kwargs): | |
) | ||
|
||
with pytest.warns(DeprecationWarning): | ||
|
||
if adv is VTrace: | ||
actor_net = TensorDictModule( | ||
nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
# LICENSE file in the root directory of this source tree. | ||
|
||
import functools | ||
import re | ||
import warnings | ||
from enum import Enum | ||
from typing import Iterable, Optional, Union | ||
|
@@ -29,6 +30,14 @@ | |
"run `loss_module.make_value_estimator(ValueEstimators.<value_fun>, gamma=val)`." | ||
) | ||
|
||
RANDOM_MODULE_LIST = ( | ||
nn.Dropout, | ||
nn.Dropout2d, | ||
nn.Dropout3d, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do these guys have a parent, common class? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, all the Dropouts have a common parent _DropoutNd. Ill update it! |
||
nn.AlphaDropout, | ||
nn.FeatureAlphaDropout, | ||
) | ||
|
||
|
||
class ValueEstimators(Enum): | ||
"""Value function enumerator for custom-built estimators. | ||
|
@@ -478,13 +487,23 @@ def new_fun(self, netname=None): | |
|
||
|
||
def _vmap_func(module, *args, func=None, **kwargs): | ||
def decorated_module(*module_args_params): | ||
params = module_args_params[-1] | ||
module_args = module_args_params[:-1] | ||
with params.to_module(module): | ||
if func is None: | ||
return module(*module_args) | ||
else: | ||
return getattr(module, func)(*module_args) | ||
try: | ||
|
||
return vmap(decorated_module, *args, **kwargs) # noqa: TOR101 | ||
def decorated_module(*module_args_params): | ||
params = module_args_params[-1] | ||
module_args = module_args_params[:-1] | ||
with params.to_module(module): | ||
if func is None: | ||
return module(*module_args) | ||
else: | ||
return getattr(module, func)(*module_args) | ||
|
||
return vmap(decorated_module, *args, **kwargs) # noqa: TOR101 | ||
|
||
except RuntimeError as err: | ||
if re.match( | ||
r"vmap: called random operation while in randomness error mode", str(err) | ||
): | ||
raise RuntimeError( | ||
"Please use <loss_module>.set_vmap_randomness('different') to handle random operations during vmap." | ||
) from err |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me know what you think! :)