From 408cf7d04705f18e6a1d58f4b2b7255d67a443d9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 16:40:58 +0000 Subject: [PATCH] [BugFix] requested_frames_per_batch in distributed collectors ghstack-source-id: 49289de6956460d9aed13d982eb8003eafc35118 Pull Request resolved: https://github.com/pytorch/rl/pull/2579 --- torchrl/collectors/distributed/generic.py | 1 + torchrl/collectors/distributed/rpc.py | 1 + torchrl/collectors/distributed/sync.py | 1 + 3 files changed, 3 insertions(+) diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 729b8a48171..5e49ad95f49 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -448,6 +448,7 @@ def __init__( self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch + self.requested_frames_per_batch = frames_per_batch self.device = device self.storing_device = storing_device diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 73247df4b0c..98220727a45 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -304,6 +304,7 @@ def __init__( self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch + self.requested_frames_per_batch = frames_per_batch self.device = device self.storing_device = storing_device diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 481fb70cc31..b90111763d7 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -315,6 +315,7 @@ def __init__( self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch + self.requested_frames_per_batch = frames_per_batch self.device = device self.storing_device = storing_device