Skip to content

Commit

Permalink
[BugFix] Fix strict_length=True in SliceSampler (#2037)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 24, 2024
1 parent cd540bf commit e835770
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 4 deletions.
47 changes: 44 additions & 3 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2050,9 +2050,6 @@ def test_slice_sampler_at_capacity(self, sampler):

for s in rb:
if (s["steps"] == 9).any():
n = (s["steps"] == 9).nonzero()
assert ((s["steps"] == 0).nonzero() == n + 1).all()
assert ((s["steps"] == 1).nonzero() == n + 2).all()
break
else:
raise AssertionError
Expand Down Expand Up @@ -2172,6 +2169,50 @@ def test_slice_sampler_without_replacement(
truncated = info[("next", "truncated")]
assert truncated.view(num_slices, -1)[:, -1].all()

def test_slicesampler_strictlength(self):

torch.manual_seed(0)

data = TensorDict(
{
"traj": torch.cat(
[
torch.ones(2, dtype=torch.int),
torch.zeros(10, dtype=torch.int),
],
dim=0,
),
"x": torch.arange(12),
},
[12],
)

buffer = ReplayBuffer(
storage=LazyTensorStorage(12),
sampler=SliceSampler(num_slices=2, strict_length=True, traj_key="traj"),
batch_size=8,
)
buffer.extend(data)

for _ in range(50):
sample = buffer.sample()
assert sample.shape == torch.Size([8])
assert (sample["traj"] == 0).all()

buffer = ReplayBuffer(
storage=LazyTensorStorage(12),
sampler=SliceSampler(num_slices=2, strict_length=False, traj_key="traj"),
batch_size=8,
)
buffer.extend(data)

for _ in range(50):
sample = buffer.sample()
if sample.shape == torch.Size([6]):
assert (sample["traj"] != 0).any()
else:
assert len(sample["traj"].unique()) == 1


def test_prioritized_slice_sampler_doc_example():
sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,7 @@ def get_traj_idx(maxval):

if (lengths < seq_length).any():
if self.strict_length:
idx = lengths == seq_length
idx = lengths >= seq_length
if not idx.any():
raise RuntimeError(
f"Did not find a single trajectory with sufficient length (length range: {lengths.min()} - {lengths.max()} / required={seq_length}))."
Expand Down

0 comments on commit e835770

Please sign in to comment.