Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Mar 27, 2024
1 parent b96e151 commit c15f41a
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2026,8 +2026,11 @@ def test_slice_sampler(
assert too_short

assert len(trajs_unique_id) == 4
done = info[("next", "done")]
assert done.view(num_slices, -1)[:, -1].all()
truncated = info[("next", "truncated")]
assert truncated.view(num_slices, -1)[:, -1].all()
terminated = info[("next", "terminated")]
assert (truncated | terminated).view(num_slices, -1)[:, -1].all()

@pytest.mark.parametrize("sampler", [SliceSampler, SliceSamplerWithoutReplacement])
def test_slice_sampler_at_capacity(self, sampler):
Expand Down Expand Up @@ -2167,8 +2170,10 @@ def test_slice_sampler_without_replacement(
trajs_unique_id = trajs_unique_id.union(
cur_episodes,
)
truncated = info[("next", "truncated")]
assert truncated.view(num_slices, -1)[:, -1].all()
done = info[("next", "done")]
assert done.view(num_slices, -1)[:, -1].all()
done_recon = info[("next", "truncated")] | info[("next", "terminated")]
assert done_recon.view(num_slices, -1)[:, -1].all()

def test_slicesampler_strictlength(self):

Expand Down

0 comments on commit c15f41a

Please sign in to comment.