Skip to content

Commit

Permalink
[Refactor] Deprecate recurrent_mode API to use decorators/CMs instead
Browse files Browse the repository at this point in the history
ghstack-source-id: 0256235faa306edbba6e10544ef043df9b5cc1c8
Pull Request resolved: #2584
  • Loading branch information
vmoens committed Nov 19, 2024
1 parent c0ba3ff commit 14924d7
Show file tree
Hide file tree
Showing 23 changed files with 265 additions and 69 deletions.
2 changes: 2 additions & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ algorithms, such as DQN, DDPG or Dreamer.
OnlineDTActor
RSSMPosterior
RSSMPrior
set_recurrent_mode
recurrent_mode

Multi-agent-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
6 changes: 3 additions & 3 deletions sota-implementations/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def compile_rssms(module):
t_loss_model_init = time.time()
# update world model
with torch.autocast(
device_type=device.type,
device_type=device.mode,
dtype=torch.bfloat16,
) if use_autocast else contextlib.nullcontext():
model_loss_td, sampled_tensordict = world_model_loss(
Expand Down Expand Up @@ -215,7 +215,7 @@ def compile_rssms(module):
# update actor network
t_loss_actor_init = time.time()
with torch.autocast(
device_type=device.type, dtype=torch.bfloat16
device_type=device.mode, dtype=torch.bfloat16
) if use_autocast else contextlib.nullcontext():
actor_loss_td, sampled_tensordict = actor_loss(
sampled_tensordict.reshape(-1)
Expand All @@ -238,7 +238,7 @@ def compile_rssms(module):
# update value network
t_loss_critic_init = time.time()
with torch.autocast(
device_type=device.type, dtype=torch.bfloat16
device_type=device.mode, dtype=torch.bfloat16
) if use_autocast else contextlib.nullcontext():
value_loss_td, sampled_tensordict = value_loss(sampled_tensordict)

Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ def make_redq_loss(
"""Builds the REDQ loss module."""
loss_kwargs = {}
loss_kwargs.update({"loss_function": cfg.loss.loss_function})
loss_kwargs.update({"delay_qvalue": cfg.loss.type == "double"})
loss_kwargs.update({"delay_qvalue": cfg.loss.mode == "double"})
loss_class = REDQLoss_deprecated
if isinstance(model, ActorValueOperator):
actor_model = model.get_policy_operator()
Expand Down Expand Up @@ -953,7 +953,7 @@ def make_target_updater(
cfg: "DictConfig", loss_module: LossModule # noqa: F821
) -> TargetNetUpdater | None:
"""Builds a target network weight update object."""
if cfg.loss.type == "double":
if cfg.loss.mode == "double":
if not cfg.loss.hard_update:
target_net_updater = SoftUpdate(
loss_module, eps=1 - 1 / cfg.loss.value_network_update_interval
Expand Down
24 changes: 24 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
DistributionalQValueActor,
OneHotCategorical,
QValueActor,
recurrent_mode,
SafeSequential,
WorldModelWrapper,
)
Expand Down Expand Up @@ -15507,6 +15508,29 @@ def test_set_deprecated_keys(self, adv, kwargs):


class TestBase:
def test_decorators(self):
class MyLoss(LossModule):
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
assert recurrent_mode()
assert exploration_type() is ExplorationType.DETERMINISTIC
return TensorDict()

def actor_loss(self, tensordict: TensorDictBase) -> TensorDictBase:
assert recurrent_mode()
assert exploration_type() is ExplorationType.DETERMINISTIC
return TensorDict()

def something_loss(self, tensordict: TensorDictBase) -> TensorDictBase:
assert recurrent_mode()
assert exploration_type() is ExplorationType.DETERMINISTIC
return TensorDict()

loss = MyLoss()
loss.forward(None)
loss.actor_loss(None)
loss.something_loss(None)
assert not recurrent_mode()

@pytest.mark.parametrize("expand_dim", [None, 2])
@pytest.mark.parametrize("compare_against", [True, False])
@pytest.mark.skipif(not _has_functorch, reason="functorch is needed for expansion")
Expand Down
2 changes: 1 addition & 1 deletion test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def test_parallel_devices(
)
if parallel:
assert (
env.shared_tensordict_parent.device.type == torch.device(edevice).type
env.shared_tensordict_parent.device.mode == torch.device(edevice).type
)

@pytest.mark.parametrize("start_method", [None, mp_ctx])
Expand Down
4 changes: 2 additions & 2 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,9 +767,9 @@ class TC:
storage = storage_type(max_size=10, device=device_storage)
storage.set(0, data)
if device_storage != "auto":
assert storage.get(0).device.type == device_storage.type
assert storage.get(0).device.mode == device_storage.type
else:
assert storage.get(0).device.type == storage.device.type
assert storage.get(0).device.mode == storage.device.type

@pytest.mark.parametrize("storage_in", ["tensor", "memmap"])
@pytest.mark.parametrize("storage_out", ["tensor", "memmap"])
Expand Down
54 changes: 53 additions & 1 deletion test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
OnlineDTActor,
ProbabilisticActor,
SafeModule,
set_recurrent_mode,
TanhDelta,
TanhNormal,
ValueOperator,
Expand Down Expand Up @@ -729,6 +730,31 @@ def test_errs(self):
with pytest.raises(KeyError, match="is_init"):
lstm_module(td)

@pytest.mark.parametrize("default_val", [False, True, None])
def test_set_recurrent_mode(self, default_val):
lstm_module = LSTMModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
default_recurrent_mode=default_val,
)
assert lstm_module.recurrent_mode is bool(default_val)
with set_recurrent_mode(True):
assert lstm_module.recurrent_mode
with set_recurrent_mode(False):
assert not lstm_module.recurrent_mode
with set_recurrent_mode("recurrent"):
assert lstm_module.recurrent_mode
with set_recurrent_mode("sequential"):
assert not lstm_module.recurrent_mode
assert lstm_module.recurrent_mode
assert not lstm_module.recurrent_mode
assert lstm_module.recurrent_mode
assert lstm_module.recurrent_mode is bool(default_val)

@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_set_temporal_mode(self):
lstm_module = LSTMModule(
input_size=3,
Expand All @@ -754,7 +780,8 @@ def test_python_cudnn(self):
num_layers=2,
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
).set_recurrent_mode(True)
default_recurrent_mode=True,
)
obs = torch.rand(10, 20, 3)

hidden0 = torch.rand(10, 20, 2, 12)
Expand Down Expand Up @@ -1109,6 +1136,31 @@ def test_errs(self):
with pytest.raises(KeyError, match="is_init"):
gru_module(td)

@pytest.mark.parametrize("default_val", [False, True, None])
def test_set_recurrent_mode(self, default_val):
gru_module = GRUModule(
input_size=3,
hidden_size=12,
batch_first=True,
in_keys=["observation", "hidden"],
out_keys=["intermediate", ("next", "hidden")],
default_recurrent_mode=default_val,
)
assert gru_module.recurrent_mode is bool(default_val)
with set_recurrent_mode(True):
assert gru_module.recurrent_mode
with set_recurrent_mode(False):
assert not gru_module.recurrent_mode
with set_recurrent_mode("recurrent"):
assert gru_module.recurrent_mode
with set_recurrent_mode("sequential"):
assert not gru_module.recurrent_mode
assert gru_module.recurrent_mode
assert not gru_module.recurrent_mode
assert gru_module.recurrent_mode
assert gru_module.recurrent_mode is bool(default_val)

@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_set_temporal_mode(self):
gru_module = GRUModule(
input_size=3,
Expand Down
8 changes: 5 additions & 3 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9009,7 +9009,7 @@ def test_pin_mem(self, device):
td = TensorDict(
{key: torch.randn(3) for key in ["a", "b", "c"]}, [], device=device
)
if device.type == "cuda":
if device.mode == "cuda":
with pytest.raises(RuntimeError, match="cannot pin"):
pin_mem(td)
with pytest.raises(RuntimeError, match="cannot pin"):
Expand Down Expand Up @@ -10885,7 +10885,8 @@ def _make_gru_module(self, input_size=4, hidden_size=4, device="cpu"):
in_keys=["observation", "rhs", "is_init"],
out_keys=["output", ("next", "rhs")],
device=device,
).set_recurrent_mode(True)
default_recurrent_mode=True,
)

