Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Mar 27, 2024
1 parent f439b54 commit b96e151
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 5 deletions.
40 changes: 39 additions & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch

from _utils_internal import get_default_devices, make_tc

from mocking_classes import CountingEnv
from packaging import version
from packaging.version import parse
from tensordict import (
Expand All @@ -30,7 +32,6 @@
)
from torch import multiprocessing as mp
from torch.utils._pytree import tree_flatten, tree_map

from torchrl.collectors import RandomPolicy, SyncDataCollector
from torchrl.collectors.utils import split_trajectories
from torchrl.data import (
Expand Down Expand Up @@ -2792,6 +2793,43 @@ def test_rb_multidim_collector(
print(f"rb {rb}") # noqa: T201
raise

@pytest.mark.parametrize("strict_length", [True, False])
def test_done_slicesampler(self, strict_length):
env = SerialEnv(
3,
[
lambda: CountingEnv(max_steps=31),
lambda: CountingEnv(max_steps=32),
lambda: CountingEnv(max_steps=33),
],
)
full_action_spec = CountingEnv(max_steps=32).full_action_spec
policy = lambda td: td.update(
full_action_spec.zero((3,)).apply_(lambda x: x + 1)
)
rb = TensorDictReplayBuffer(
storage=LazyTensorStorage(200, ndim=2),
sampler=SliceSampler(
slice_len=32,
strict_length=strict_length,
truncated_key=("next", "truncated"),
),
batch_size=128,
)

for i in range(50):
r = env.rollout(50, policy=policy, break_when_any_done=False)
r["next", "done"][:, -1] = 1
rb.extend(r)

sample = rb.sample()

assert sample["next", "done"].sum() == 128 // 32, (
i,
sample["next", "done"].sum(),
)
assert (split_trajectories(sample)["next", "done"].sum(-2) == 1).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
8 changes: 4 additions & 4 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,9 +1160,9 @@ def _get_index(
terminated = torch.zeros_like(truncated)
if traj_terminated.any():
if isinstance(seq_length, int):
truncated.view(num_slices, -1)[traj_terminated] = 1
terminated.view(num_slices, -1)[traj_terminated, -1] = 1
else:
truncated[(seq_length.cumsum(0) - 1)[traj_terminated]] = 1
terminated[(seq_length.cumsum(0) - 1)[traj_terminated]] = 1
truncated = truncated & ~terminated
done = terminated | truncated
return index.to(torch.long).unbind(-1), {
Expand Down Expand Up @@ -1726,9 +1726,9 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
terminated = torch.zeros_like(truncated)
if traj_terminated.any():
if isinstance(seq_length, int):
truncated.view(num_slices, -1)[traj_terminated] = 1
terminated.view(num_slices, -1)[:, traj_terminated] = 1
else:
truncated[(seq_length.cumsum(0) - 1)[traj_terminated]] = 1
terminated[(seq_length.cumsum(0) - 1)[traj_terminated]] = 1
truncated = truncated & ~terminated
done = terminated | truncated

Expand Down

0 comments on commit b96e151

Please sign in to comment.