Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix, Feature] Vmap randomness in losses #1740

Merged
merged 21 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
from torchrl.objectives.redq import REDQLoss
from torchrl.objectives.reinforce import ReinforceLoss
from torchrl.objectives.utils import (
_vmap_func,
HardUpdate,
hold_out_net,
SoftUpdate,
Expand Down Expand Up @@ -233,6 +234,38 @@ def set_advantage_keys_through_loss_test(
)


@pytest.mark.parametrize("device", get_default_devices())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know what you think! :)

@pytest.mark.parametrize("vmap_randomness", (None, "different", "same"))
@pytest.mark.parametrize("dropout", (0.0, 0.1))
def test_loss_vmap_random(device, vmap_randomness, dropout):
class VmapTestLoss(LossModule):
def __init__(self):
super().__init__()
layers = [nn.Linear(4, 4), nn.ReLU()]
if dropout > 0.0:
layers.append(nn.Dropout(dropout))
layers.append(nn.Linear(4, 4))
net = nn.Sequential(*layers).to(device)
model = TensorDictModule(net, in_keys=["obs"], out_keys=["action"])
self.convert_to_functional(model, "model", expand_dim=4)
self.vmap_model = _vmap_func(
self.model, (None, 0), randomness=self.vmap_randomness
)

def forward(self, td):
out = self.vmap_model(td, self.model_params)
return {"loss": out["action"].mean()}

loss_module = VmapTestLoss()
td = TensorDict({"obs": torch.randn(3, 4).to(device)}, [3])

# If user sets vmap randomness to a specific value
if vmap_randomness in ("different", "same") and dropout > 0.0:
loss_module.set_vmap_randomness(vmap_randomness)

loss_module(td)["loss"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe let's test that things actually fail if we don't call the set_vmap_randomness before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't do loss_module.set_vmap_randomness(vmap_randomness) and have a Module that uses randomness vmap_randomness sets default to "different". So its only if the user wants a specific vmap_randomness. I think there is no case in which we should expect an error, only if the user sets vmap_randomness manually to "error" and uses dropout for example.

I can add a test for that but not sure if thats what you meant.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only if the user sets vmap_randomness manually to "error" and uses dropout for example

Yes that is what I meant. Here we only test that the code runs, but we're not really checking that it would have been broken had we done things differently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added it!



class TestDQN(LossModuleTestBase):
seed = 0

Expand Down Expand Up @@ -1803,12 +1836,17 @@ def _create_mock_actor(
device="cpu",
in_keys=None,
out_keys=None,
dropout=0.0,
):
# Actor
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
module = nn.Linear(obs_dim, action_dim)
module = nn.Sequential(
nn.Linear(obs_dim, obs_dim),
nn.Dropout(dropout),
nn.Linear(obs_dim, action_dim),
)
actor = Actor(
spec=action_spec, module=module, in_keys=in_keys, out_keys=out_keys
)
Expand Down Expand Up @@ -1984,6 +2022,7 @@ def _create_seq_mock_data_td3(
@pytest.mark.parametrize("noise_clip", [0.1, 1.0])
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("use_action_spec", [True, False])
@pytest.mark.parametrize("dropout", [0.0, 0.1])
def test_td3(
self,
delay_actor,
Expand All @@ -1993,9 +2032,10 @@ def test_td3(
noise_clip,
td_est,
use_action_spec,
dropout,
):
torch.manual_seed(self.seed)
actor = self._create_mock_actor(device=device)
actor = self._create_mock_actor(device=device, dropout=dropout)
value = self._create_mock_value(device=device)
td = self._create_mock_data_td3(device=device)
if use_action_spec:
Expand Down Expand Up @@ -4876,7 +4916,6 @@ def test_cql(
device,
td_est,
):

torch.manual_seed(self.seed)
td = self._create_mock_data_cql(device=device)

Expand Down Expand Up @@ -6075,7 +6114,7 @@ def zero_param(p):
p.grad = None
loss_objective.backward()
named_parameters = loss_fn.named_parameters()
for (name, other_p) in named_parameters:
for name, other_p in named_parameters:
p = params.get(tuple(name.split(".")))
assert other_p.shape == p.shape
assert other_p.dtype == p.dtype
Expand Down Expand Up @@ -11137,7 +11176,6 @@ def test_set_deprecated_keys(self, adv, kwargs):
)

with pytest.warns(DeprecationWarning):

if adv is VTrace:
actor_net = TensorDictModule(
nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]
Expand Down
27 changes: 25 additions & 2 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams
from torch import nn
from torch.nn import Parameter

from torchrl._utils import RL_WARNINGS
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives.utils import ValueEstimators

from torchrl.objectives.utils import RANDOM_MODULE_LIST, ValueEstimators
from torchrl.objectives.value import ValueEstimatorBase


Expand Down Expand Up @@ -81,6 +81,7 @@ class _AcceptedKeys:

pass

_vmap_randomness = None
default_value_estimator: ValueEstimators = None
SEP = "."
TARGET_NET_WARNING = (
Expand Down Expand Up @@ -429,6 +430,28 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams

return self

@property
def vmap_randomness(self):
if self._vmap_randomness is None:
do_break = False
for val in self.__dict__.values():
if isinstance(val, torch.nn.Module):
for module in val.modules():
if isinstance(module, RANDOM_MODULE_LIST):
self._vmap_randomness = "different"
do_break = True
break
if do_break:
# double break
break
else:
self._vmap_randomness = "error"

return self._vmap_randomness

def set_vmap_randomness(self, value):
self._vmap_randomness = value

@staticmethod
def _make_meta_params(param):
is_param = isinstance(param, nn.Parameter)
Expand Down
8 changes: 6 additions & 2 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,12 @@ def __init__(
torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)),
)

self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0))
self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network)
self._vmap_qvalue_networkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)

