Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix, Feature] Vmap randomness in losses #1740

Merged
merged 21 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,12 @@ def __init__(
torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)),
)

self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0))
self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network)
self._vmap_qvalue_networkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness="different"
)
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness="different"
)
BY571 marked this conversation as resolved.
Show resolved Hide resolved

@property
def target_entropy(self):
Expand Down
4 changes: 3 additions & 1 deletion torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,9 @@ def __init__(
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma
self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0))
self._vmap_qvalue_networkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness="different"
)
BY571 marked this conversation as resolved.
Show resolved Hide resolved

@property
def device(self) -> torch.device:
Expand Down
4 changes: 3 additions & 1 deletion torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,9 @@ def __init__(
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma

self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network)
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness="different"
)
self._vmap_getdist = _vmap_func(self.actor_network, func="get_dist_params")

@property
Expand Down
6 changes: 4 additions & 2 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,11 @@ def __init__(
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma
self._vmap_qnetworkN0 = _vmap_func(self.qvalue_network, (None, 0))
self._vmap_qnetworkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness="different"
)
if self._version == 1:
self._vmap_qnetwork00 = _vmap_func(qvalue_network)
self._vmap_qnetwork00 = _vmap_func(qvalue_network, randomness="different")

@property
def target_entropy_buffer(self):
Expand Down
51 changes: 48 additions & 3 deletions torchrl/objectives/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,25 @@
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator

RANDOM_MODULE_LIST = (
BY571 marked this conversation as resolved.
Show resolved Hide resolved
torch.nn.Dropout,
torch.nn.Dropout2d,
torch.nn.Dropout3d,
torch.nn.AlphaDropout,
torch.nn.FeatureAlphaDropout,
torch.nn.GaussianDropout,
torch.nn.GaussianNoise,
torch.nn.SyncBatchNorm,
torch.nn.GroupNorm,
torch.nn.LayerNorm,
torch.nn.LocalResponseNorm,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.SpatialDropout,
torch.nn.SpatialCrossMapLRN,
)


class TD3Loss(LossModule):
"""TD3 Loss module.
Expand Down Expand Up @@ -201,6 +220,7 @@ class _AcceptedKeys:
"next_state_value",
"target_value",
]
_vmap_randomness = None
BY571 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
Expand All @@ -219,7 +239,6 @@ def __init__(
priority_key: str = None,
separate_losses: bool = False,
) -> None:

super().__init__()
self._in_keys = None
self._set_deprecated_ctor_keys(priority=priority_key)
Expand Down Expand Up @@ -296,8 +315,12 @@ def __init__(
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma
self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network)
self._vmap_actor_network00 = _vmap_func(self.actor_network)
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
self._vmap_actor_network00 = _vmap_func(
self.actor_network, randomness=self.vmap_randomness
)

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
Expand Down Expand Up @@ -343,6 +366,28 @@ def _cached_stack_actor_params(self):
[self.actor_network_params, self.target_actor_network_params], 0
)

@property
def vmap_randomness(self):
if self._vmap_randomness is None:
do_break = False
for val in self.__dict__.values():
if isinstance(val, torch.nn.Module):
for module in val.modules():
if isinstance(module, RANDOM_MODULE_LIST):
self._vmap_randomness = "different"
do_break = True
break
if do_break:
# double break
break
else:
self._vmap_randomness = "error"

return self._vmap_randomness

def set_vmap_randomness(self, value):
self._vmap_randomness = value

BY571 marked this conversation as resolved.
Show resolved Hide resolved
def actor_loss(self, tensordict):
tensordict_actor_grad = tensordict.select(*self.actor_network.in_keys)
with self.actor_network_params.to_module(self.actor_network):
Expand Down
28 changes: 19 additions & 9 deletions torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,13 +478,23 @@ def new_fun(self, netname=None):


def _vmap_func(module, *args, func=None, **kwargs):
def decorated_module(*module_args_params):
params = module_args_params[-1]
module_args = module_args_params[:-1]
with params.to_module(module):
if func is None:
return module(*module_args)
else:
return getattr(module, func)(*module_args)
try:

return vmap(decorated_module, *args, **kwargs) # noqa: TOR101
def decorated_module(*module_args_params):
params = module_args_params[-1]
module_args = module_args_params[:-1]
with params.to_module(module):
if func is None:
return module(*module_args)
else:
return getattr(module, func)(*module_args)

return vmap(decorated_module, *args, **kwargs) # noqa: TOR101

except RuntimeError as err:
if "vmap: called random operation while in randomness error mode" in str(
err
): # better to use re.match here but anyway
BY571 marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(
"Please use loss_module.set_vmap_randomness to handle random operations during vmap."
BY571 marked this conversation as resolved.
Show resolved Hide resolved
) from err
Loading