Skip to content

Commit

Permalink
[BugFix] Robust sync for non_blocking=True (#2034)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 26, 2024
1 parent e57d0bc commit 2b95b41
Show file tree
Hide file tree
Showing 8 changed files with 397 additions and 94 deletions.
54 changes: 42 additions & 12 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from torchrl.collectors.utils import split_trajectories
from torchrl.data.tensor_specs import TensorSpec
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.envs.common import _do_nothing, EnvBase
from torchrl.envs.transforms import StepCounter, TransformedEnv
from torchrl.envs.utils import (
_aggregate_end_of_traj,
Expand Down Expand Up @@ -472,8 +472,45 @@ def __init__(
)

self.storing_device = storing_device
if self.storing_device is not None and self.storing_device.type != "cuda":
# Cuda handles sync
if torch.cuda.is_available():
self._sync_storage = torch.cuda.synchronize
elif torch.backends.mps.is_available():
self._sync_storage = torch.mps.synchronize
elif self.storing_device.type == "cpu":
self._sync_storage = _do_nothing
else:
raise RuntimeError("Non supported device")
else:
self._sync_storage = _do_nothing

self.env_device = env_device
if self.env_device is not None and self.env_device.type != "cuda":
# Cuda handles sync
if torch.cuda.is_available():
self._sync_env = torch.cuda.synchronize
elif torch.backends.mps.is_available():
self._sync_env = torch.mps.synchronize
elif self.env_device.type == "cpu":
self._sync_env = _do_nothing
else:
raise RuntimeError("Non supported device")
else:
self._sync_env = _do_nothing
self.policy_device = policy_device
if self.policy_device is not None and self.policy_device.type != "cuda":
# Cuda handles sync
if torch.cuda.is_available():
self._sync_policy = torch.cuda.synchronize
elif torch.backends.mps.is_available():
self._sync_policy = torch.mps.synchronize
elif self.policy_device.type == "cpu":
self._sync_policy = _do_nothing
else:
raise RuntimeError("Non supported device")
else:
self._sync_policy = _do_nothing
self.device = device
# Check if we need to cast things from device to device
# If the policy has a None device and the env too, no need to cast (we don't know
Expand Down Expand Up @@ -503,7 +540,7 @@ def __init__(
if self.env_device:
self.env: EnvBase = self.env.to(self.env_device)
elif self.env.device is not None:
# we we did not receive an env device, we use the device of the env
# we did not receive an env device, we use the device of the env
self.env_device = self.env.device

# If the storing device is not the same as the policy device, we have
Expand Down Expand Up @@ -915,6 +952,7 @@ def rollout(self) -> TensorDictBase:
policy_input = self._shuttle.to(
self.policy_device, non_blocking=True
)
self._sync_policy()
elif self.policy_device is None:
# we know the tensordict has a device otherwise we would not be here
# we can pass this, clear_device_ must have been called earlier
Expand All @@ -933,6 +971,7 @@ def rollout(self) -> TensorDictBase:
if self._cast_to_env_device:
if self.env_device is not None:
env_input = self._shuttle.to(self.env_device, non_blocking=True)
self._sync_env()
elif self.env_device is None:
# we know the tensordict has a device otherwise we would not be here
# we can pass this, clear_device_ must have been called earlier
Expand All @@ -954,6 +993,7 @@ def rollout(self) -> TensorDictBase:
tensordicts.append(
self._shuttle.to(self.storing_device, non_blocking=True)
)
self._sync_storage()
else:
tensordicts.append(self._shuttle)

Expand Down Expand Up @@ -1000,16 +1040,6 @@ def rollout(self) -> TensorDictBase:
)
return self._final_rollout

@staticmethod
def _update_device_wise(tensor0, tensor1):
# given 2 tensors, returns tensor0 if their identity matches,
# or a copy of tensor1 on the device of tensor0 otherwise
if tensor1 is None or tensor1 is tensor0:
return tensor0
if tensor1.device == tensor0.device:
return tensor1
return tensor1.to(tensor0.device, non_blocking=True)

@torch.no_grad()
def reset(self, index=None, **kwargs) -> None:
"""Resets the environments to a new initial state."""
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ def get(self, index: Union[int, Sequence[int], slice]) -> Any:
# to be deprecated in v0.4
def map_device(tensor):
if tensor.device != self.device:
return tensor.to(self.device, non_blocking=True)
return tensor.to(self.device, non_blocking=False)
return tensor

if is_tensor_collection(result):
Expand Down
Loading

0 comments on commit 2b95b41

Please sign in to comment.