@property
def target_entropy(self):
Expand Down
4 changes: 3 additions & 1 deletion torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,9 @@ def __init__(
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma
self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0))
self._vmap_qvalue_networkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)

@property
def device(self) -> torch.device:
Expand Down
9 changes: 6 additions & 3 deletions torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ def __init__(
priority_key: str = None,
separate_losses: bool = False,
):

super().__init__()
self._in_keys = None
self._set_deprecated_ctor_keys(priority_key=priority_key)
Expand Down Expand Up @@ -318,8 +317,12 @@ def __init__(
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma

self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network)
self._vmap_getdist = _vmap_func(self.actor_network, func="get_dist_params")
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
self._vmap_getdist = _vmap_func(
self.actor_network, func="get_dist_params", randomness=self.vmap_randomness
)

@property
def target_entropy(self):
Expand Down
14 changes: 9 additions & 5 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,13 @@ def __init__(
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma
self._vmap_qnetworkN0 = _vmap_func(self.qvalue_network, (None, 0))
self._vmap_qnetworkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
if self._version == 1:
self._vmap_qnetwork00 = _vmap_func(qvalue_network)
self._vmap_qnetwork00 = _vmap_func(
qvalue_network, randomness=self.vmap_randomness
)

@property
def target_entropy_buffer(self):
Expand Down Expand Up @@ -411,7 +415,6 @@ def target_entropy(self):
isinstance(self.tensor_keys.action, tuple)
and len(self.tensor_keys.action) > 1
):

action_container_shape = action_spec[self.tensor_keys.action[:-1]].shape
else:
action_container_shape = action_spec.shape
Expand Down Expand Up @@ -753,7 +756,6 @@ def _value_loss(
return loss_value, {}

def _alpha_loss(self, log_prob: Tensor) -> Tensor:

if self.target_entropy is not None:
# we can compute this loss even if log_alpha is not a parameter
alpha_loss = -self.log_alpha * (log_prob + self.target_entropy)
Expand Down Expand Up @@ -1049,7 +1051,9 @@ def __init__(
self.register_buffer(
"target_entropy", torch.tensor(target_entropy, device=device)
)
self._vmap_qnetworkN0 = _vmap_func(self.qvalue_network, (None, 0))
self._vmap_qnetworkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
Expand Down
9 changes: 6 additions & 3 deletions torchrl/objectives/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def __init__(
priority_key: str = None,
separate_losses: bool = False,
) -> None:

super().__init__()
self._in_keys = None
self._set_deprecated_ctor_keys(priority=priority_key)
Expand Down Expand Up @@ -296,8 +295,12 @@ def __init__(
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma
self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network)
self._vmap_actor_network00 = _vmap_func(self.actor_network)
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
self._vmap_actor_network00 = _vmap_func(
self.actor_network, randomness=self.vmap_randomness
)

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
Expand Down
37 changes: 28 additions & 9 deletions torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import functools
import re
import warnings
from enum import Enum
from typing import Iterable, Optional, Union
Expand All @@ -29,6 +30,14 @@
"run `loss_module.make_value_estimator(ValueEstimators.<value_fun>, gamma=val)`."
)

RANDOM_MODULE_LIST = (
nn.Dropout,
nn.Dropout2d,
nn.Dropout3d,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these guys have a parent, common class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, all the Dropouts have a common parent _DropoutNd. Ill update it!

nn.AlphaDropout,
nn.FeatureAlphaDropout,
)


class ValueEstimators(Enum):
"""Value function enumerator for custom-built estimators.
Expand Down Expand Up @@ -478,13 +487,23 @@ def new_fun(self, netname=None):


def _vmap_func(module, *args, func=None, **kwargs):
def decorated_module(*module_args_params):
params = module_args_params[-1]
module_args = module_args_params[:-1]
with params.to_module(module):
if func is None:
return module(*module_args)
else:
return getattr(module, func)(*module_args)
try:

return vmap(decorated_module, *args, **kwargs) # noqa: TOR101
def decorated_module(*module_args_params):
params = module_args_params[-1]
module_args = module_args_params[:-1]
with params.to_module(module):
if func is None:
return module(*module_args)
else:
return getattr(module, func)(*module_args)

return vmap(decorated_module, *args, **kwargs) # noqa: TOR101

except RuntimeError as err:
if re.match(
r"vmap: called random operation while in randomness error mode", str(err)
):
raise RuntimeError(
"Please use <loss_module>.set_vmap_randomness('different') to handle random operations during vmap."
) from err
Loading