From a484e4f26c624e52762a575ad658cbfaf5e6a63f Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 17 Dec 2024 21:33:52 -0800 Subject: [PATCH] manager: rename start_step to start_quorum and move step changes to should_commit (#46) --- src/manager.rs | 70 ++++++++++++++++++++++++++++++- torchft/manager.py | 46 ++++++++++++++------- torchft/manager_test.py | 91 ++++++++++++++++++++++++++++------------- torchft/optim.py | 2 +- torchft/optim_test.py | 2 +- 5 files changed, 164 insertions(+), 47 deletions(-) diff --git a/src/manager.rs b/src/manager.rs index 68233cf..275c6d3 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -239,7 +239,8 @@ impl ManagerService for Arc { .await .map_err(|e| Status::internal(e.to_string()))?; - let participants = &quorum.participants; + let mut participants = quorum.participants.clone(); + participants.sort_by(|a, b| a.replica_id.cmp(&b.replica_id)); let mut replica_rank = 10000000000; for (i, p) in participants.iter().enumerate() { @@ -266,7 +267,7 @@ impl ManagerService for Arc { // Decide whether we should be healing: // 1. if we're not at the max step // 2. if everyone is at the first step and we're not the primary - let heal = max_step != req.step || max_step == 1 && primary.replica_id != self.replica_id; + let heal = max_step != req.step || max_step == 0 && primary.replica_id != self.replica_id; if heal { info!( "healing is required step={}, max_step={}", @@ -475,4 +476,69 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_get_quorum_heal_first_step() -> Result<()> { + let lighthouse = Lighthouse::new(LighthouseOpt { + bind: "[::]:0".to_string(), + join_timeout_ms: 100, + min_replicas: 2, + quorum_tick_ms: 100, + }) + .await?; + let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); + + let mut manager_futs: Vec>> = + Vec::new(); + + for replica_id in 0..2 { + let lighthouse_addr = lighthouse.address(); + manager_futs.push(tokio::spawn(async move { + let manager = Manager::new( + format!("rep_{}", replica_id), + lighthouse_addr, + "addr".to_string(), + "[::]:0".to_string(), + "store_addr".to_string(), + 1, // world size + ) + .await?; + let manager_fut = tokio::spawn(manager.clone().run()); + + let mut client = + manager_client_new(manager.address(), Duration::from_secs(10)).await?; + + let request = tonic::Request::new(ManagerQuorumRequest { + rank: 0, + step: 0, + checkpoint_server_addr: "addr".to_string(), + }); + + let result = client.quorum(request).await?.into_inner(); + + manager_fut.abort(); + + Ok(result) + })); + } + + let resp_a = manager_futs.swap_remove(0).await??; + let resp_b = manager_futs.swap_remove(0).await??; + + lighthouse_fut.abort(); + + assert_eq!(resp_a.quorum_id, 1); + assert_eq!(resp_a.max_step, 0); + assert_eq!(resp_a.replica_rank, 0); + assert_eq!(resp_a.replica_world_size, 2); + assert_eq!(resp_a.heal, false); + + assert_eq!(resp_b.quorum_id, 1); + assert_eq!(resp_b.max_step, 0); + assert_eq!(resp_b.replica_rank, 1); + assert_eq!(resp_b.replica_world_size, 2); + assert_eq!(resp_b.heal, true); + + Ok(()) + } } diff --git a/torchft/manager.py b/torchft/manager.py index 475849a..a1e5167 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -182,7 +182,6 @@ def _manager_state_dict() -> Dict[str, T]: self._batches_committed = 0 # first step is 1 - self._should_step = True self._participating_rank: Optional[int] = None self._participating_world_size: int = 0 @@ -218,8 +217,7 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso fut.set_result(grad) return fut - assert self._quorum_future is not None, "must call step before allreduce_grad" - self._quorum_future.result() + self.wait_quorum() if not self.is_participating(): grad.zero_() @@ -315,7 +313,7 @@ def callback( self._pending_work.append(cast(torch.futures.Future[object], fut)) return fut - def start_step(self) -> None: + def start_quorum(self, allow_heal: bool = True) -> None: """ .. note:: We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly. @@ -323,13 +321,20 @@ def start_step(self) -> None: Computes a new quorum (potentially asynchronously) and readies the manager for a new step. - Must be called before the forwards pass of each step for best + It's best practice to call this before the forwards pass of each step for performance as computing quorum may take some time. + + If allow_heal is set, the manager will attempt to heal either + synchronously before returning or asynchronously prior to any network + calls. + + Args: + allow_heal: whether to allow healing at the beginning of the step """ - if self._should_step: - self._step += 1 - self._batches_committed += self.num_participants() + # wait for previous quorum to complete + if self._quorum_future is not None: + self._quorum_future.result() self._errored = None self._healing = False @@ -338,9 +343,9 @@ def start_step(self) -> None: # TODO: we should really be wrapping this whole section in a try-except # block to allow gracefully recovering from issues in PG setup and quorum. - self._quorum_future = self._executor.submit(self._async_quorum) + self._quorum_future = self._executor.submit(self._async_quorum, allow_heal) if not self._use_async_quorum: - self._quorum_future.result() + self.wait_quorum() if self._healing: # eagerly apply pending state_dict so we can run the forwards pass @@ -350,7 +355,18 @@ def start_step(self) -> None: # and don't need to zero_grad self._healing = False - def _async_quorum(self) -> None: + def wait_quorum(self) -> None: + """ + Wait for the quorum to complete. + + ProcessGroup will be in a healthy state after this returns. + """ + assert ( + self._quorum_future is not None + ), "must call start_quorum before wait_quorum" + self._quorum_future.result() + + def _async_quorum(self, allow_heal: bool) -> None: ( quorum_id, replica_rank, @@ -372,7 +388,7 @@ def _async_quorum(self) -> None: # workers will be healthy. self._participating_rank, self._participating_world_size = ( (max_rank, max_world_size) - if self._use_async_quorum + if self._use_async_quorum or not allow_heal else (replica_rank, replica_world_size) ) @@ -397,7 +413,7 @@ def _async_quorum(self) -> None: self._quorum_id = quorum_id # See manager.rs for healing conditions - if heal: + if heal and allow_heal: self._healing = True self._logger.info( f"healing required, fetching checkpoint server address from {address=} {max_step=}" @@ -475,7 +491,9 @@ def should_commit(self) -> bool: self._ckpt_server.disallow_checkpoint() # decide whether we're in a healthy state to increase the step count - self._should_step = should_commit + if should_commit: + self._step += 1 + self._batches_committed += self.num_participants() return should_commit diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 4617be3..ad6d5a5 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -99,21 +99,21 @@ def test_quorum_happy(self, client_mock: MagicMock) -> None: ) self.assertEqual(manager._quorum_id, -1) - self.assertEqual(manager._step, 0) + self.assertEqual(manager.current_step(), 0) self.assertEqual(manager.batches_committed(), 0) - manager.start_step() + manager.start_quorum() manager.allreduce_grad(torch.tensor([1.0])).wait() self.assertEqual(len(manager._pending_work), 1) self.assertTrue(manager.should_commit()) self.assertEqual(len(manager._pending_work), 0) self.assertEqual(manager._quorum_id, 123) - self.assertEqual(manager._step, 1) + self.assertEqual(manager.current_step(), 1) # pyre-ignore[16]: _pg is mocked self.assertEqual(manager._pg.allreduce.call_count, 1) - manager.start_step() + manager.start_quorum() self.assertEqual(manager.batches_committed(), 2) @patch("torchft.manager.ManagerClient", autospec=True) @@ -133,14 +133,14 @@ def test_quorum_heal_sync(self, client_mock: MagicMock) -> None: True, # heal ) # forceable increment checkpoint server to compute correct address - manager._ckpt_server.allow_checkpoint(1) + manager._ckpt_server.allow_checkpoint(manager.current_step()) client_mock().checkpoint_address.return_value = manager._ckpt_server.address() self.assertEqual(manager._quorum_id, -1) - self.assertEqual(manager._step, 0) + self.assertEqual(manager.current_step(), 0) - manager.start_step() + manager.start_quorum() manager.allreduce_grad(torch.tensor([1.0])).wait() self.assertFalse(manager._healing) self.assertTrue(manager.is_participating()) @@ -148,7 +148,7 @@ def test_quorum_heal_sync(self, client_mock: MagicMock) -> None: self.assertTrue(manager.should_commit()) self.assertEqual(manager._quorum_id, 123) - self.assertEqual(manager._step, 20) + self.assertEqual(manager.current_step(), 21) # pyre-ignore[16]: _pg is mocked self.assertEqual(manager._pg.allreduce.call_count, 1) # pyre-ignore[16]: _pg is mocked @@ -175,14 +175,14 @@ def test_quorum_heal_async_not_enough_participants( True, # heal ) # forceable increment checkpoint server to compute correct address - manager._ckpt_server.allow_checkpoint(1) + manager._ckpt_server.allow_checkpoint(manager.current_step()) client_mock().checkpoint_address.return_value = manager._ckpt_server.address() self.assertEqual(manager._quorum_id, -1) - self.assertEqual(manager._step, 0) + self.assertEqual(manager.current_step(), 0) - manager.start_step() + manager.start_quorum() assert manager._quorum_future is not None manager._quorum_future.result() self.assertTrue(manager._healing) @@ -194,10 +194,10 @@ def test_quorum_heal_async_not_enough_participants( torch.testing.assert_close(grad, torch.zeros_like(grad)) # don't commit since num_max < min_replica_size self.assertFalse(manager.should_commit()) - self.assertFalse(manager._should_step) + self.assertEqual(manager.current_step(), 20) self.assertEqual(manager._quorum_id, 123) - self.assertEqual(manager._step, 20) + self.assertEqual(manager.current_step(), 20) # pyre-ignore[16]: _pg is mocked self.assertEqual(manager._pg.allreduce.call_count, 1) # pyre-ignore[16]: _pg is mocked @@ -206,8 +206,8 @@ def test_quorum_heal_async_not_enough_participants( self.assertEqual(self.load_state_dict.call_count, 1) # failed to commit so no step - manager.start_step() - self.assertEqual(manager._step, 20) + manager.start_quorum() + self.assertEqual(manager.current_step(), 20) self.assertEqual(manager.batches_committed(), 0) @patch("torchft.manager.ManagerClient", autospec=True) @@ -227,14 +227,14 @@ def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None: True, # heal ) # forceable increment checkpoint server to compute correct address - manager._ckpt_server.allow_checkpoint(1) + manager._ckpt_server.allow_checkpoint(manager.current_step()) client_mock().checkpoint_address.return_value = manager._ckpt_server.address() self.assertEqual(manager._quorum_id, -1) - self.assertEqual(manager._step, 0) + self.assertEqual(manager.current_step(), 0) - manager.start_step() + manager.start_quorum() assert manager._quorum_future is not None manager._quorum_future.result() self.assertTrue(manager._healing) @@ -245,10 +245,9 @@ def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None: # don't commit since num_max < min_replica_size self.assertTrue(manager.should_commit()) self.assertEqual(manager.num_participants(), 1) - self.assertTrue(manager._should_step) + self.assertTrue(manager.current_step(), 21) self.assertEqual(manager._quorum_id, 123) - self.assertEqual(manager._step, 20) # pyre-ignore[16]: _pg is mocked self.assertEqual(manager._pg.allreduce.call_count, 1) # pyre-ignore[16]: _pg is mocked @@ -256,8 +255,8 @@ def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None: self.assertEqual(self.load_state_dict.call_count, 1) - manager.start_step() - self.assertEqual(manager._step, 21) + manager.start_quorum() + self.assertEqual(manager.current_step(), 21) self.assertEqual(manager.batches_committed(), 1) @patch("torchft.manager.ManagerClient", autospec=True) @@ -278,9 +277,9 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: ) self.assertEqual(manager._quorum_id, -1) - self.assertEqual(manager._step, 0) + self.assertEqual(manager.current_step(), 0) - manager.start_step() + manager.start_quorum() manager.allreduce_grad(torch.tensor([1.0])).wait() # pyre-ignore[16]: _pg is mocked self.assertEqual(manager._pg.allreduce.call_count, 1) @@ -314,7 +313,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: 2, # max_world_size False, # heal ) - manager.start_step() + manager.start_quorum() self.assertFalse(manager._errored) @@ -343,7 +342,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: False, # heal ) - manager.start_step() + manager.start_quorum() manager.allreduce_grad(torch.tensor([1.0])).wait() self.assertTrue(manager.should_commit()) @@ -372,17 +371,51 @@ def test_quorum_fixed_world_size(self, client_mock: MagicMock) -> None: ) self.assertEqual(manager._quorum_id, -1) - self.assertEqual(manager._step, 0) + self.assertEqual(manager.current_step(), 0) self.assertEqual(manager.batches_committed(), 0) - manager.start_step() + manager.start_quorum() manager.allreduce_grad(torch.tensor([1.0])).wait() self.assertEqual(manager.is_participating(), rank != 2) self.assertEqual(manager.num_participants(), 2) - manager.start_step() + self.assertTrue(manager.should_commit()) self.assertEqual(manager.batches_committed(), 2) + self.assertEqual(manager.current_step(), 1) + + @patch("torchft.manager.ManagerClient", autospec=True) + def test_quorum_no_healing(self, client_mock: MagicMock) -> None: + manager = self._create_manager( + min_replica_size=2, + ) + client_mock().should_commit = lambda rank, step, should_commit: should_commit + + client_mock().quorum.return_value = ( + 123, # quorum_id + 0, # replica_rank + 3, # replica_world + "manager address", + f"localhost:{self.store.port}", + 1, # max_step + None, # max_rank + 2, # max_world_size + True, # heal + ) + + self.assertEqual(manager._quorum_id, -1) + self.assertEqual(manager.current_step(), 0) + self.assertEqual(manager.batches_committed(), 0) + + manager.start_quorum(allow_heal=False) + manager.allreduce_grad(torch.tensor([1.0])).wait() + + self.assertFalse(manager.is_participating()) + self.assertEqual(manager.num_participants(), 2) + + self.assertTrue(manager.should_commit()) + self.assertEqual(manager.batches_committed(), 2) + self.assertEqual(manager.current_step(), 1) @patch("torchft.manager.ManagerClient", autospec=True) def test_manager_report_error(self, client_mock: MagicMock) -> None: diff --git a/torchft/optim.py b/torchft/optim.py index 0e2a3e9..ce24823 100644 --- a/torchft/optim.py +++ b/torchft/optim.py @@ -45,7 +45,7 @@ def state_dict(self) -> object: return self.optim.state_dict() def zero_grad(self, set_to_none: bool = True) -> None: - self.manager.start_step() + self.manager.start_quorum() self.optim.zero_grad(set_to_none) def step(self, closure: Optional[object] = None) -> None: diff --git a/torchft/optim_test.py b/torchft/optim_test.py index cf18219..50412d8 100644 --- a/torchft/optim_test.py +++ b/torchft/optim_test.py @@ -32,7 +32,7 @@ def test_optimizer_wrapper(self) -> None: optim.load_state_dict(optim.state_dict()) optim.zero_grad() - self.assertEqual(manager.start_step.call_count, 1) + self.assertEqual(manager.start_quorum.call_count, 1) manager.should_commit.return_value = True optim.step()