diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 729b8a48171..5e49ad95f49 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -448,6 +448,7 @@ def __init__( self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch + self.requested_frames_per_batch = frames_per_batch self.device = device self.storing_device = storing_device diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 73247df4b0c..98220727a45 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -304,6 +304,7 @@ def __init__( self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch + self.requested_frames_per_batch = frames_per_batch self.device = device self.storing_device = storing_device diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 481fb70cc31..b90111763d7 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -315,6 +315,7 @@ def __init__( self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch + self.requested_frames_per_batch = frames_per_batch self.device = device self.storing_device = storing_device