def _make_lstm_module(self, input_size=4, hidden_size=4, device="cpu"):
return LSTMModule(
Expand All @@ -10895,7 +10896,8 @@ def _make_lstm_module(self, input_size=4, hidden_size=4, device="cpu"):
in_keys=["observation", "rhs_h", "rhs_c", "is_init"],
out_keys=["output", ("next", "rhs_h"), ("next", "rhs_c")],
device=device,
).set_recurrent_mode(True)
default_recurrent_mode=True,
)

def _make_batch(self, batch_size: int = 2, sequence_length: int = 5):
observation = torch.randn(batch_size, sequence_length + 1, 4)
Expand Down
22 changes: 22 additions & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
import os
import pickle
import sys
import threading
import time
import traceback
import warnings
from contextlib import nullcontext
from copy import copy
from distutils.util import strtobool
from functools import wraps
Expand All @@ -32,6 +34,10 @@
from tensordict.utils import NestedKey
from torch import multiprocessing as mp

try:
from torch.compiler import is_compiling
except ImportError:
from torch._dynamo import is_dynamo_supported as is_compiling
LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO")
logger = logging.getLogger("torchrl")
logger.setLevel(getattr(logging, LOGGING_LEVEL))
Expand Down Expand Up @@ -827,3 +833,19 @@ def _make_ordinal_device(device: torch.device):
if device.type == "mps" and device.index is None:
return torch.device("mps", index=0)
return device


