-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix, Feature] Vmap randomness in losses #1740
Merged
Merged
Changes from 18 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
3206ddb
add randomness param to vmap for offpolicy algos
BY571 82ec897
update td3 with vmap_randomness
BY571 29a5cda
undo vmap randomness zip actor and critic modules
BY571 53ca415
expand random module list
BY571 50aa8a4
move random_module_list to utils
BY571 3f54635
update sac losses
BY571 c4ff926
update iql objective
BY571 1c44c35
update redq objective
BY571 bc03b92
update conti cql
BY571 9aa65ed
move vmap_randomness to LossModule
BY571 3ff3ef8
fix
BY571 febf277
fix
BY571 3af72c5
add test example
BY571 08c83c1
Merge branch 'main' into vmap_dropout
BY571 f122825
fix
BY571 d85d63b
add vmap randomness test
BY571 92c3e40
update ranodm_module_list
BY571 7cdc1e6
add fail case for vmap
BY571 f737172
Merge branch 'main' into vmap_dropout
BY571 1bfcc27
update vmap fail case test
BY571 6ac65e1
Merge remote-tracking branch 'origin/main' into vmap_dropout
vmoens File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
# LICENSE file in the root directory of this source tree. | ||
|
||
import functools | ||
import re | ||
import warnings | ||
from enum import Enum | ||
from typing import Iterable, Optional, Union | ||
|
@@ -13,6 +14,7 @@ | |
from tensordict.tensordict import TensorDict, TensorDictBase | ||
from torch import nn, Tensor | ||
from torch.nn import functional as F | ||
from torch.nn.modules import dropout | ||
|
||
try: | ||
from torch import vmap | ||
|
@@ -29,6 +31,8 @@ | |
"run `loss_module.make_value_estimator(ValueEstimators.<value_fun>, gamma=val)`." | ||
) | ||
|
||
RANDOM_MODULE_LIST = (dropout._DropoutNd,) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we keep it a tuple in case we are going to extend it in the future? |
||
|
||
|
||
class ValueEstimators(Enum): | ||
"""Value function enumerator for custom-built estimators. | ||
|
@@ -478,13 +482,23 @@ def new_fun(self, netname=None): | |
|
||
|
||
def _vmap_func(module, *args, func=None, **kwargs): | ||
def decorated_module(*module_args_params): | ||
params = module_args_params[-1] | ||
module_args = module_args_params[:-1] | ||
with params.to_module(module): | ||
if func is None: | ||
return module(*module_args) | ||
else: | ||
return getattr(module, func)(*module_args) | ||
try: | ||
|
||
return vmap(decorated_module, *args, **kwargs) # noqa: TOR101 | ||
def decorated_module(*module_args_params): | ||
params = module_args_params[-1] | ||
module_args = module_args_params[:-1] | ||
with params.to_module(module): | ||
if func is None: | ||
return module(*module_args) | ||
else: | ||
return getattr(module, func)(*module_args) | ||
|
||
return vmap(decorated_module, *args, **kwargs) # noqa: TOR101 | ||
|
||
except RuntimeError as err: | ||
if re.match( | ||
r"vmap: called random operation while in randomness error mode", str(err) | ||
): | ||
raise RuntimeError( | ||
"Please use <loss_module>.set_vmap_randomness('different') to handle random operations during vmap." | ||
) from err |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me know what you think! :)