diff --git a/test/test_collector.py b/test/test_collector.py index 1309254ce2d..38191a46eaa 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -3172,6 +3172,29 @@ def make_and_test_policy( ) +@pytest.mark.parametrize( + "ctype", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector] +) +def test_no_stopiteration(ctype): + # Tests that there is no StopIteration raised and that the length of the collector is properly set + if ctype is SyncDataCollector: + envs = SerialEnv(16, CountingEnv) + else: + envs = [SerialEnv(8, CountingEnv), SerialEnv(8, CountingEnv)] + + collector = ctype(create_env_fn=envs, frames_per_batch=173, total_frames=300) + try: + c_iter = iter(collector) + assert len(collector) == 2 + for i in range(len(collector)): # noqa: B007 + c = next(c_iter) + assert c is not None + assert i == 1 + finally: + collector.shutdown() + del collector + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index fd4ae2c5db8..16eb5904b84 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -740,8 +740,8 @@ def __init__( f" ({-(-frames_per_batch // self.n_env) * self.n_env}). " "To silence this message, set the environment variable RL_WARNINGS to False." ) - self.requested_frames_per_batch = int(frames_per_batch) self.frames_per_batch = -(-frames_per_batch // self.n_env) + self.requested_frames_per_batch = self.frames_per_batch * self.n_env self.exploration_type = ( exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE ) @@ -1656,6 +1656,7 @@ def __init__( self._get_weights_fn_dict[policy_device] = get_weights_fn self.policy = policy + remainder = 0 if total_frames is None or total_frames < 0: total_frames = float("inf") else: