From 70c650ec8c946f36fd8d57c11612548da2251128 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 10 Oct 2023 12:08:40 -0400 Subject: [PATCH] [Deprecation] Deprecate ambiguous device for memmap replay buffer (#1624) --- torchrl/data/replay_buffers/storages.py | 16 ++++++++++++++-- tutorials/sphinx-tutorials/pretrained_models.py | 2 +- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index f2b28f373b4..313163b96f8 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -586,8 +586,14 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: data.clone() .expand(self.max_size, *data.shape) .memmap_like(prefix=self.scratch_dir) - .to(self.device) ) + if self.device.type != "cpu": + warnings.warn( + "Support for Memmap device other than CPU will be deprecated in v0.4.0.", + category=DeprecationWarning, + ) + out = out.to(self.device).memmap_() + for key, tensor in sorted( out.items(include_nested=True, leaves_only=True), key=str ): @@ -603,8 +609,14 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: data.clone() .expand(self.max_size, *data.shape) .memmap_like(prefix=self.scratch_dir) - .to(self.device) ) + if self.device.type != "cpu": + warnings.warn( + "Support for Memmap device other than CPU will be deprecated in v0.4.0.", + category=DeprecationWarning, + ) + out = out.to(self.device).memmap_() + for key, tensor in sorted( out.items(include_nested=True, leaves_only=True), key=str ): diff --git a/tutorials/sphinx-tutorials/pretrained_models.py b/tutorials/sphinx-tutorials/pretrained_models.py index 9404b7abd43..24c4dee726e 100644 --- a/tutorials/sphinx-tutorials/pretrained_models.py +++ b/tutorials/sphinx-tutorials/pretrained_models.py @@ -88,7 +88,7 @@ # from torchrl.data import LazyMemmapStorage, ReplayBuffer -storage = LazyMemmapStorage(1000, device=device) +storage = LazyMemmapStorage(1000) rb = ReplayBuffer(storage=storage, transform=r3m) ##############################################################################