class _ContextManager:
def __init__(self):
self._mode: Any | None = None
self._lock = threading.Lock()

def get_mode(self) -> Any | None:
cm = self._lock if not is_compiling() else nullcontext()
with cm:
return self._mode

def set_mode(self, type: Any | None) -> None:
cm = self._lock if not is_compiling() else nullcontext()
with cm:
self._mode = type
22 changes: 11 additions & 11 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def map_weight(
weight = weight.data
if weight.device != policy_device:
weight = weight.to(policy_device)
elif weight.device.type in ("cpu",):
elif weight.device.mode in ("cpu",):
weight = weight.share_memory_()
if is_param:
weight = Parameter(weight, requires_grad=False)
Expand Down Expand Up @@ -582,41 +582,41 @@ def __init__(
)

self.storing_device = storing_device
if self.storing_device is not None and self.storing_device.type != "cuda":
if self.storing_device is not None and self.storing_device.mode != "cuda":
# Cuda handles sync
if torch.cuda.is_available():
self._sync_storage = torch.cuda.synchronize
elif torch.backends.mps.is_available() and hasattr(torch, "mps"):
# Will break for older PT versions which don't have torch.mps
self._sync_storage = torch.mps.synchronize
elif self.storing_device.type == "cpu":
elif self.storing_device.mode == "cpu":
self._sync_storage = _do_nothing
else:
raise RuntimeError("Non supported device")
else:
self._sync_storage = _do_nothing

self.env_device = env_device
if self.env_device is not None and self.env_device.type != "cuda":
if self.env_device is not None and self.env_device.mode != "cuda":
# Cuda handles sync
if torch.cuda.is_available():
self._sync_env = torch.cuda.synchronize
elif torch.backends.mps.is_available() and hasattr(torch, "mps"):
self._sync_env = torch.mps.synchronize
elif self.env_device.type == "cpu":
elif self.env_device.mode == "cpu":
self._sync_env = _do_nothing
else:
raise RuntimeError("Non supported device")
else:
self._sync_env = _do_nothing
self.policy_device = policy_device
if self.policy_device is not None and self.policy_device.type != "cuda":
if self.policy_device is not None and self.policy_device.mode != "cuda":
# Cuda handles sync
if torch.cuda.is_available():
self._sync_policy = torch.cuda.synchronize
elif torch.backends.mps.is_available() and hasattr(torch, "mps"):
self._sync_policy = torch.mps.synchronize
elif self.policy_device.type == "cpu":
elif self.policy_device.mode == "cpu":
self._sync_policy = _do_nothing
else:
raise RuntimeError("Non supported device")
Expand Down Expand Up @@ -1006,7 +1006,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
Yields: TensorDictBase objects containing (chunks of) trajectories
"""
if self.storing_device and self.storing_device.type == "cuda":
if self.storing_device and self.storing_device.mode == "cuda":
stream = torch.cuda.Stream(self.storing_device, priority=-1)
event = stream.record_event()
streams = [stream]
Expand All @@ -1025,7 +1025,7 @@ def cuda_check(tensor: torch.Tensor):
# This may be a bit dangerous as `torch.device("cuda")` may not have a precise
# device associated, whereas `tensor.device` always has
for spec in self.env.specs.values(True, True):
if spec.device.type == "cuda":
if spec.device.mode == "cuda":
if ":" not in str(spec.device):
raise RuntimeError(
"A cuda spec did not have a device associated. Make sure to "
Expand Down Expand Up @@ -3038,9 +3038,9 @@ def _main_async_collector(
else:
# make sure each cpu tensor is shared - assuming non-cpu devices are shared
def cast_tensor(x, MPS_ERROR=MPS_ERROR):
if x.device.type in ("cpu",):
if x.device.mode in ("cpu",):
x.share_memory_()
if x.device.type in ("mps",):
if x.device.mode in ("mps",):
RuntimeError(MPS_ERROR)

collected_tensordict.apply(cast_tensor, filter_empty=True)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def __getstate__(self):
# If it's memmaped no worry in this case either.
# Only if the device is not "cpu" or "cuda" we may have a problem.
def assert_is_sharable(tensor):
if tensor.device is None or tensor.device.type in (
if tensor.device is None or tensor.device.mode in (
"cuda",
"cpu",
"meta",
Expand Down
Loading

0 comments on commit 14924d7

Please sign in to comment.